Remix.run Logo
liuliu 2 days ago

But that's only for prefilling right? Or is it beneficial for decoding too (I guess you can do KV lookup on shards, not sure how much speed-up that will be though).

zackangelo 2 days ago | parent | next [-]

No you use tensor parallelism in both cases.

The way it typically works in an attention block is: smaller portions of the Q, K and V linear layers are assigned to each node and are processed independently. Attention, rope norm etc is run on the node-specific output of that. Then, when the output linear layer is applied an "all reduce" is computed which combines the output of all the nodes.

EDIT: just realized it wasn't clear -- this means that each node ends up holding a portion of the KV cache specific to its KV tensor shards. This can change based on the specific style of attention (e.g., in GQA where there are fewer KV heads than ranks you end up having to do some replication etc)

liuliu 2 days ago | parent [-]

I usually call it "head parallelism" (which is a type of tensor parallelism, but paralllelize for small clusters, and specific to attention). That is what you described: sharding input tensor by number of heads and send to respective Q, K, V shard. They can do Q / K / V projection, rope, qk norm whatever and attention all inside that particular shard. The out projection will be done in that shard too but then need to all reduce sum amongst shard to get the final out projection broadcasted to every participating shard, then carry on to do whatever else themselves.

I am asking, however, is whether that will speed up decoding as linearly as it would for prefilling.

awnihannun 2 days ago | parent [-]

Right, my comment was mostly about decoding speed. For prefill you can get a speed up but there you are less latency bound.

In our benchmarks with MLX / mlx-lm it's as much as 3.5x for token generation (decoding) at batch size 1 over 4 machines. In that case you are memory bandwidth bound so sharding the model and KV cache 4-ways means each machine only needs to access 1/4th as much memory.

liuliu 2 days ago | parent [-]

Oh! That's great to hear. Congrats! Now, I want to get the all-to-all primitives ready in s4nnc...

monster_truck 2 days ago | parent | prev [-]

Even if it wasn't outright beneficial for decoding by itself, it would still allow you to connect a second machine running a smaller, more heavily quantized version of the model for speculative decoding which can net you >4x without quality loss