← Back to context

Comment by theanonymousone

1 year ago

I remember that, the way I understood it, Transformers solved two major "issues" of RNNs that enabled the later boom: Vanishing gradients limiting the context (and model?) size and difficulty in parallelisation limiting the size of the training data.

Do we have solutions for these two problems now?

Transformers can also fetch at any moment any previous information that become useful.

RNN are constantly updating and overwriting their memory. It means they need to be able to predict what is going to be useful in order to store it for later.

This is a massive advantage for Transformers in interactive use cases like in ChatGPT. You give it context and ask questions in multiple turns. Which part of the context was important for a given question only becomes known later in the token sequence.

To be more precise, I should say it's an advantage of Attention-based models, because there are also hybrid models successfully mixing both approaches, like Jamba.

  • You could theoretically run the input twice, allowing the model to correlate later tokens with previous ones. It would fix the problem with not knowing what information to retain. A more complicated approach would train the RNN to request replaying some earlier data when needed.

    A great thing about RNNs is they can easily fork the state and generate trees, it would be possible to backtrack and work on combinatorial search problems.

    Also easier to cache demonstrations for free in the initial state, a model that has seen lots of data is not using more memory than a model starting from scratch.

Vanishing (or exploding) gradients affected all deep architectures, not just RNNs. They were solved by LSTMs first proposed in 1997. See:

https://www.semanticscholar.org/paper/Long-Short-Term-Memory...

I find it interesting that this knowledge seems to be all but forgotten now. Back in the day, ca. 2014, LSTMs were all the rage, e.g. see:

https://karpathy.github.io/2015/05/21/rnn-effectiveness/

https://colah.github.io/posts/2015-08-Understanding-LSTMs/

  • > They were solved by LSTMs first proposed in 1997.

    I see this stuff everywhere online and it's often taught this way so I don't blame folks for repeating it, but I think it's likely promulgated by folks who don't train LSTMs with long contexts.

    LSTMs do add something like a "skip-connection" (before that term was a thing) which helps deal with the catastrophic vanishing gradients you get from e.g. Jordan RNNs right from the jump.

    However (!), while this stops us from seeing vanishing gradients after e.g. 10s or 100s of time-steps, when you start seeing multiple 1000s of tokens, the wheels start falling off. I saw this in my own research, training on amino acid sequences of 3,000 length led to a huge amount of instability. It was only after tokenizing the amino acid sequences (which was uncommon at the time) which got us down to ~1500 timesteps on average, did we start seeing stable losses at training. Check-out the ablation at [0].

    You can think of ResNets by analogy. ResNets didn't "solve" vanishing gradients, there's a practical limit of the depth of networks, but it did go a long way towards dealing with it.

    EDIT: I wanted to add, while I was trying to troubleshoot this for myself, it was super hard to find evidence online of why I was seeing instability. Everything pertaining to "vanishing gradients" and LSTMs were blog posts and pre-prints which just merrily repeated "LSTMs solve the problem of vanishing gradients". That made it hard for me, a junior PhD at the time, to suss out the fact that LSTMs do demonstrably and reliably suffer from vanishing gradients at longer contexts.

    [0] https://academic.oup.com/bioinformatics/article/38/16/3958/6...

    • >> I see this stuff everywhere online and it's often taught this way so I don't blame folks for repeating it, but I think it's likely promulgated by folks who don't train LSTMs with long contexts.

      To clarify, this wasn't taught to me. I studied LSTMs during my MSc in 2014, by my own initiative, because they were popular at the time [1]. I remember there being a hefty amount of literature on LSTMs, and I mean scholarly articles, not just blog posts. Rather at the time I think there were only two blog posts, the ones by Andrey Karpathy and Chris Olah that I link above. The motivation with respect to vanishing gradients is well documented in previous wok by Hochreiter (I think it's his thesis), and maybe a little less so in the 1997 paper that introduces the "constant error carousel".

      What kind of "instability" did you see? Vanishing gradients weren't something I noticed in my experiments. If that was because I didn't use a long enough context, as you say, I wouldn't be able to tell but there was a different kind of instability: loss would enter an oscillatory pattern which I put down to the usual behaviour of gradient descent (either it gets stuck on local minima, or in saddle points). Is that what you mean?

      _______________

      [1] More precisely, our tutor asked us to study an RNN architecture expecting we'd look at something relatively simple like an Elman network but I wanted to try out the hot new stuff. The code and report is here:

      https://github.com/stassa/lstm_rnn

      There may be errors in the code and I don't know if you'll be able to run it, in case you got really curious. I don't think I really grokked automatic differentiation at the time.

    • Highway networks add a skip connection, but LSTMs don't. Btw you might be interested in truncated backprop thru time, which we introduced in our ULMFiT paper.

      2 replies →

  • Agreed, Ilya Sutskever himself has spent a long time with lstm and published papers like this one while working at Google. http://proceedings.mlr.press/v37/jozefowicz15.pdf

    Recent comments from him have said that any architecture can achieve transformer accuracy and recall, but we have devoted energy to refining transformers, due to the early successes.

  • LSTM and GRU did not quite solve the issue, but they made it less bad. Overall, recurrent units are nutritiously prone to vanishing and exploding gradients.

    I don't want to downplay the value of these models. Some people seem to be under the perception that transformers replaced or made them obsolete, which is faar from the truth.

From my (admittedly loose) reading of the paper, this paper particularly targets parallelization and fast training, not "vanishing gradients." However, by simplifying the recurrent units, they managed to improve both!

This is very clever and very interesting. The paper continuously calls it a "decade-old architecture," but in practice, it's still used massively, thanks to its simplicity in adapting to different domains. Placing it as a "competitor" to transformers is also not quite fully fair, as transformers and RNNs are not mutually exclusive, and there are many methods that merge them.

Improvement in RNNs is an improvement in a lot of other surprising places. A very interesting read.