Comment by refulgentis

5 hours ago

n^2 isn't a setting someone chose, it's a mathematical consequence of what attention is.

Here's what attention does: every token looks at every other token to decide what's relevant. If you have n tokens, and each one looks at n others, you get n * n = n^2 operations.

Put another way: n^2 is when every token gets to look at every other token. What would n^3 be? n^10?

(sibling comment has same interpretation as you, then handwaves transformers can emulate more complex systems)

There are lots more complicated operations than comparing every token to every other token & the complexity increases when you start comparing not just token pairs but token bigrams, trigrams, & so on. There is no obvious proof that all those comparisons would be equivalent to the standard attention mechanism of comparing every token to every other one.

  • That skips an important part: the "deep" in "deep learning".

    Attention already composes across layers.

    After layer 1, you're not comparing raw tokens anymore. You're comparing tokens-informed-by-their-context. By layer 20, you're effectively comparing rich representations that encode phrases, relationships, and abstract patterns. The "higher-order" stuff emerges from depth. This is the whole point of deep networks, and attention.

    TL;DR for rest of comment: people have tried shallow-and-wide instead of deep, it doesn't work in practice. (rest of comment fleshes out search/ChatGPT prompt terms to look into to understand more of the technical stuff here)

    A shallow network can approximate any function (universal approximation theorem), but it may need exponentially more neurons. Deep networks represent the same functions with way fewer parameters. There's formal work on "depth separation",functions that deep nets compute efficiently, but shallow nets need exponential width to match.

    Empirically, People have tried shallow-and-wide vs. deep-and-narrow many times, across many domains. Deep wins consistently for the same parameter budget. This is part of why "deep learning" took off, the depth is load-bearing.

    For transformers specifically, stacking attention layers is crucial. A single attention layer, even with more heads or bigger dimensions, doesn't match what you get from depth. The representations genuinely get richer in ways that width alone can't replicate.