Remix.run Logo
kouteiheika 2 days ago

This is another potential improvement to the transformer architecture from Facebook (the other one that comes to mind is this one from same authors: https://arxiv.org/abs/2405.18719), but note that it comes with a major problem that might not be obvious at first glance: it's just not usable in practice without a ton of work. It modifies the innards of the attention mechanism, so it is incompatible with Flash Attention (or any other optimized attention library), and you do not want to train anything beyond toy models without Flash Attention (the performance hit is just way too big).

There's pytorch's FlexAttention which could maybe make this practical, but currently it's just way too buggy.

jszymborski 2 days ago | parent | next [-]

People familiar with exotic RNNs and improvements to LSTMs know this problem all too well. The moment your lstm isnt a bog standard lstm, it loses all the speed-ups from cuDNN and it becomes borderline unusable for anything but toy models.

tpurves a day ago | parent [-]

These would be inherently temporary problems though right? If it became eventually clear that alternate methods were the way forward, NVDIA would be highly motivated to do the optimization work wouldn't they? Any new step functions that can forestall the asymptotic plateauing of AI progress are things they desperately need.

jszymborski a day ago | parent | next [-]

That follows reason, but in practice I find that its often not the case. My suspicion is that it's hard to establish that your method is superior to another if, for example, it takes 10-100x the compute to train a model. This is largely in part due to the fact that machine learning is currently a deeply empirical field.

Nvidia isn't likely to start releasing updated firmware for an obscure architecture for which there is limited evidence of improvement, and even less adoption.

kouteiheika 19 hours ago | parent [-]

Indeed. Especially when a lot of papers are just using cherry-picked results that show some improvements just so they can publish something, but their method doesn't work that well when it comes in contact with reality (e.g. see the deluge of papers which claim to have come up with an optimizer better than AdamW), and when the majority of people are not even properly benchmarking their new methods wrt to the time overhead (no, it doesn't matter if your method achieves 1% better loss if it takes 10% longer to train, because if I'd trained for 10% longer without your method I'd get an even better loss; and don't even get me started on people not tuning their baselines).

I've been burnt way too many times by fancy new methods that claimed improvement, where I spent a ton of effort to implement them, and they ended up being poop.

Every person working in the field and pushing papers should read this blog post and apply what's written in it: https://kellerjordan.github.io/posts/muon/#discussion-solvin...

porridgeraisin 15 hours ago | parent [-]

Yep. Offline RL is especially full of these types of papers too. The sheer number of alternatives to the KL divergence to prevent the offline distribution from diverging too far from the collected data distribution... There's probably one method for each person on earth.

ssivark a day ago | parent | prev [-]

Check out The hardware lottery [1], which drove a lot of discussion a few years ago.

[1]: https://arxiv.org/abs/2009.06489

albertzeyer a day ago | parent | prev [-]

Why do you say FlexAttention is too buggy? I have heard about a lot of successful usages of it, and never heard about any such problems.

Also note, depending on your model dimensions and sequence lengths, often the attention computation plays only a minor role (maybe 10% overall or so), and the MLP computation dominates.

kouteiheika a day ago | parent [-]

Last time I tried it I encountered both showstopper bugs (it was completely obviously broken) and subtle correctness bugs (it looked like it was working, but since I'm paranoid I have unit tests for everything and numerically the errors were too big compared to what you'd get with eager attention or Flash Attention), and it was too slow for my taste compared to Flash Attention so I just dropped it. And I wasn't even doing anything super exotic with it.

Maybe it's better now, but I'd still consider using FlexAttention without a corresponding unit test checking its accuracy against an equivalent eager implementation completely irresponsible.

gessha 17 hours ago | parent [-]

What unit tests do you use for nn modules and how do you come up with them?

kouteiheika 9 hours ago | parent | next [-]

Unit tests which test random inputs across different sizes (e.g. with different number of heads, head sizes, embedding dimensions, etc.) and compare two different implementations' output to each other (e.g. attention implemented manually in an eager fashion vs a bunch of accelerated attention libraries).

Also more integration-like tests where I take an already pretrained model, load it using an established library (e.g. Huggingface Transformers) and I also load the very same checkpoint into my reimplementation (where I vary the implementation, e.g. swap the attention implementation) and compare the outputs. Funnily enough, I recently even found a bug in HF's Transformers this way when I updated to a newer version and my previously matching output was not matching anymore.

porridgeraisin 15 hours ago | parent | prev [-]

I would like to know too