← Back to context

Comment by ryao

7 days ago

At the heart of inference is matrix-vector multiplication. If you have many of these operations to do and only the vector part differs (which is the case when you have multiple queries), you can do matrix-matrix multiplication by stuffing the vectors into a matrix. Computing hardware is able to run the equivalent of dozens of matrix-vector multiplication operations in the same time it takes to do 1 matrix-matrix multiplication operation. This is called batching. That is the main trick.

A second trick is to implement something called speculative decoding. Inference has two phases. One is prompt processing and another is token generation. They actually work the same way using what is called a forward pass, except prompt processing can do them in parallel by switching from matrix-vector to matrix-matrix multiplication and dumping the prompt’s tokens into each forward pass in parallel. Each forward pass will create a new token, but it can be discarded unless it is from the last forward pass, as that will be the first new token generated as part of token generation. Now, you put that token into the next forward pass to get the token after it, and so on. It would be nice if all of the forward passes could be done in parallel, but you do not know the future, so you ordinarily cannot. However, if you make a draft model that is a very fast model runs in a fraction of the time and guesses the next token correctly most of the time, then you can sequentially run the forward pass for that instead N times. Now, you can take the N tokens and put it into the prompt processing routine that did N forward passes in parallel. Instead of discarding all tokens except the last one like in prompt processing, we will compare them to the input tokens. All tokens up to and including the first token that differ, that come out of the parallel forward pass are valid tokens for the output of the main model. This is guaranteed to always produce at least 1 valid token since in the worse case the first token does not match, but the output for the first token will be equal to the output of running the forward pass without having done speculative decoding. You can get a 2x to 4x performance increase from this if done right.

Now, I do not work on any of this professionally, but I am willing to guess that beyond these techniques, they have groups of machines handling queries of similar length in parallel (since doing a batch where 1 query is much longer than the others is inefficient) and some sort of dynamic load balancing so that machines do not get stuck with a query size that is not actively being utilized.