Comment by porridgeraisin

19 hours ago

Background:

LLMs take your input, upscale it into a very high dimensional space, and then downscale it back to 1D at the end. This 1D list is interpreted as a list of probabilities -- one for each word in your vocabulary. i.e f(x) = downscale(upscale(x)). Each of downscale() and upscale() are parameterized (billions of params). I see you have a gamedev background, so as an example: bezier curves are parameterized functions where bezier handles are the parameters. During training, these parameters are continuously adjusted so that the output of the overall function gets closer to the expected result. Neural networks are just really flexible functions for which you can choose parameters to get any expected result, provided you have enough of them (similar to bezier curves in this regard).

---

When training, you make an LLM learn that

I use arch = downscale(upscale(I use))

If you want to predict the next word after that, you do next in sequence the following:

I use arch btw = downscale(upscale(I use arch))

Now, multi-token prediction is having two downscale functions, one for each of the next two words, and learning it that way, basically, you have a second downscale2() that learns how to predict the next-to-next word.

i.e in parallel:

I use arch = downscale1(upscale(I use))

I use ____ btw = downscale2(upscale(I use))

However, this way you'll need twice the number of parameters downscale needs. And if you want to predict more tokens ahead you'll need even more parameters.

What Qwen has done, is instead of downscale1 and downscale2 being completely separately parameterized functions, they set downscale1(.) = lightweight1(downscale_common(.)) and downscale2(.) = lightweight2(downscale_common(.)). This is essentially betting that a lot of the logic is common and the difference between predicting the next and next-to-next token can be captured in one lightweight function each. Lightweight here, means less parameters. The bet paid off.

So overall, you save params.

Concretely,

Before: downscale1.params + downscale2.params

After: downscale_common.params + lightweight1.params + lightweight2.params

Edit: its actually downscale_common(lightweight()) and not the other way around as I have written above. Doesn't change the crux of the answer, but just including this for clarity.

so after your edit it would be (just to clarify):

    I use ____ ___ = downscale_common(lightweight1(.)) + downscale_common(lightweight2(.)) ?

And does it generate 2 at a time and keep going that way, or is there some overlap?

  • You generate blocks of 2 at a time yes. In general, k. As you can imagine, larger k performs worse. LLM(I like cats) is very likely to continue with "because they", but beyond that, there's too many possibilities. LLM(I like cats because they are) = small and cute and they meow, while LLM(I like cats because they eat) = all the rats in my garden.

    If you try to predict the whole thing at once you might end up with

    I like cats because they are all the rats and they garden

    > Overlap

    Check out an inference method called self-speculative decoding which solves(somewhat) the above problem of k-token prediction, which does overlap the same ___ across multiple computations.

Dude, this was like that woosh of cool air on your brain when an axe splits your head in half. That really brought a lot of stuff into focus.