Comment by augment_me
11 hours ago
TLDR:
Authors realize that global row-wise dependent functions like RMSNorm/LayerNorm have baked-in scales that are commutative in certain setups, so they can be moved out after a subsequent projection and be partially aggregated on tiles of rows.
So ((W1 @ gamma * globally_computed_scale) * W2 can be written as (W1 @ gamma * W2) * globally_computed_scale as long as we have row-only interactions for the scale.
This was usually not done before because left-to-right graph compilers like torch.compile can't assume that a global row-wise reduction between GEMMs can be commutative.
No comments yet
Contribute on Hacker News ↗