| ▲ | magicalhippo 9 days ago |
| Maybe I'm especially daft this morning but I don't get the point of the speculative decoding. How does the target model validate the draft tokens without running the inference as normal? Because if it is doing just that, I don't get the point as you can't trust the draft tokens before they are validated, so you're still stuck waiting for the target model. |
|
| ▲ | porridgeraisin 9 days ago | parent | next [-] |
| Let's say I want to run f2(f1(x)) where f1 and f2 are both a single pass through GPT4. This takes 2 seconds time, assuming 1 second for every pass. What I instead do is kick off f1(x) in another thread, and then run f2(g1(x)) where g1 is one pass through GPT-nano. This takes 1 + 0.1 seconds, assuming gpt nano takes 0.1s for every pass. In this 1.1 seconds, the f1(x) that we kicked off in the 2nd thread would have finished (it takes 1 second). So in 1.1 seconds we have available to us f1(x), f2(g1(x)), and we store the intermediate g1(x) as well We compare g1(x) and f1(x) If they were equal, i.e g1(x) = f1(x), then we have our answer = f2(g1(x)) in just 1.1s. If they were not, we compute f2(output of f1(x) from 2nd thread) which takes 1 further second, bringing our total to 2.1s. If the small model is equalling the big model in say 2/3 of cases, you will spend 2/3 * 1.1 + 1/3 * 2.1 = 1.433s on average for this computation. Without speculative decoding, it is always 2s. |
| |
| ▲ | magicalhippo 9 days ago | parent | next [-] | | Thanks, very nice explanation, that makes perfect sense. I guess their graphics confused me for some reason and had me thinking all wrong. Now I see they tried to point out the obvious thing which is to predict multiple tokens ahead, not just two as in your example. | |
| ▲ | arkmm 9 days ago | parent | prev [-] | | This is a really great explanation. |
|
|
| ▲ | cristoperb 9 days ago | parent | prev | next [-] |
| My simplified understanding: The target model can validate the draft tokens all at once, in a single forward pass. The output of that forward pass is a list of probabilities for each draft token which are compared to the probabilities produced by the draft model. If the target model's probabilities are the same or greater than the draft model, the tokens are accepted. Worst case none of the draft tokens are accepted and instead the target model selects the single next token as usual. |
|
| ▲ | furyofantares 9 days ago | parent | prev | next [-] |
| Not an expert, but here's how I understand it. You know how input tokens are cheaper than output tokens? It's related to that. Say the model so far has "The capital of France". The small model generates "is Paris.", which let's say is 5 tokens. You feed the large model "The capital of France is Paris." to validate all 5 of those tokens in a single forward pass. |
| |
| ▲ | isoprophlex 9 days ago | parent | next [-] | | but... do you get any validation during the forward pass? the small model could just as well have generated "is Berlin." or whatever. do these models somehow give you a likelihood for the next token when you're prefilling, that you can compare against? if so why not just... use that always? or is this a scenario where computation is expensive but validation is cheap? EDIT: thanks, people, for educating me! very insightful :) | | |
| ▲ | sanxiyn 9 days ago | parent | next [-] | | Yes, models give likelihoods you can compare against. No, you can't do that without drafting, because likelihood of token N+2 depends on token N+1. That is, you get P(is, The capital of France) and P(Berlin, The capital of France is), but for the later you need to give "is" as input, you can't do P(Berlin, The Capital of France _). | |
| ▲ | pama 9 days ago | parent | prev | next [-] | | If you want to go down the rabbit hole of the state of the art, I recommend the EAGLE3 paper: https://arxiv.org/abs/2503.01840 | |
| ▲ | shikon7 9 days ago | parent | prev [-] | | Yes, the forward pass does a next token prediction on all input tokens (so we know exactly how many tokens from the small model matched). The expensive thing is not the computation, but the memory bandwidth, as each pass needs to load the model from memory. If the small model predicts some tokens correctly, you save some passes, at the expense of doing some extra computations when the tokens were not correct. In any case, each forward pass will give at least one new token. |
| |
| ▲ | ahmedfromtunis 9 days ago | parent | prev [-] | | But what would happen if the small model's prediction was "is Rome."? Wouldn't that result in costlier inference if the small model is "wrong" more than it is correct. Also, if the small model would be sufficiently more "correct" than "wrong", wouldn't be more efficient to get rid of the large model at this point? | | |
| ▲ | imtringued 9 days ago | parent | next [-] | | You're forgetting that some sequences are more predictable than others, hence the name "speculative" decoding. Let's say your token encoding has 128k tokens. That means the model has to pick the right token out of 128k. Some of those tokens are incredibly rare, while others are super common. The big model has seen the rare tokens many more times than the small model. This means that the small model will be able to do things like produce grammatically correct English, but not know anything about a specific JS framework. The post training fine tuning costs (low thousand dollars) are the main reason why speculative decoding is relatively unpopular. The most effective speculative decoding strategy requires you to train multiple prediction heads ala medusa (or whatever succeeded it). If you don't do any fine tuning, then the probability of the small model being useful is slim. Using a random model as your draft model will probably give you very disappointing results. | |
| ▲ | acters 9 days ago | parent | prev | next [-] | | I believe that is exactly the downside of using speculative decoding, which is why it is very important to have the models properly sized between each other by making sure the small use is big enough to be mostly correct while also being exceptionally faster than the larger one. However the larger one has to be fast enough that catching flaws won't introduce too manyrandom delays. Also, if the small one is incorrect then the larger one correcting the mistake is miles better than leaving in incorrect output. It is about improving quality while allowing for faster speed most of the time. The tradeoff is that you consume more memory from having two models loaded vs one of them exclusively. If you just focus on one then it would make sense to reduce memory usage by just running the smaller model. | | |
| ▲ | acters 9 days ago | parent [-] | | Another caveat with this method is that both larger and smaller models need to behave very similar because a lot of the savings come from generating the necessary fluff around each detail such as grammar, formatting and words/letters that transition between each other. Unsurprisingly gpt-oss has both larger and smaller models that work very similarly! Both model sizes are so similar that even if getting a few wrong would not be slowing down the performance enough to equal the speed of the larger model(which is the worst case with this setup). We want the speed of the smaller model as much as possible. That is all |
| |
| ▲ | cwyers 9 days ago | parent | prev | next [-] | | So, the way speculative decoding works, the model begins predicting at the first wrong token, so you still get 'is' for free. | |
| ▲ | 9 days ago | parent | prev [-] | | [deleted] |
|
|
|
| ▲ | bhaney 9 days ago | parent | prev | next [-] |
| > How does the target model validate the draft tokens without running the inference as normal? It does run the inference as normal, just in parallel with the other inferences > if it is doing just that, I don't get the point Running inferences in parallel allows you to only read the model weights out of memory only once for N parallel inferences, as opposed to reading them out of memory N times for N serial inferences. Inference is massively bottlenecked by memory bandwidth to the tune of one or two orders of magnitude compared to compute, so this helps a lot. |
| |
| ▲ | littlestymaar 9 days ago | parent [-] | | > Inference is massively bottlenecked by memory bandwidth to the tune of one or two orders of magnitude compared to compute, so this helps a lot. Nitpick: it's only bottlenecked by memory bandwidth if the batch size is too low (that is: if you don't have many users calling the same model in parallel). Speculative decoding is just a way of running a single query as if it was parallel queries. |
|
|
| ▲ | joliu 9 days ago | parent | prev | next [-] |
| It does run inference, but on the batch of tokens that were drafted, akin to the prefill phase. So your draft model can decode N new tokens, then the real model does one inference pass to score the N new drafted tokens. Prefill is computation bound whereas decode is bandwidth bound, so in practice doing one prefill over N tokens is cheaper than doing N decode passes. |
|
| ▲ | jlebar 9 days ago | parent | prev | next [-] |
| Just want to suggest: Ask an LLM about it! If you have access to a reasoning model like o3, I've found it to be very helpful. I think this answer is as good as any of the human-generated ones in the thread so far, but the real power is that you can ask it follow-up questions. https://chatgpt.com/share/6894504f-4458-8008-a8c9-f371588259... |
|
| ▲ | robrenaud 9 days ago | parent | prev [-] |
| I think your core misunderstanding is that you are assuming K calls to generate 1 token is expensive as 1 call to generate K tokens. It is actually much more expensive to generate serially than even in small batches. |