Comment by pests
20 hours ago
The state of the system can be cached after the system prompt is calculated and all new chats start from that state. O(n^2) is not great but apparently its fine at these context lengths and I'm sure this is a factor in their minimum prompt cost. Advances like grouped query or multi head attention or sparse attention will eventually get rid of that exponential, hopefully.
That's not how it works. The system prompt doesn't "get calculated first" or anything. You combine it with the user prompt and then run the generation for the first new token on that thing, which basically boils down to one huge matmul that runs in parallel. So you can literally just cache a part of the input matrices for the first step and then you'll very quickly run into n^2 complexity.