Comment by krackers
2 months ago
> extra attention head were added that queried the KV data from lower layers
Isn't this sort of similar to latent looping? E.g. [1]. But actually as [2] argues, even that wasn't a good experiment because it used the very last hidden state, which is too close to the logits and loses most of the rich embedding structure. Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
> Perhaps you don't even need access to the state of anything except the penultimate hidden layer, since based on my vague reading of [3] the residual stream doesn't "lose information" as it passes deeper down the attention layers, so each block maybe manipulates a different subspace of the residual stream.
I imagine that conventional transformers kind of force this. If you train a transformer such that it needs to learn the ability to do tasks like “Repeat the following words: apple banana cat” then the model is sort of forced to internally propagate the input far enough along to be able to perform the task. But maybe if you pre-trained from scratch with an architecture where later layers get direct access to earlier layers and/or the raw input, then the model wouldn’t need to propagate information.
Or maybe it would all fall apart and something would go wrong with the gradients.
Apparently a new paper from DS shows this is not the case, or rather the information isn't captured with as much fidelity as you'd expect. Intuitively the residual stream apparently doesn't have enough dimension to allow each layer to carve out its own subspace [1]
>And this makes it hard for layers to explore new features that are beneficial for just a few layers because you need to revert or overwrite those features as they will not be useful for later layers.
Since with a residual stream architecture, removing features can't be done by simply zeroing out a weight but instead you have to calculate the inverse.
>This leads each layer to contribute "generally useful" features and one immediate pattern is continuously refining features. I think this is the reason why later layers in LLMs tend to behave like that.
Greatly increasing the number of "channels" of the residual stream helps however (although you have to play some tricks to preserve the useful "identity mapping" behavior) [2, 3]
[1] https://x.com/rosinality/status/2006902561727721670
[2] https://x.com/norxornor/status/2006649194690257285#m
[3] https://x.com/byebyescaling/status/2007147288809087281#