Remix.run Logo
achierius an hour ago

ML compilers in particular go beyond even the level of stability you would expect from numerical programs. Due to how the SIMT model of thread/warp divergence works, the hardware heavily punishes unstable branches. E.g. if you have 32 threads taking a branch then recoalescing on a barrier -- if they all go the same direction then they can go down the execution pipe as a single bundle, but if 1 takes it while 31 don't, then that's 2x the ex-pipe usage by default (and if you have e.g. a computed-branch, performance goes out the window). Consequently, the whole stack is built around the expectation of stable control flow, even to the detriment of performance (from a local perspective).

ML frameworks even take advantage of this to compute, ahead-of-time, how much memory will be used at different points in the program graph, and thereafter schedule memcpy's to make space as necessary. Of course this only works for well-behaved program classes, but e.g. most LLM architectures fit into that category. Interestingly MoE models don't, since they require data-dependent control flow, thus the recent push towards accommodating dynamism in frameworks (like JAX, which until ~recently couldn't handle it at all).