Remix.run Logo
MediaSquirrel 14 hours ago

Memory usage increases quadratically with sequence length. Therefore, using shorter sequences during fine-tuning can prevent memory explosions. On my 64GB RAM machine, I'm limited to input sequences of about 2,000 tokens, considering my average output for the fine-tuning task is around 1,000 tokens (~3k tokens total).

zozbot234 8 hours ago | parent | next [-]

Shouldn't FlashAttention address the quadratic increase in memory footprint wrt. fine-tuning/training? I'm also pretty sure that it does not apply to pure inference due to how KV-caching works.

LuxBennu 12 hours ago | parent | prev | next [-]

Ah that makes sense, quadratic scaling is brutal. So with 96gb i'd probably get somewhere around 4-5k total sequence length before hitting the wall, which is still pretty limiting for anything multimodal. Do you do any gradient checkpointing or is that not worth the speed tradeoff at these sizes?

MediaSquirrel 10 hours ago | parent [-]

Haven’t tried yet. That’s on the do list. But good suggestion.

8 hours ago | parent | prev [-]
[deleted]