Comment by sailingparrot
3 months ago
But you are missing the causal attention from your analysis. The output is not the only thing that is preserved, there is also the KV-cache.
At token 1, the model goes through, say, 28 transformer blocks, for each one of those block we save 2 projections of the hidden state in a cache.
At token 2, on top of seeing the new token, the model is now also able in each one of those 28 blocks, to look at the previously saved hidden states from token 1.
At token 3, it can see the states from token 2 and 1 etc.
However I still agree that is not a perfect information-passing mechanism because of how those model are trained (and something like feedback transformer would be better), but information still is very much being passed from earlier tokens to later ones.
Like an other commenter said, isn't the KV cache a performance optimization to not have to redo work that was already done ? Or does it fundamentally alter the output of the LLM, and so preserves state that is not present in the output of the LLM ?
Yes, it's "just" an optimization technique, in the sense that you could not have it and end up with the same result (given the same input sequence), just much slower.
Conceptually what matters is not the kv-cache but the attention. But IMHO thinking about how the model behave during inference, when outputting one token at a time and doing attention on the kv cache is much easier to grok than during training/prefilling where the kv cache is absent and everything happens in parallel (although they are mathematically equivalent).
The important part of my point, is that when the model is processing token N, it can check it's past internal state during token 1,...,N-1, and thus "see" its previous plan and reasoning, and iterate over it, rather than just repeating everything from scratch in each token's hidden state (with caveat, explained at the end).
token_1 ──▶ h₁ᴸ ────────┐
token_2 ──▶ h₂ᴸ ──attn──┼──▶ h₃ᴸ (refines reasoning)
token_3 ──▶ h₃ᴸ ──attn──┼──▶ h₄ᴸ (refines further)
And the kv-cache makes this persistent across time, so the entire system (LLM+cache) becomes effectively able to save its state, and iterate upon it at each token, and not have to start from scratch every time.
But ultimately its a markov-chain, so again mathematically, yes, you could just re-do the full computation all the time, and end up in the same place.
Caveat: Because token N at layer L can attend to all other tokens <N but only at layer L, it only allows it to see the how the reasoning was at that depth, not how it was after a full pass, so it's not a perfect information passing mechanism, and is more pyramidal than straight line. Hence why i referenced feedback transformers in another message. But the principle still applies that information is passing through time steps.