Comment by big-chungus4

1 day ago

How does x.cos().cos() work faster than doing two cos calls separately? Like the first cos call returns a tensor either way, the only difference is that it's not assigned to a variable. But how is it even possible know that difference in python?

The author forgot to add "fused" here, like they did in other parts of the same section.

Non-fused:

  foreach i
    y[i] = cos(x[i])
  foreach i
    z[i] = cos(y[i])

Fused, no intermediate variable:

  foreach i
    t = cos(x[i])
    z[i] = cos(t)

The temporary "t" doesn't leave the GPU. Sweeping the array twice makes you twice as dependent on memory bandwidth.

It’s really not a concept you can express in idiomatic Python very easily. This comes from the actual generated assembly involving copies from global GPU memory into registers (slow, bandwidth saturates quickly) and back in between the cosines. If you can avoid the intermediate roundtrip that cuts the cost approximately in half.

Yeah, that part should not be read literally; `x.cos().cos()` and `x1 = x.cos(); x2 = x1.cos()` both launch the same number of kernels (two in unfused/eager mode, one in fused/torch.compile, see this test notebook [1]). I think the author chained the two cos calls to symbolize the idea of combining them (without exposing the intermediate result), but chaining the two cos calls doesn't literally trigger operator fusion.

[1] https://colab.research.google.com/drive/13a4Y-ko6QLMPAhBz64c...