Comment by EMM_386
20 hours ago
I believe it's something along these lines. The MTP head runs simultaneously and generates a probability list based on what it thinks the results will be, learned during training.
If n+1 = "Barack" then n+2 = "Obama" (confidence: 0.90) If n+1 = "The" then n+2 = "quick" (confidence: 0.45) If n+1 = "President" then n+2 = "Biden" (confidence: 0.75)
A threshold is set (say, as 90%) so that if the n+2 prediction is above that (as in the first example) it uses it without having to determine it with the main model. It's confident "enough".
Well yeah; also inference benefits massively from batching, so you use the guesses to pre fill context needed to infer the next speculated tokens, and if the guesses were wrong, you just have to re-compute the speculated ones that depended on the guessed context.
You compute the next token and guess the one after; then you try to take the guess for real and compute the one after together with running inference for the guessed one, and the one after is speculated on the guess being correct.