← Back to context

Comment by timlarshanson

1 day ago

I doubt it. This does not seem to be a particularly well written or well thought-out paper -- e.g. equations 6 and 7 contradict their descriptions in the sentence below; the 'theorem' is an assertion.

After reading a few times, I gather that, rather than kernelizing or linearizing attention (which has been thoroughly explored in the literature), they are using a MLP to do run-time modelling of the attention operation. If that's the case (?), (which is interesting, sure): 1 -- Why did they not say this plainly. 2 -- Why does eq. 12 show the memory MLP being indexed by the key, whereas eq. 15 shows it indexed by the query? 3 -- What's with all the extra LSTM-esque forget and remember gates? Meh. Wouldn't trust it without ablations.

I guess if a MLP can model a radiance field (NeRF) well, stands to reason it can approx attention too. The Q,K,V projection matrices will need to be learned beforehand using standard training.

While the memory & compute savings are clear, uncertain if this helps with reasoning or generalization thereof. I doubt that too.

The eq. 12 is a loss function to associate a given key and value in the memory MLP using test-time training with gradient-descent.

The eq. 15 is simply the operation to query a value that was previously inserted in previous tokens using eq. 12.

Basically, for each autoregressively processed segmented you do:

1) Test-time inference: query values from memory with eq. 15.

2) Test-time training: associate new keys and values into the memory with the loss from eq. 12.

The forget and remember gates is because... well, the architecture in general is very similar to a LSTM, but using test-time gradient descent to decide what to insert to the long-term memory.

  • Ok, thanks for the clarification.

    Seems the implicit assumption then is that M(q) -> v 'looks like' or 'is smooth like' the dot product, otherwise 'train on keys, inference on queries' wouldn't work ? (safe assumption imo with that l2 norm & in general; unsafe if q and k are from different distributions).

    Correct me if I'm wrong, but typically k and v are generated via affine projections K, V of the tokens; if M is matrix-valued and there are no forget and remember gates (to somehow approx the softmax?), then M = V K^-1