← Back to context

Comment by KaiserPro

7 days ago

Same explanation but with less mysticism:

Inference is (mostly) stateless. So unlike training where you need to have memory coherence over something like 100k machines and somehow avoid the certainty of machine failure, you just need to route mostly small amounts of data to a bunch of big machines.

I don't know what the specs of their inference machines are, but where I worked the machines research used were all 8gpu monsters. so long as your model fitted in (combined) vram, you could job was a goodun.

To scale the secret ingredient was industrial amounts of cash. Sure we had DGXs (fun fact, nvidia sent literal gold plated DGX machines) but they wernt dense, and were very expensive.

Most large companies have robust RPC, and orchestration, which means the hard part isn't routing the message, its making the model fit in the boxes you have. (thats not my area of expertise though)

> Inference is (mostly) stateless. ... you just need to route mostly small amounts of data to a bunch of big machines.

I think this might just be the key insight. The key advantage of doing batched inference at a huge scale is that once you maximize parallelism and sharding, your model parameters and the memory bandwidth associated with them are essentially free (since at any given moment they're being shared among a huge amount of requests!), you "only" pay for the request-specific raw compute and the memory storage+bandwidth for the activations. And the proprietary models are now huge, highly-quantized extreme-MoE models where the former factor (model size) is huge and the latter (request-specific compute) has been correspondingly minimized - and where it hasn't, you're definitely paying "pro" pricing for it. I think this goes a long way towards explaining how inference at scale can work better than locally.

(There are "tricks" you could do locally to try and compete with this setup, such as storing model parameters on disk and accessing them via mmap, at least when doing token gen on CPU. But of course you're paying for that with increased latency, which you may or may not be okay with in that context.)

  • > The key advantage of doing batched inference at a huge scale is that once you maximize parallelism and sharding, your model parameters and the memory bandwidth associated with them are essentially free (since at any given moment they're being shared among a huge amount of requests!)

    Kind of unrelated, but this comment made me wonder when we will start seeing side channel attacks that force queries to leak into each other.

    • I asked a colleague about this recently and he explained it away with a wave of the hand saying, "different streams of tokens and their context are on different ranks of the matrices". And I kinda believed him, based on the diagrams I see on Welch Labs YouTube channel.

      On the other hand, I've learned that when I ask questions about security to experts in a field (who are not experts in security) I almost always get convincing hand waves, and they are almost always proven to be completely wrong.

      Sigh.

  • mmap is not free. It just moves bandwidth around.

    • Using mmap for model parameters allows you to run vastly larger models for any given amount of system RAM. It's especially worthwhile when you're running MoE models and parameters for unused "experts" can just be evicted from RAM, leaving room for more relevant data. But of course this applies more generally to, e.g. single model layers, etc.

> Inference is (mostly) stateless

Quite the opposite. Context caching requires state (K/V cache) close to the VRAM. Streaming requires state. Constrained decoding (known as Structured Outputs) also requires state.

  • > Quite the opposite.

    Unless something has dramatically changed, the model is stateless. The context cache needs to be injected before the new prompt, but for what I understand (and please do correct me if I'm wrong) the the context cache isn't that big, like in the order of a few tens of kilobytes. Plus the cache saves seconds of GPU time, so having an extra 100ms of latency is nothing compare to a cache miss. so a broad cache is much much better than a narrow local cache.

    But! even if its larger, Your bottleneck isn't the network, its waiting on the GPUs to be free[1]. So whilst having the cache really close ie in the same rack, or same machine, will give the best performance, it will limit your scale (because the cache is only effective for a small number of users)

    [1] a 100megs of data shared over the same datacentre network every 2-3 seconds per node isn't that much, especially if you have a partitioned network (ie like AWS where you have a block network and a "network" network)

    • KV cache for dense models is order 50% of parameters. For sparse moe models it can be significantly smaller I believe, but I don’t think it is measured in kb.