← Back to context

Comment by albertzeyer

1 day ago

Why do you say FlexAttention is too buggy? I have heard about a lot of successful usages of it, and never heard about any such problems.

Also note, depending on your model dimensions and sequence lengths, often the attention computation plays only a minor role (maybe 10% overall or so), and the MLP computation dominates.

Last time I tried it I encountered both showstopper bugs (it was completely obviously broken) and subtle correctness bugs (it looked like it was working, but since I'm paranoid I have unit tests for everything and numerically the errors were too big compared to what you'd get with eager attention or Flash Attention), and it was too slow for my taste compared to Flash Attention so I just dropped it. And I wasn't even doing anything super exotic with it.

Maybe it's better now, but I'd still consider using FlexAttention without a corresponding unit test checking its accuracy against an equivalent eager implementation completely irresponsible.

  • What unit tests do you use for nn modules and how do you come up with them?

    • Unit tests which test random inputs across different sizes (e.g. with different number of heads, head sizes, embedding dimensions, etc.) and compare two different implementations' output to each other (e.g. attention implemented manually in an eager fashion vs a bunch of accelerated attention libraries).

      Also more integration-like tests where I take an already pretrained model, load it using an established library (e.g. Huggingface Transformers) and I also load the very same checkpoint into my reimplementation (where I vary the implementation, e.g. swap the attention implementation) and compare the outputs. Funnily enough, I recently even found a bug in HF's Transformers this way when I updated to a newer version and my previously matching output was not matching anymore.