Comment by libraryofbabel

8 hours ago

Speculative decoding is an amazingly clever invention, almost seems-too-good-to-be-true (faster interference with zero degradation from the quality of the main model). The core idea is: if you can find a way to generate a small run of draft next tokens with a smaller model that have a reasonable likelihood of being correct, it's fast to check that they are actually correct with the main model because you can run the checks in parallel. And if you think about it, a lot of next tokens are pretty obvious in certain situations (e.g. it doesn't take a frontier model to guess the likely next token in "United States of...", and a lot of code is boilerplate and easy to predict from previous code sections).

I always encourage folks who are interested in LLM internals to read up on speculative decoding (both the basic version and the more advanced MTP), and if you have time, try and implement your own version of it (writing the core without a coding agent, to begin with!)

> it's fast to check that they are actually correct with the main model because you can run the checks in parallel.

Can you give an intuition as to why it's faster? I would have thought regardless how many you run in parallel, the successful check has to execute the full model to generate the full sequence so you will have exactly the same time needed? Or is it by process of elimination so it terminates early once it eliminates the non-viable choices? (in which case, how do you guarantee the correct output was speculatively generated at all to be the last survivor?)

  • The small draft model proposes a sequence of tokens d1 d2 d3.

    The big target model calculates

    P(d1)

    P(d2|d1)

    P(d3|d1 d2)

    In parallel. If we were just greedy decoding it would be simple. Just stop when the draft model doesn’t predict the most likely token as judged by the target model. At that point, append the correct token from the target model and kick off both models again in parallel.

    In practice we aren’t using greedy decoding. We are sampling and we need to match the target model’s distribution. To do this, we accept tokens from the draft model probabilistically, which is possible because we have the logits of both the draft model and the target at that point. The ratio of their softmax probabilities is used for this.

    You are right that actually accepting tokens has to happen sequentially but that’s a heck of a lot faster than a forward pass.

    • nice ... i think i get the idea - it's effectively the same / similar benefit as batching, but you're batching against your own speculated future path. Which would be pointless if you didn't have a high probability path to evaluate against - but the draft gives you that.

      1 reply →

  • AIUI you run the checks of several predicted tokens in lockstep, and the computation for each token is served by the same data loaded from memory. In normal execution, each token would depend on the previous one, precluding the parallelization and causing much more per-token memory traffic.

    So this is a case of trading off idle compute capacity that's waiting for the bottleneck (memory access).

  • An obscure fact about the transformer architecture is that it more or less computes the most likely next token for every single token in the context window at once. This is because the KV cache values needed to predict the next token are needed for every token, and the attention modules do nearly all the work, so once you computed the KVs running them through the last sections to get the target probabilities is nearly free.

    The reason it's designed this way is a bit subtle but it has the advantage during training that you can use a single block of 10 tokens to generate 9 training examples in parallel, so it's highly efficient. This efficiency is basically the main benefit of transformers - the algorithm parallelizes really well and that's what allowed the scale up to large language models as opposed to the previous reality of just language models.

    The blog post does discuss why MTP is faster but it's maybe a bit hard to understand if you haven't studied LLM internals. During inference the hardware has arithmetic units idling because they spend so much time waiting for the weight matrices to get moved closer to the processors. Because data movement and computation can be overlapped, if you can reuse the same loaded data for multiple calculations at once you're winning - it's free latency-wise because you're just exploiting previously idle resources (it's not free in terms of energy).

    Speculative decoding and MTP exploit this to run the model in parallel on several tokens at once. Say your context window contains "The United". The KV cache has been populated by the main model for this set of tokens. The draft model is given "The United" and predicts " States of America" in one forward pass (this part where it can predict multiple tokens at once with a single pass is the MTP part). Then the main model is given the KV cache from last time along with " States of America". In its own forward pass it can then compute in parallel the completions of both "The United", "The United States", "The United States of" and "The United States of America" (the last one might be an eos token indicating it wants to stop talking.). That's the speculative decoding part.

    Now you decode the main model at each position (look at the token probabilities and pick one according to some decoding strategy). It's possible the main model didn't pick " States" at all, or picked " States", but then its prediction diverged e.g. if it wants to say "The United States is a country". So you just select the tokens that match and toss all the tokens starting from the one that didn't. Repeat.

    The parallelism comes almost for free because the same weight matrices can be reused multiple times before they're swapped out for the next.

So we've basically taken the concept of branch prediction from CPUs and applied it to LLMs?

  • Maybe at very high level of abstraction, but there's no branching involved.

    • Well, there are multiple token proposals processed in parallel, from which only one is picked, seems like branching to me. The only difference is that in case of CPU there is always only one possible branch that is correct.

  • The concept of predicting future elements in a series is not specific to CS. It's older than computers.

  • Well, the TPUs they're running on don't have branch prediction, so that had to end up somewhere in the stack.

Naively it seems odd that running multiple checks in parallel is faster than just running the autoregressive model multiple times in series. It’s the same amount of compute right?

But I think the key is that in the standard autoregressive case we get memory bandwidth bound, so there are tons of idle compute resources. And so checking multiple tokens is cheap because we can batch and thus reuse the read weights for multiple tokens.

The verification step is similar to a prefill with a small batch size. The difference is what we do with the generated logits.

  • That’s correct, and yes - not less compute total on the main model (actually slightly more, since checking failed draft tokens costs you compute), but faster because inference is memory-bandwidth bound. And like you I also think of it as like a “mini prefill” (but on top of the existing KV cache, of course); the code is very similar to prefill if you implement a simple toy version yourself.

    Most of the complexity in implementing a simple toy version comes from having to get the KV cache back into a good state for the next cycle (e.g. if only the first half of your draft tokens were correct).

  • > But I think the key is that in the standard autoregressive case we get memory bandwidth bound, so there are tons of idle compute resources.

    Right, this is the same way batching works. It's "free" until we exhaust available compute resources, at which point decode throughput becomes compute bound. (This is a good place to be, because scaling out compute is a lot easier than adding fast VRAM.) This is why MTP is mostly useful when you have one or few users, which means compute is abundant. When you're running large batches you're better off using that compute to grow your batch size.

    Of course, batch size is usually limited by things like bulky KV caches. So perhaps MTP has some residual use in that setting. But if you're sharing cached context in a subagent swarm, or running a model like the recent DeepSeek V4 with its tiny KV cache, you can go a lot further in processing a larger batch.

    • You can disaggregate though. So draft models can run on cheaper hardware with less RAM, saving time on the more expensive machines with more RAM.