← Back to context

Comment by thesz

1 day ago

The difference between training and inference is 1) one have to keep intermediate results for backward pass in training and 2) computation for training double because of the backward pass.

Training is also done over batches, which increase memory requirements by several orders of magnitude. This is why training needs costly compute.

One of the ways out of this unfortunate situation is to use something like Stochastic Average Gradient Descent [1]. Examples there are mostly concerned with regularized logistic regression, which makes problem more or less convex. Neural networks are inherently non-convex. Still, maybe some ideas from there can be utilized in the context of neural networks, like use of estimated Lipshitz constant to derive curvature and appropriate learning step.

  [1] https://www.cs.ubc.ca/~schmidtm/Courses/540-W19/L12.pdf

So one way to think about it is roughly,

Training is inference + backwards pass (~2x inference cost) + activations (vram overhead) + optimizer (vram overhead) + gradients (vram overhead).

  • Multiply "inference + backwards pass (~2x inference cost) + activations (vram overhead)" by batch size (thousands) to get to the actual RAM and compute cost. Optimizer like ADAM adds only two or three model-sized overhead.

    And last, but not least, you need only one hidden layer kept in RAM for inference, but you need all of them (61 for Deepseek models) kept in RAM for computing gradient for one sample.

    • Microbatch size is a hyperparameter, it can be set to 1 and work just as effectively. With gradient accumulation it's equivalent even. Large batch sizes are used to increase parallelism, and sometimes to reduce variance in the loss signal (at the cost of increased bias).

      Batch size is frequently limited by compute bottlenecks well before memory.

    • And of course you do all of this for every object in your training set, which is going to be larger than the total number of uses for any individual user.

It's all got much more complex than that in recent years. Training now involves large amounts of inference for RL rollouts and similar. You can't disentangle them computationally like that. "Inference" is just the word used to mean serving customer traffic now, and "training" means creating the model you serve.

That is an estimate of the relative cost of one training step, but you have to multiply it by the number of training steps, an unknown quantity.