Remix.run Logo
gavinray 3 hours ago

I'd not heard of this before, quick search turned up this 2025 post which suggests "fused cross-entropy loss" kernel was integrated into PyTorch:

https://pytorch.org/blog/peak-performance-minimized-memory/

  > "The integration involves modifying the TransformerDecoder module in torchtune to bypass the linear layer computation, allowing the Liger Fused Linear Cross Entropy Loss to handle the forward projection weights. "
Is this the same thing as you discuss above?
kouteiheika 2 hours ago | parent [-]

Yes.

Although this wasn't integrated into PyTorch itself (but to torchtune, which is a different thing). If you're writing your own training loop you need to use a third-party kernel, e.g. the Liger kernel mentioned in the article, or Cut Cross Entropy (which is much better than the Liger one, although IIRC it has a numeric bug in one of its kernels making the results very slightly off).