Remix.run Logo
furyofantares 4 hours ago

Is it really no quality degradation?

I'm curious where my understanding is wrong, but I didn't think you necessarily got the exact same output with how I understand speculative decoding to be used. I thought that if the small model produces tokens that are "good enough", meaning within the top few tokens the larger model produces, they're accepted.

I thought it doesn't necessarily have to produce the exact same token the larger model would have produced to be accepted (and that requiring this would reduce the hit rate by a lot.) Just one the top model could have produced with whatever top-k and temperature settings.

Klaus23 3 hours ago | parent | next [-]

It really is. This is because LLMs with a single output/user are strongly bandwidth limited. Although the hardware can generate multiple tokens simultaneously, it is slowed down if the tokens depend on each other, as is the case with regular text generation.

The draft model essentially predicts the next token quickly, enabling you to start generating the subsequent token in parallel. If the guess is right, the second generated token is correct. If it is wrong, the second generated token is also potentially wrong, so it must be generated again using the correct prior token obtained through the big model.

A poor draft model will simply slow down the process without affecting the output.

furyofantares 3 hours ago | parent [-]

> If the guess is right

This is the crux. What makes the guess "right"?

I think the acceptance criteria is not that the token is exactly the token the big model would have produced. It's accepted of the big model verifies that the probability of that token was high enough.

How close it is to the same output (or same distribution of outputs) you'd get from running the big model would be dependent on temperature, top-k, top-p settings, or other inference parameters.

Klaus23 2 hours ago | parent | next [-]

The token is correct if it matches the one generated by the main model. It works like this:

The draft model quickly generates draft-token 1.

The main model then starts working on two tokens in parallel. It calculates token 1 based on the context, and token 2 based on the context + draft-token 1.

Once the two tokens have been generated, you can check whether the draft-token 1 from the draft model matches token 1 from the main model.

If they match, you have just calculated two tokens in the time it takes to generate one, because the calculation was done in parallel. If they do not match, delete token 2 and generate it again. Since you have already generated the correct token 1 with the big model, you can use the context + token 1 (from the main model). This takes more time, but the result is always the same.

petu 2 hours ago | parent | prev | next [-]

> What makes the guess "right"?

Matching token that would've been picked without speculative decoding. That seems to be more or less agreed upon.

e.g. vLLM docs list tests they run to ensure that output doesn't change if spec. decoding is used: https://github.com/vllm-project/vllm/blob/main/docs/features...

But introducing some threshold to accept other high probability tokens is interesting idea.

dist-epoch 2 hours ago | parent | prev [-]

There is more compute available than bandwidth when computing LLMs.

It's like branch prediction - the CPU predicts what branch you'll take and starts executing it. Later you find out exactly what branch you took. If the prediction was correct, the speculative executed code is kept. If the prediction was wrong, it's thrown away, the pipeline is flushed, and the execution resumes from the branch point.

The same with this thing: 3 tokens, A-B-C were "predicted", you start computing ALL them 3 at the same time, hoping that the prediction checks out. And because of the mathematical structure of the transformer, it costs you almost the same to compute 3 tokens at a time or just one - you are limited by bandwidth, not compute. But CRITICALLY, each token depends on all the previous ones, so if you predicted wrongly one of the tokens, you need to discard all tokens predicted after (flush the pipeline). This is why a prediction is required and why you can't always compute 3 tokens simultaneously - the serial dependency between consecutive tokens. If you were to start computing 3 tokens simultaneously without a prediction, for token C you need to assume some exact values for tokens A and B, but those were not computed yet! But if they were speculatively predicted you can start and hope the prediction was correct.

petu 3 hours ago | parent | prev [-]

Speculative decoding batches multiple completions on all possible outcomes (0/1/2 draft tokens accepted) and sees if big model deviates at any point -- thus verifying each token. So there's no difference in output.