← Back to context

Comment by hirako2000

9 hours ago

The claims of the article assumes far more compute and far more VRAM..while the trick enables less back and forth, they don't eliminate it.

I doubt you meant 50M. Rather 50B?

You can only give it a try, but don't get your hopes high on a large context. If their technique works I would guess 8096k context limits would still OOM. 2048 maybe.

I'm extrapolating based on my experiment without this paper's trick to leverage the system memory.

> You can only give it a try, but don't get your hopes high on a large context.

You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel.

For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes.

  • I'd not heard of this before, quick search turned up this 2025 post which suggests "fused cross-entropy loss" kernel was integrated into PyTorch:

    https://pytorch.org/blog/peak-performance-minimized-memory/

      > "The integration involves modifying the TransformerDecoder module in torchtune to bypass the linear layer computation, allowing the Liger Fused Linear Cross Entropy Loss to handle the forward projection weights. "
    

    Is this the same thing as you discuss above?

    • Yes.

      Although this wasn't integrated into PyTorch itself (but to torchtune, which is a different thing). If you're writing your own training loop you need to use a third-party kernel, e.g. the Liger kernel mentioned in the article, or Cut Cross Entropy (which is much better than the Liger one, although IIRC it has a numeric bug in one of its kernels making the results very slightly off).

  • Activation would still require gigabytes for a few kb context.

    There are plenty of techniques to optimise. But the question is what can an rtx 3080 train before OOM. The answer is not that much.

    Can barely do quantized fine tuning. Even then, small context.

    • > Activation would still require gigabytes for a few kb context.

      For that you use activation checkpointing, and you can also offload that to the CPU in a smart way to hide the latency. Although, yes, for long context training the activations do dominate the memory usage (and quantizing them degrades things more than just quantizing weights and/or optimizer states).

Maybe they're not talking about LLMs? Training a 9M parameter YOLO can take over 20GB of VRAM if you use a large image size.