Remix.run Logo
musebox35 12 hours ago

I think https://jax-ml.github.io/scaling-book/ is one of the best references to go through. It details how single device and distributed computations map to TPU hardware features. The emphasis is on mapping the transformer computations, both forwards and backwards, so requires some familiarity with how transformer networks are structured.