← Back to context

Comment by moffkalast

1 day ago

Hmm but isn't the checking only required because the draft model is not the same model and can only speculate what the main one is thinking, hence the name? If the main model generates two tokens itself, then how can it be wrong about its own predictions?

Because if you generate token n+1 with all 48 layers of Qwen3-Next and 80 billion params, and also generate token n+2 with the 1 MTP layer at 2bil params... that n+2 token can be much lower quality than the n+1 token but mostly correct.

Let's say you have a model that generates the string "The 44th president of the United States is ___ ___". Your model will generate "Barack" as the n+1 token, and the MTP layer probably does a good enough job to generate "Obama" as the n+2 token (even though that MTP layer is a mere <2bil parameters in size). Then you just check if "Obama" is correct via the same speculative decoding process, which is a lot faster than if you had to start over from layer 1-48 and generate "Obama" the regular way.

  • > Then you just check if "Obama" is correct via the same speculative decoding process, which is a lot faster than if you had to start over from layer 1-48 and generate "Obama" the regular way.

    That doesn't match my understanding of what speculative decoding does: AFAIK with regular speculative decoding you ask a smaller llm infer the next few tokens (let say 5 tokens) and then, you can have the big model infer token 1, 2, 3, 4, 5 and 6 in parallel (each time starting from the sentence partially completed by the smaller model). Because llms are bandwidth bound, doing the same work six times in parallel isn't slower than doing it only once (what's costly is moving the massive model weights between VRAM and the GPU cores).

    If token 1,2 and 3 match what the small models inferred, then you keep them. As soon as you have a mismatched token (say token 4) it means that you have to discard the next inferred tokens (here token 5 and 6) because they were calculated under a wrong assumption for token 4.

    So if the MTP layer merely replace the smaller llm in the previous scheme with everything else working the same way, you would save anything when inferring “Obama” (you'd still need to “generate it the regular way”, as there isn't really another way) but you could also start working on the word immediately after “Obama” by assuming “Obama” was already chose. And if the model actually outputted “Hussein” instead of “Obama”, then the token calculated to happen after “Obama” would have to be discarded.

    Or maybe my understanding of speculative decoding is completely off…

    • Sounds right. The policy for rejection can depend on what you want - you might accept the top K highest probability tokens or top P probability mass. Or you can do something like importance sampling and probabilistically reject based on the ratio of likelihoods

If you ask me to guess an answer, I'll _usually_ produce the same answer as if I had time to think about it deeply, but not always...

I believe it's something along these lines. The MTP head runs simultaneously and generates a probability list based on what it thinks the results will be, learned during training.

If n+1 = "Barack" then n+2 = "Obama" (confidence: 0.90) If n+1 = "The" then n+2 = "quick" (confidence: 0.45) If n+1 = "President" then n+2 = "Biden" (confidence: 0.75)

A threshold is set (say, as 90%) so that if the n+2 prediction is above that (as in the first example) it uses it without having to determine it with the main model. It's confident "enough".

  • Well yeah; also inference benefits massively from batching, so you use the guesses to pre fill context needed to infer the next speculated tokens, and if the guesses were wrong, you just have to re-compute the speculated ones that depended on the guessed context.

    You compute the next token and guess the one after; then you try to take the guess for real and compute the one after together with running inference for the guessed one, and the one after is speculated on the guess being correct.

the 2nd token is generated without knowing what token was chosen for the 1st token