Comment by ebonnafoux
2 days ago
Yes but in practice, if you compute K=X.wk, Q=X.wq and then K.tQ you make three matrice multiplication. Wouldn't be faster to compute W=wk.twq beforhand and then just X.W.tX which will be just two matrices multiplication ? Is there something I am missing ?
Most models have a per-head dimension much smaller than the input dimension, so it's faster to multiply by the small wk and wk individually than to multiply by the large matrix W. Also, if you use rotary positional embeddings, the RoPE matrices need to be sandwiched in the middle and they're different for every token, so you could no longer premultiply just once.