Remix.run Logo
amluto 3 hours ago

I skimmed the paper, and I think I completely lost the plot.

Sections 2.1 through 2.4 talk about the decomposing the per-token-pair attention (key vector from the ith token with query vector from the jth token, where, in inference, the jth token is the one being sampled) into an approximation that is only mildly outrageously exponential in size compared to the original exponential-of-a-dot product. And they get something that's a polynomial (in the mathematical sense -- you're literally evaluating a polynomial) and has a size that's manageable at 4th order.

Okay, great, they took something simple and made it bigger and nastier but less transcendental without losing too much precision. (As far as I know, there is really nothing special about the exp in attention in the first place, so trying to approximate it well seems mostly useful insofar as it will keep existing models working.)

But the reason that attention is quadratic is that each token gets evaluated with respect to each other token. They haven't changed this at all. Section 2.5 seems like it's deferring this to an appendix. Section 2.6 gives the hidden state size per token, which, on first read, is strictly larger than the hidden state in normal attention (in normal attention it's d_v * d_k -- I'm not sure where their +1 comes from).

So what did the paper gain? Is there some detail that I missed or that the paper completely glossed over that explains why there is any gain of efficiency at all?

For what it's worth, the paper's overall claim is, in some sense, impossible. You can think of attention as being a sort of vector database, and this gets more accurate the sharper you make the exponential. If you replace softmax with actual max, a query locates the key that is the closest match to the query and returns the associated value. This operation is a plain linear search, it's possible (in principle anyway) to do lots of queries and recover the entire contents of the database, and I think that any paper claiming to do it faster than linear time should explain how it's compressing the data and where the loss is.

In language model terms, imagine an prompt like so:

    1: [string 1]
    2: [string 2]
    3: [string 3]
    ...
    n: [string n]
    
    Tell me the string associated with the number k.
As long as there's enough precision and enough query/key space to fit some embedding of the number k that will match the right thing (and there is a lot of room in high-dimensional spaces), one might expect a transformer to be able to answer this question. But this obviously requires memory with size linear in the prompt length. If you try to get rid of that, you necessarily lose something. (This is not to say that nice attention scaling is impossible -- one could imagine schemes where it takes the model multiple tokens to answer the question, and the number of tokens needed could scale, say, logarithmically with prompt size. But you still need that linear memory.)
csense 2 hours ago | parent | next [-]

This paper combines two different insights, the second one is buried in the appendix.

Let's say you consider the 3 most-recent tokens. The first insight is that you can use a Taylor approximation: At token position 3 you compute A_3 = ((q1, q2, q3) . (k1, k2, k3))^1, B_3 = ((q1, q2, q3) . (k1, k2, k3)^2, C_3 = ((q1, q2, q3) . (k1, k2, k3))^3, etc. [1] [2]

The second insight is that you can compute e.g. B_{i+1} incrementally from B_i, with much fewer FLOPS than computing B_{i+1} from scratch. [3]

[1] I'd buy that it's empirically "good enough" that you don't need to go beyond D_3 (fourth degree polynomial).

[2] I'd also buy that it's empirically "good enough" to assume the inputs aren't extreme enough for E_3, F_3 etc. to matter. I agree with other posters that radius of convergence worries aren't addressed. I find it plausible that these issues don't sink the paper. I'd not be surprised to learn that either it doesn't matter in practice, or workarounds can be implemented without much performance impact.

[3] The author's choice to bury this insight in an appendix rather than putting it front and center is a baffling pedagogical choice but it's a small issue in the grand scheme of things. Perhaps that second insight is prior work (possibly by others) that experts in the latest LLM linear algebra could reasonably be expected to be familiar with, but is included as an appendix because it's not universally known in e.g. HN comment sections?

fheinsen 2 hours ago | parent [-]

[3] is linear attention, https://arxiv.org/abs/2006.16236, a well-known result with ~3K citations: https://scholar.google.com/scholar_lookup?arxiv_id=2006.1623...

yorwba an hour ago | parent | prev | next [-]

> But the reason that attention is quadratic is that each token gets evaluated with respect to each other token. They haven't changed this at all. Section 2.5 seems like it's deferring this to an appendix.

They defer it to the appendix because it's a standard construction (Q'K)V = Q'(KV), where Q'K is an n×n matrix and requires O(n²) to compute, but KV has a constant size and can be computed in O(n) time, and the multiplication with Q' can also be done in O(n) time.

> Section 2.6 gives the hidden state size per token, which, on first read, is strictly larger than the hidden state in normal attention (in normal attention it's d_v * d_k -- I'm not sure where their +1 comes from).

Actually, their hidden state has a (large) constant size, so strike the words "per token" from section 2.6. In normal attention, the total state is n(d_v + d_k), but their state is basically (d_v + 1)D_k, where D_k is much larger than d_k, but independent of n. The +1 is because they also need to compute the normalization factor for the softmax.

It's true that a constant state size implies that you cannot use it to losslessly store arbitrarily large databases, but LLMs in practice cannot do this either, so there's no loss of capability in that sense. (In fact, if you use enough terms in the Taylor expansion to get the same result as standard attention to within machine precision, the resulting constant state size should give you an upper bound for the amount of data the LLM can effectively retrieve from its context.)

fheinsen 2 hours ago | parent | prev | next [-]

This is a form of linear attention (https://arxiv.org/abs/2006.16236) that approximates standard scaled dot-product attention to arbitrary precision, by adding Taylor terms in an efficient manner. Each additional Taylor term improves the approximation. Efficiency is achieved by exploiting certain mathematical symmetries that become evident only after decomposing the standard formulation of attention into an expression over chains of tensor products. The github repository's README walks through examples. The first example is with 8 Taylor terms.

jsenn 2 hours ago | parent | prev | next [-]

> Section 2.6 gives the hidden state size per token, which, on first read, is strictly larger than the hidden state in normal attention

This is where you’ve gone off track. The “hidden state” for their model is a fixed size thing, like in an RNN, not per token. For a transformer, the “hidden state” is called the KV cache, and it grows with sequence length. This is why their method is linear not quadratic.

The Taylor Series they derive isn’t just for softmax (after all, real implementations of softmax will likely already use the Taylor series!), it’s for the entire tensor-level softmax(QK) computation.

adarsh2321 2 hours ago | parent | prev [-]

[flagged]

ripbozo 2 hours ago | parent [-]

llm detected