← Back to context

Comment by quotemstr

1 year ago

But an RNN isn't going to remember the first token of input. It won't know until it sees the last token whether that first token was relevant after all, so it has to learn token-specific update rules that let it guess how long to hold what kinds of information. (In multi-layer systems, the network uses ineffable abstractions rather than tokens, but the same idea applies.)

What the RNN must be doing reminds me of "sliding window attention" --- the model learns how to partition its state between short- and long-range memories to minimize overall loss. The two approaches seem related, perhaps even equivalent up to implementation details.

The most popular RNNs (the ones that were successful enough for Google translate and the like) actually had this behavior baked in to the architecture, called "LSTMs", "Long-Short Term Memory"