Comment by cavisne

7 months ago

From the paper it was a speedup on the XLA GPU kernel they wrote using Jax, which is probably not SOTA. I don't think Jax even has a official flash attention implementation.