Remix.run Logo
cobolexpert 4 hours ago

Dumb question: is the quadratic time complexity for training, inference, or both?

dave_universetf 2 hours ago | parent | next [-]

Both, with caveats. The attention computation is fundamentally quadratic: for every token in the sequence, you're doing a computation that has to compute over every other token in the sequence. So it's O(N) per token, O(N^2) for the whole sequence.

The big mitigation for this is that in causal transformers (i.e. all the chatbot type applications, where each token is only allowed to see tokens before it), you're running inference repeatedly on the same prefix in order to grow it by one token at a time. So if you cache the computations for tokens 0..N-1, on each inference pass you only have to compute O(N) for the newly added token at the end of the sequence.

That's why caching (and caching charges) appear so prominently everywhere in the pricing of inference.

In practice, caching is most beneficial at inference time, because you typically have relatively long conversations that start with the same cacheable prefix (the system prompt). At training time the same optimization can apply, but you're typically not pushing the same prefixes through the model repeatedly so you end up paying the quadratic cost more often.

The quadratic cost of attention is the fundamental compute bottleneck for transformer architectures, which is why there's research like this trying to find shortcuts in computing attention, as well as research into completely new primitives to replace attention (e.g. SSM, which is O(N) on a cold cache and O(1) on a warm cache).

omneity 4 hours ago | parent | prev [-]

Attention is calculated during the forward pass of the model, which happens in both inference (forward only) and training (forward & backward).

SubiculumCode 3 hours ago | parent [-]

Dumb question: Can inference be done in a reverse pass? Outputs predicting inputs?

dave_universetf an hour ago | parent | next [-]

Strictly speaking: no. The "forward pass" terminology does not imply that there exists a "reverse pass" that does the same kind of computation. Rather, it's describing two different kinds of computation, and the direction they occur in.

The forward pass is propagating from inputs to outputs, computing the thing the model was trained for. The reverse/backwards pass is propagating from outputs back to inputs, but it's calculating the gradients of parameters for training (rougly: how much changing each parameter in isolation affects the output, and whether it makes the output closer to the desired training output). The result of the "reverse pass" isn't a set of inputs, but a set of annotations on the model's parameters that guide their adjustment.

The computations of the forward pass are not trivially reversible (e.g. they include additions, which destroys information about the operand values). As a sibling thread points out, you can still probabilistically explore what inputs _could_ produce a given output, and get some information back that way, but it's a lossy process.

And of course, you could train a "reverse" model, one that predicts the prefix of a sequence given a suffix (trivially: it's the same suffix prediction problem, but you train it on reversed sequences). But that would be a separate model trained from scratch on that task, and in that model the prefix prediction would be its forward pass.

gpm 3 hours ago | parent | prev | next [-]

Not as trivially as the forwards direction, unsurprisingly information is lost, but better than you might expect. See for example https://arxiv.org/pdf/2405.15012

root_axis 3 hours ago | parent | prev [-]

Sounds like a great premise for a sci-fi short story.

anu7df 2 hours ago | parent [-]

Sci-fi ? You mean historical fiction!