Metal only supported via wgpu due to pulling in CubeCL as a dependency is... not ideal.
If you need to extend it, you'll have to write custom Autodiff rules with the ergonomics of a Rust front-end and a JAX-like backend. :(