Remix.run Logo
thesz 2 hours ago

The difference between training and inference is 1) one have to keep intermediate results for backward pass in training and 2) computation for training double because of the backward pass.

Training is also done over batches, which increase memory requirements by several orders of magnitude. This is why training needs costly compute.

One of the ways out of this unfortunate situation is to use something like Stochastic Average Gradient Descent [1]. Examples there are mostly concerned with regularized logistic regression, which makes problem more or less convex. Neural networks are inherently non-convex. Still, maybe some ideas from there can be utilized in the context of neural networks, like use of estimated Lipshitz constant to derive curvature and appropriate learning step.

  [1] https://www.cs.ubc.ca/~schmidtm/Courses/540-W19/L12.pdf
janalsncm an hour ago | parent [-]

So one way to think about it is roughly,

Training is inference + backwards pass (~2x inference cost) + activations (vram overhead) + optimizer (vram overhead) + gradients (vram overhead).

thesz 32 minutes ago | parent [-]

Multiply "inference + backwards pass (~2x inference cost) + activations (vram overhead)" by batch size (thousands) to get to the actual RAM and compute cost. Optimizer like ADAM adds only two or three model-sized overhead.

And last, but not least, you need only one hidden layer kept in RAM for inference, but you need all of them (61 for Deepseek models) kept in RAM for computing gradient for one sample.