▲ | crazygringo 14 days ago | |
Very curious how this compares to JAX [1]. JAX lets you write Python code that executes on Nvidia, but also GPUs of other brands (support varies). It similarly has drop-in replacements for NumPy functions. This only supports Nvidia. But can it do things JAX can't? It is easier to use? Is it less fixed-size-array-oriented? Is it worth locking yourself into one brand of GPU? | ||
▲ | odo1242 14 days ago | parent [-] | |
Well, the idea is that you’d be writing low level CUDA kernels that implement operations not already implemented by JAX/CUDA and integrate them into existing projects. Numba[1] is probably the closest thing I can think of that currently exists. (In fact, looking at it right now, it seems this effort from Nvidia is actually based on Numba) [1]: https://numba.readthedocs.io/en/stable/cuda/overview.html |