Comment by pizza

2 years ago

softmax(QK) gives you a probability matrix of shape [seq, seq]. Think of this like an adjacency matrix with edges with flow weights that are probabilities. Hence semantic routing of parts of X reduced with V.

where

- Q = X @ W_Q [query]

- K = X @ W_K [key]

- V = X @ V [value]

- X [input]

hence

attn_head_i = (softmax(Q@K/normalizing term) @ V)

Each head corresponds to a different concurrent routing system

The transformer just adds normalization and mlp feature learning parts around that.