Remix.run Logo
BarakWidawsky 6 days ago

I wonder how much of this is due to Diffusion models having less capacity for memorization than auto regressive models

The auto regressive models consistently show better loss for the same number of training tokens

I find a lot of the conclusions compelling but I would’ve loved to see more epochs of training on the 1B model with a 10B dataset, as that model was showing epoch over epoch improvements

thesz 6 days ago | parent | next [-]

> I wonder how much of this is due to Diffusion models having less capacity for memorization than auto regressive models

Diffusion requires more computation resources than autoregressive models, compute excess is proportional to the length of sequence. Time dilated RNNs and adaptive computation in image recognition hint us that we can compute more with same weights and achieve better results.

Which, I believe, also hint at the at least one flaw of the TS study - I did not see that they matched DLM and AR by compute, they matched them only by weights.

heyitsguay 6 days ago | parent [-]

Do you have references on adaptive methods for image recognition?

godelski 6 days ago | parent [-]

I don't have an exact reference but there are a lot more hints that evidence the claim (compute more with same weights). In fact, I wouldn't even call them hints since they aren't subtle at all. For one, animal brains are perfect examples of this. But in the ML space, we could think of this purely from the mathematical perspective.

I think it might be confusing because neurons are neurons right? And they can only hold so much memory, so what's the difference? Well, that difference is architecture and training.

Let's think about signals for a moment and to help understand this, let's move to small dimensions[0]. Like 2D or 3D. (I'll use 3D, but you'll see why this can still ruin visualization) We're talking about universal approximates, so we can think of these as finite length strings, but have fixed end points. Our goal is then to untangle these strings. Oh no, this bundle has a knot! We can't actually untangle this string just by stretching. We also have a rule that we can't cut and glue things. We'd be stuck if we didn't have a trick up our sleeves. We can move into a higher dimension and untangle these strings there[1]. We'll need at least 2N-D. To the flatlander this will look like a cut, but it isn't.

The reason this needs to be understood is because we need to know where we get those dimensions. It is through architecture and training. But let's just think about that architecture. When we're learning these relationships we need to have the capacity to perform these higher dimensional movements, but once we already uncover the relationships we don't necessarily need to. The relationship it depends on the dimensionality of the relationship itself, not the data.

This is true for all models and is fundamentally why things like distillation even work. It is also why that FFN layer post attention in the transformer needs to project into a higher dimension before returning (typical is 4x and I think you can reason why that gives more flexibility than 2x). Also related to the latent manifold hypothesis.

If you ever wondered if math is useful to machine learning, I hope this gives some motivation to learn more. You don't need math to build good models, but even a little math goes a long way to help make better models.

[0] Note, we're doing a significant amount of simplification here. There's a lot of depth and complexity to all of this but I think this will be sufficient to point anyone in (mostly) the right direction.

[1] Think about a Klein bottle. In 4D it has a single surface. But the 3D projection of this shape makes it look like it is intersecting itself. Unfortunately we can't really visualize the 4D version :(

godelski 6 days ago | parent | prev | next [-]

  > as that model was showing epoch over epoch improvements
Both of them were showing improvements. I agree with you that I'd like to see more, but I'm not sure more would significantly change the argument (which is a lot about how metrics aren't straight forward). Especially since the 96B token experiment shows.

IN FACT, those results are so similar I had to open them up in GIMP to align and spot the differences. Now I'm actually not convinced there wasn't a mistake. There are differences, just very minor. Harder to tell with the AR model because scale, but in the diffusion you can see a little bump in the second one right before the concavity change at the end. There some more bumps in the AR model earlier on that help show differences too, but the fact that the envelopes are nearly identical is... suspicious. I'm not claiming maliciousness because even if a mistake these things are so easy to make that they are common. I'm not even convinced there is a mistake, but it warrants extra thinking.

That said, money is finite and these are quite computationally heavy. Author looks to be a research fellow and so I'm assuming not backed by big tech.

cma 6 days ago | parent | prev [-]

> The auto regressive models consistently show better loss for the same number of training tokens

I thought bi-directional transformers (non auto-regressive) show less loss than autoregressive for the same amount of training tokens.

pama 5 days ago | parent [-]

It is the other way around. If the data is causal and presented in the causal order, it is impossible to beat the loss of a pure auto-regressive model because it has the correct probability distribution for the dataset. Language data is mostly causal (as words follow in the context of previous words when they are spoken/written). Most of the remaining additional info in the extreme oversampling of the same data via diffusion models should be there by using fill-in-the-middle or order-reversal strategies with AR models as well and with significant compute savings during training.

cma 4 days ago | parent [-]

I mean models like BERT and not diffusion.

> Language data is mostly causal (as words follow in the context of previous words when they are spoken/written).

But where it isn't, the old KV is frozen in place and has to be ammended after what follows, where BERT like models take it all into account all over.

I have definitely heard they have less loss for the same amount of training tokens but are less efficient to compute and running next token prediction from them would be much more expensive.