| ▲ | hirako2000 6 hours ago | |||||||||||||||||||||||||||||||
The claims of the article assumes far more compute and far more VRAM..while the trick enables less back and forth, they don't eliminate it. I doubt you meant 50M. Rather 50B? You can only give it a try, but don't get your hopes high on a large context. If their technique works I would guess 8096k context limits would still OOM. 2048 maybe. I'm extrapolating based on my experiment without this paper's trick to leverage the system memory. | ||||||||||||||||||||||||||||||||
| ▲ | kouteiheika 6 hours ago | parent [-] | |||||||||||||||||||||||||||||||
> You can only give it a try, but don't get your hopes high on a large context. You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel. For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes. | ||||||||||||||||||||||||||||||||
| ||||||||||||||||||||||||||||||||