| ▲ | kouteiheika 4 hours ago | |||||||
> You can only give it a try, but don't get your hopes high on a large context. You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel. For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes. | ||||||||
| ▲ | gavinray 3 hours ago | parent | next [-] | |||||||
I'd not heard of this before, quick search turned up this 2025 post which suggests "fused cross-entropy loss" kernel was integrated into PyTorch: https://pytorch.org/blog/peak-performance-minimized-memory/
Is this the same thing as you discuss above? | ||||||||
| ||||||||
| ▲ | hirako2000 3 hours ago | parent | prev [-] | |||||||
Activation would still require gigabytes for a few kb context. There are plenty of techniques to optimise. But the question is what can an rtx 3080 train before OOM. The answer is not that much. Can barely do quantized fine tuning. Even then, small context. | ||||||||
| ||||||||