Remix.run Logo
Athas an hour ago

I wrote a bit about this some years ago: https://futhark-lang.org/blog/2020-05-03-higher-order-parall... - but note that Jax isn't subject to these constraints; it's more like Accelerate and Futhark.

I'm not a Jax expert. Accelerate's 'map' allows for almost arbitrary sequential code - there is some fine print, because it's an embedded language, and the biggest fine print is that nested parallelism is not allowed. You can define your own Haskell-level higher order functions, and Accelerate will handle them just fine, because essentially all the Haskell-level computation is "compiled away" (by being run) before the Accelerate code is JIT-compiled at run-time. You can consider Haskell to be a meta-language in which you ultimately construct Accelerate program terms, and then those are compiled and run - not too dissimilar from how Jax does it, actually.

Recursion works, but for an uninteresting reason: the recursion is on the Haskell side, and will essentially be unrolled before Accelerate gets its hand on it. This allows you to do some fun things (like partially evaluating a ray tracer on its scene description), but it's often not what you want, and Accelerate provides some combinators (that look like higher-order Haskell functions) for expressing sequential looping.