Comment by vgatherps

1 year ago

https://github.com/jax-ml/jax

To expand on this link, this is probably the closest you're going to get to 'I'll "program" in LinAlg, and a JIT can compile it to whatever wonky way your HW requires.' right now. JAX implements a good portion of the Numpy interface - which is the most common interface for linear algebra-heavy code in Python - so you can often just write Numpy code, but with `jax.numpy` instead of `numpy`, then wrap it in a `jax.jit` to have it run on the GPU.

I was about to say that it is literally just Jax.

It genuinely deserves to exist alongside pytorch. It's not just Google's latest framework that you're forced to use to target TPUs.