Remix.run Logo
crazygringo 14 days ago

Very curious how this compares to JAX [1].

JAX lets you write Python code that executes on Nvidia, but also GPUs of other brands (support varies). It similarly has drop-in replacements for NumPy functions.

This only supports Nvidia. But can it do things JAX can't? It is easier to use? Is it less fixed-size-array-oriented? Is it worth locking yourself into one brand of GPU?

[1] https://github.com/jax-ml/jax

odo1242 14 days ago | parent [-]

Well, the idea is that you’d be writing low level CUDA kernels that implement operations not already implemented by JAX/CUDA and integrate them into existing projects. Numba[1] is probably the closest thing I can think of that currently exists. (In fact, looking at it right now, it seems this effort from Nvidia is actually based on Numba)

[1]: https://numba.readthedocs.io/en/stable/cuda/overview.html