Remix.run Logo
porridgeraisin 9 days ago

Let's say I want to run f2(f1(x)) where f1 and f2 are both a single pass through GPT4.

This takes 2 seconds time, assuming 1 second for every pass.

What I instead do is kick off f1(x) in another thread, and then run f2(g1(x)) where g1 is one pass through GPT-nano.

This takes 1 + 0.1 seconds, assuming gpt nano takes 0.1s for every pass. In this 1.1 seconds, the f1(x) that we kicked off in the 2nd thread would have finished (it takes 1 second).

So in 1.1 seconds we have available to us f1(x), f2(g1(x)), and we store the intermediate g1(x) as well

We compare g1(x) and f1(x)

If they were equal, i.e g1(x) = f1(x), then we have our answer = f2(g1(x)) in just 1.1s.

If they were not, we compute f2(output of f1(x) from 2nd thread) which takes 1 further second, bringing our total to 2.1s.

If the small model is equalling the big model in say 2/3 of cases, you will spend 2/3 * 1.1 + 1/3 * 2.1 = 1.433s on average for this computation. Without speculative decoding, it is always 2s.

magicalhippo 9 days ago | parent | next [-]

Thanks, very nice explanation, that makes perfect sense. I guess their graphics confused me for some reason and had me thinking all wrong.

Now I see they tried to point out the obvious thing which is to predict multiple tokens ahead, not just two as in your example.

arkmm 9 days ago | parent | prev [-]

This is a really great explanation.