Comment by timlarshanson

2 days ago

I doubt it. This does not seem to be a particularly well written or well thought-out paper -- e.g. equations 6 and 7 contradict their descriptions in the sentence below; the 'theorem' is an assertion.

After reading a few times, I gather that, rather than kernelizing or linearizing attention (which has been thoroughly explored in the literature), they are using a MLP to do run-time modelling of the attention operation. If that's the case (?), (which is interesting, sure): 1 -- Why did they not say this plainly. 2 -- Why does eq. 12 show the memory MLP being indexed by the key, whereas eq. 15 shows it indexed by the query? 3 -- What's with all the extra LSTM-esque forget and remember gates? Meh. Wouldn't trust it without ablations.

I guess if a MLP can model a radiance field (NeRF) well, stands to reason it can approx attention too. The Q,K,V projection matrices will need to be learned beforehand using standard training.

While the memory & compute savings are clear, uncertain if this helps with reasoning or generalization thereof. I doubt that too.

The eq. 12 is a loss function to associate a given key and value in the memory MLP using test-time training with gradient-descent.

The eq. 15 is simply the operation to query a value that was previously inserted in previous tokens using eq. 12.

Basically, for each autoregressively processed segmented you do:

1) Test-time inference: query values from memory with eq. 15.

2) Test-time training: associate new keys and values into the memory with the loss from eq. 12.

The forget and remember gates is because... well, the architecture in general is very similar to a LSTM, but using test-time gradient descent to decide what to insert to the long-term memory.

  • Ok, thanks for the clarification.

    Seems the implicit assumption then is that M(q) -> v 'looks like' or 'is smooth like' the dot product, otherwise 'train on keys, inference on queries' wouldn't work ? (safe assumption imo with that l2 norm & in general; unsafe if q and k are from different distributions).

    Correct me if I'm wrong, but typically k and v are generated via affine projections K, V of the tokens; if M is matrix-valued and there are no forget and remember gates (to somehow approx the softmax?), then M = V K^-1

    • It's actually implied in the paper that the neural memory module M can be anything, and there's probably a lot of room to test different kinds of architectures for M. But in this paper M is an MLP of 1 layer (fig. 7 is an ablation study using different number of layers for the MLP).

      > using a matrix-valued memory M [...] is an online linear regression objective and so the optimal solution assumes the underlying dependency of historical data is linear. On the other hand, we argue that deep memory modules (i.e., M ≥ 2) . Aligning with the theoretical results that MLPs with at least two layers are strictly more expressive than linear models (Hornik, Stinchcombe, and White 1989), in Section 5.5, we show that deep memory modules are more effective in practice