Comment by trott

1 year ago

My feeling is that the answer is "no", in the sense that these RNNs wouldn't be able to universally replace Transformers in LLMs, even though they might be good enough in some cases and beat them in others.

Here's why.

A user of an LLM might give the model some long text and then say "Translate this into German please". A Transformer can look back at its whole history. But what is an RNN to do? While the length of its context is unlimited, the amount of information the model retains about it is bounded by whatever is in its hidden state at any given time.

Relevant: https://arxiv.org/abs/2402.01032

> the amount of information the model retains about it is bounded by whatever is in its hidden state

This is no different than a transformer, which, after all, is bound by a finite state, just organized in a different manner.

  • > This is no different than a transformer, which, after all, is bound by a finite state, just organized in a different manner.

    It's not just a matter of organizing things differently. Suppose your network dimension and sequence length are both X.

    Then your memory usage (per layer) will be O(X^2), while your training update cost will be O(X^3). That's for both Transformers and RNNs.

    However, at the end of the sequence, a Transformer layer can look back see O(X^2) numbers, while an RNN can only see O(X) numbers.

That problem has plagued RNNs since the 90s: there's an information precision problem (how many bits do you need older states to carry), a decay problem (the oldest information is the weakest) and a mixing problem (it tends to mix/sum representations).

The counterargument here is that you can just scale the size of the hidden state sufficiently such that it can hold compressed representations of whatever-length sequence you like. Ultimately, what I care about is whether RNNs could compete with transformers if FLOPs are held constant—something TFA doesn't really investigate.

  • Well, that's what Transformer already does... One problem with the scaling you're describing is that there would be a massive amount of redundant information stored in hidden activations during training the RNN. The hidden state at each time step t in the sequence would need to contain all info that (i) could be useful for predicting the token at time t and (ii) that could be useful for predicting tokens at times >t. (i) is obvious and (ii) is since all information about the past is transferred to future predictions through the current hidden state. In principle, Transformers can avoid storing redundant info in multiple hidden states at the cost of having to maintain and access (via attention) a larger hidden state at test/eval time.

    • > there would be a massive amount of redundant information stored in hidden activations

      Is there a way to prove this? One potential caveat that comes to mind for me is that perhaps the action of lerping between the old state and the new could be used by the model to perform semantically meaningful transformations on the old state. I guess in my mind it just doesn't seem obvious that the hidden state is necessarily a collection of "redundant information" — perhaps the information is culled/distilled the further along in the sequence you go? There will always be some redundancy, sure, but I don't think that such redundancy necessarily means we have to use superlinear methods like attention.

      2 replies →

>> A user of an LLM might give the model some long text and then say "Translate this into German please". A Transformer can look back at its whole history.

Which isn't necessary. If you say "translate the following to german." Instead, all it needs is to remember the task at hand and a much smaller amount of recent input. Well, and the ability to output in parallel with processing input.

  • It's necessary for arbitrary information processing if you can forget and have no way to "unforget".

    A model can decide to forget something that turns out to be important for some future prediction. A human can go back and re-read/listen etc, A transformer is always re-reading but a RNN can't and is fucked.

    • If the networks are to ever be a path to a closer to general intelligence, they will anyway need to be able to ask for context to be repeated, or to have separate storage where they can "choose" to replay it themselves. So this problem likely has to be solved another way anyway, both for transformers and for RNNs.

      2 replies →

    • That's just because we twisted it's arm. One could for example feed the reversed input after, ie abc|cba where | is a special token. That would allow it to react to any part of the message.

      1 reply →

  • Also, a lightweight network could do a first pass to identify tasks, instructions, constraints etc, and then a second pass could use the RNN.

    Consider the flood fill algorithm or union-find algorithm, which feels magical upon first exposure.

    https://en.wikipedia.org/wiki/Hoshen%E2%80%93Kopelman_algori...

    Having 2 passes can enable so much more than a single pass.

    Another alternative could be to have a first pass make notes in a separate buffer while parsing the input. The bandwidth of the note taking and reading can be much much lower than that required for fetching the billions of parameters.