Remix.run Logo
RossBencina 11 hours ago

Can you suggest a good reference for understanding which algorithms map well onto the regular grid systolic arrays used by TPUs? The fine article says dese matmul and convolution are good, but is there anything else? Eigendecomposition? SVD? matrix exponential? Solving Ax = b or AX = B? Cholesky?

cdavid 9 hours ago | parent | next [-]

SVD/eigendecomposition will often boil down to making many matmul (e.g. when using Krylov-based methods, e.g. Arnoldi, Krylov-schur, etc.), so I would expect TPU to work well there. GMRES, one method to solve Ax = b is also based on Arnoldi decomp.

WithinReason 11 hours ago | parent | prev | next [-]

Anything that you can express as 128x128 (but ideally much larger) dense matrix multiplication and nothing else

musebox35 10 hours ago | parent | prev [-]

I think https://jax-ml.github.io/scaling-book/ is one of the best references to go through. It details how single device and distributed computations map to TPU hardware features. The emphasis is on mapping the transformer computations, both forwards and backwards, so requires some familiarity with how transformer networks are structured.