Comment by threeducks

14 hours ago

I thought I'd do something smart and inline all the matrix multiplications into the einsums of the vectorized multi-head attention implementation from the article and set optimize="optimal" to make use of the optimal matrix chain multiplication algorithm https://en.wikipedia.org/wiki/Matrix_chain_multiplication to get a nice performance boost.

    def multi_head_attention_golfed(X, W_q, W_k, W_v, W_o, optimize="optimal"):
        scores = np.einsum('si,hij,tm,hmj->hst', X, W_q, X, W_k, optimize=optimize)
        weights = softmax(W_k.shape[-1]**-0.5 * scores, axis=-1)
        projected = np.einsum('hst,ti,hiv,hvd->shd', weights, X, W_v, W_o, optimize=optimize)
        return projected.reshape(X.shape[0], W_v.shape[2])

This is indeed twice as fast as the vectorized implementation, but, disappointingly, the naive implementation with loops is even faster. Here is the code if someone wants to figure out why the performance is like that: https://pastebin.com/raw/peptFyCw

My guess is that einsum could do a better job of considering cache coherency when evaluating the sum.

> This is indeed twice as fast as the vectorized implementation, but, disappointingly, the naive implementation with loops is even faster.

On CPU or GPU?