← Back to context

Comment by psb217

1 year ago

Well, that's what Transformer already does... One problem with the scaling you're describing is that there would be a massive amount of redundant information stored in hidden activations during training the RNN. The hidden state at each time step t in the sequence would need to contain all info that (i) could be useful for predicting the token at time t and (ii) that could be useful for predicting tokens at times >t. (i) is obvious and (ii) is since all information about the past is transferred to future predictions through the current hidden state. In principle, Transformers can avoid storing redundant info in multiple hidden states at the cost of having to maintain and access (via attention) a larger hidden state at test/eval time.

> there would be a massive amount of redundant information stored in hidden activations

Is there a way to prove this? One potential caveat that comes to mind for me is that perhaps the action of lerping between the old state and the new could be used by the model to perform semantically meaningful transformations on the old state. I guess in my mind it just doesn't seem obvious that the hidden state is necessarily a collection of "redundant information" — perhaps the information is culled/distilled the further along in the sequence you go? There will always be some redundancy, sure, but I don't think that such redundancy necessarily means we have to use superlinear methods like attention.

  • All information about the past which will be available for predicting future tokens must be stored in the present state. So, if some bits of info about some past tokens at times less than t_p will be used for predicting some future token at time t_f, those bits must be passed through all states at times from t_p to t_f. The bits are passed through the recurrence. Once information about past tokens is lost from the hidden state it is gone forever, so it must be stored and carried across many steps up until it finally becomes useful.

    The information cost of making the RNN state way bigger is high when done naively, but maybe someone can figure out a clever way to avoid storing full hidden states in memory during training or big improvements in hardware could make memory use less of a bottleneck.

    • > The information cost of making the RNN state way bigger is high when done naively, but maybe someone can figure out a clever way to avoid storing full hidden states in memory during training or big improvements in hardware could make memory use less of a bottleneck.

      Isn't this essentially what Mamba [1] does via its 'Hardware-aware Algorithm'?

      [1] https://arxiv.org/pdf/2312.00752