Comment by bbminner
8 hours ago
I still consider jax.vmap to be a little miracle: fn2 = vmap(fn, (1,2)), if i remember correctly, traverces the computation graph of fn and correctly broacasts all operations in a way that ensures that fn2 acts like fn applied in a loop across the second dimension of the first argument (but accelerated, has auto-gradients, etc).
No comments yet
Contribute on Hacker News ↗