Remix.run Logo
gpjt 2 days ago

Exactly! If I can get it down to an hour or two (seems very plausible on an 8x H200 with 160 GiB VRAM per GPU, though those are almost never available on Lambda Labs), I'll do the experiments with dropout and the other possible causes of issues, then see if I can bake that all into a new train on the RTX 3090 and confirm it repros there. Looks like I'll definitely need gradient accumulation there.

I assume the zero_grad would need to go in the same if block?

gpjt a day ago | parent [-]

Hmm, interesting. With a batch size of 512 (8x B200s with 160 GiB each) I get worse results! Maybe there's a sweet spot somewhere in between.