Comment by timlarshanson
1 day ago
I doubt it. This does not seem to be a particularly well written or well thought-out paper -- e.g. equations 6 and 7 contradict their descriptions in the sentence below; the 'theorem' is an assertion.
After reading a few times, I gather that, rather than kernelizing or linearizing attention (which has been thoroughly explored in the literature), they are using a MLP to do run-time modelling of the attention operation. If that's the case (?), (which is interesting, sure): 1 -- Why did they not say this plainly. 2 -- Why does eq. 12 show the memory MLP being indexed by the key, whereas eq. 15 shows it indexed by the query? 3 -- What's with all the extra LSTM-esque forget and remember gates? Meh. Wouldn't trust it without ablations.
I guess if a MLP can model a radiance field (NeRF) well, stands to reason it can approx attention too. The Q,K,V projection matrices will need to be learned beforehand using standard training.
While the memory & compute savings are clear, uncertain if this helps with reasoning or generalization thereof. I doubt that too.
The eq. 12 is a loss function to associate a given key and value in the memory MLP using test-time training with gradient-descent.
The eq. 15 is simply the operation to query a value that was previously inserted in previous tokens using eq. 12.
Basically, for each autoregressively processed segmented you do:
1) Test-time inference: query values from memory with eq. 15.
2) Test-time training: associate new keys and values into the memory with the loss from eq. 12.
The forget and remember gates is because... well, the architecture in general is very similar to a LSTM, but using test-time gradient descent to decide what to insert to the long-term memory.
Ok, thanks for the clarification.
Seems the implicit assumption then is that M(q) -> v 'looks like' or 'is smooth like' the dot product, otherwise 'train on keys, inference on queries' wouldn't work ? (safe assumption imo with that l2 norm & in general; unsafe if q and k are from different distributions).
Correct me if I'm wrong, but typically k and v are generated via affine projections K, V of the tokens; if M is matrix-valued and there are no forget and remember gates (to somehow approx the softmax?), then M = V K^-1
The paper has ablations