Remix.run Logo
spi 3 days ago

A separate comment about conclusions about why they are worse than OpenAI GPT2 - which to me feel to be missing the point.

One main point is batch size - I'd agree with Gemini here. Batch size <= 5 with 1024 seq len is really tiny. Nowadays models are trained with effective batch size of millions of tokens in total. Of course, this won't fit into memory, one uses gradient accumulations to that purpose, again as mentioned by Gemini.

Training duration is definitely also a reason - models do get better over time, otherwise people wouldn't train so long wasting millions :-) just how long for optimality is unclear, but certainly < 2 days is not optimal even at this "small" scale.

The optimizer could also play a role. As the author mentions, a fixed learning rate is hardly optimal, it is typically both increased in the beginning ("warm up", but that's for stability, if training works without, that's not an issue) and scaled down at the end ("cool down" - that is, annealing, with cosine as mentioned in the article). This generally squeezes out a bit more performance. Also, while it's true that dropout was used back then (might be useful for many epochs, likely only harmful for < 1 epoch), using _both_ dropout _and_ weight_decay > 0, as the author does, is probably wrong and makes training too slow & careful to get good results. Also, even if used, a "good" implementation of weight decay should skip some layers like embeddings and biases (GPT2 did that, and it's relatively important to do so).

On the other hand, I'm pretty sure that using mixed precision and TF32 has absolutely no downsides. It's really standard nowadays to use either mixed precision (FP16 gradients + FP32 base weights) or directly BF16 ("brain" float 16, a bit like the TF32 described there, but with only 16 bits) and I have almost never seen either one fail... and when it does, it typically fails spectacularly, with NaN losses or the model degenerating to trivial performance.

gpjt 3 days ago | parent | next [-]

OP here -- thanks! I'm in the process of doing some trains using the same code plus DDP on big Lambda Labs machines, and (within the bounds of what I can afford) will hopefully have some interesting results about all of those shortly.

gpjt 2 days ago | parent [-]

OK, early indicators support both you and Gemini quite strongly re: batch size. On my (somewhat ad-hoc) test dataset, I get losses like this:

  * OpenAI medium weights: 3.231
  * OpenAI small weights: 3.500
  * My locally trained model, FineWeb Chinchilla, batch size 6: 3.944
  * My locally trained model, FineWeb-Edu Chinchilla, batch size 6: 4.167
  * My locally trained model, FineWeb-Edu double Chinchilla, batch size 6: 4.135
  * My cloud trained model, FineWeb Chinchilla, batch size 13 \* 8 = 104: 3.674
That last one was trained on an 8x A100 machine with 40 GiB per GPU, with the same code as before, just converted to DDP. It certainly looks like the much larger batch size has improved the model significantly.

I'll be trying on larger machines. No gradient accumulation yet, but it's certainly looking like a valuable lever to pull for local training runs (and, I suspect, might also be useful on "small" cloud machines like the one I used -- will have to see what things look like with the bigger mini-batches I can squeeze onto 80 GiB and 160 GiB GPUs).

spi 2 days ago | parent [-]

Thanks, very nice to see these results! Certainly using GPUs with more RAM makes things simpler to scale. Gradient accumulation is as easy as adding a counter for number of steps and an "if counter % gradient_accumulation_steps:` around `optimizer.step()`, so that can also be tried simply on a single GPU / cheaper GPUs. But if you can just use 8xA100 and your pipeline parallizes well, you also get results (almost) 8 times faster, which is certainly nicer to experiment of course!

gpjt 2 days ago | parent [-]

Exactly! If I can get it down to an hour or two (seems very plausible on an 8x H200 with 160 GiB VRAM per GPU, though those are almost never available on Lambda Labs), I'll do the experiments with dropout and the other possible causes of issues, then see if I can bake that all into a new train on the RTX 3090 and confirm it repros there. Looks like I'll definitely need gradient accumulation there.

I assume the zero_grad would need to go in the same if block?

gpjt a day ago | parent [-]

Hmm, interesting. With a batch size of 512 (8x B200s with 160 GiB each) I get worse results! Maybe there's a sweet spot somewhere in between.

whimsicalism 3 days ago | parent | prev | next [-]

> Nowadays models are trained with effective batch size of millions of tokens in total. Of course, this won't fit into memory, one uses gradient accumulations to that purpose, again as mentioned by Gemini.

I would be surprised if there is much/any gradient acc in modern large-scale pretraining runs. You can always just recruit more GPUs with DP/PP/TP rather than training for longer.

alansaber 3 days ago | parent | prev [-]

To caveat, smaller batch sizes are generally better for model stability, but we go bigger because it substantially speeds up training

spi 2 days ago | parent [-]

Mmh not really. As OP shows, speed increases with larger batch size, but only initially, until the GPU has high enough utilization; then speed improvements flatten out (although you might get OOM before that and not "really" see the flat part). Using smaller batch size increases _noise_, so quite literally decreases stability. That might be good sometimes: in the limit case, if the batch is as large as your training set, you'll end up in local minima and not be able to get out of it. But this is true for toy datasets like MNIST, here it's an entirely different beast.

With such large corpora as the ones used here, and very noisy ones at that, gradient updates are very noisy and that can harm quality. Or anyway, common lore is that one needs pretty large batch size to have the language model improve steadily.

alansaber 2 days ago | parent [-]

Are you sure about the top-cap on batch size for speed? See https://arxiv.org/pdf/1904.00962