Comment by miven
9 days ago
I'm not sure I understand what you're trying to say here, information between tokens is propagated through self-attention, and there's an attention block inside each transformer block within the model, that's a whole lot of internal state that's stored in (mostly) inscrutable key and value vectors with hundreds of dimensions per attention head, around a few dozen heads per attention block, and around a few dozen blocks per model.
Yes, but all that internal state only survives until the end of the computation chain that predicts the next token - it doesn't survive across the entire sequence as it would in a recurrent network.
There is literally no difference between a model predicting the tokens "<thought> I think the second choice looks best </thought>" and a user putting those tokens into the prompt: The input for the next round would be exactly the same.
So the tokens kind of act like a bottleneck (or more precisely the sampling of exactly one next token at the end of each prediction round does). During prediction of one token, the model can go crazy with hidden state, but not across several tokens. That forces the model to do "long form" reasoning through the tokens and not through hidden state.
The key and value vectors are cached, that's kind of the whole point of autoregressive transformer models, the "state" not only survives within the KV cache but, in some sense, grows continuously with each token added, and is reused for each subsequent token.
Hmm, maybe I misunderstood that part, but so far I thought the KV cache was really just that - a cache. Because all the previous tokens of the sequence stay the same, it makes no sense to compute the same K and V vectors again in each round.
But that doesn't change that the only input to the Q, K and V calculations are the tokens (or in later layers information that was derived from the tokens) and each vector in the cache maps directly to an input token.
So I think you could disable the cache and recompute everything in each round and you'd still get the same result, just a lot slower.
2 replies →