Comment by yagizdegirmenci
2 months ago
Google introduced this idea in 2022 with "FNet: Mixing Tokens with Fourier Transforms" [0].
Later they found out that, performance of their TPU(s) for matrix multiplication was faster than FFT in the most scenarios.
Referenced in this paper:
"Overall, while approaches such as FNet, Performer, and sparse transformers demonstrate that either fixed or approximate token mixing can reduce computational overhead, our adaptive spectral filtering strategy uniquely merges the efficiency of the FFT with a learnable, input-dependent spectral filter. This provides a compelling combination of scalability and adaptability, which is crucial for complex sequence modeling tasks."
And a comparison section after that.
Except that the paper is written as if they discovered that you can use an fft for attention. They even have a "proof". It's in the title. Then you discover everyone already knew this and all they do is as some extra learnable parameters.
Pretty lame.
Search engines don't always turn up prior art the way you'd like. Simple jargon discrepancies can cause a lot of mischief. Though I'm sure a case could be made about it being confirmation bias. It's hard to get people to search in earnest for bad news. If it's not in your face they declare absence of evidence as evidence of absence.
That seems like an odd comparison, specialty hardware is often better, right?
Hey, do DSPs have special hardware to help with FFTs? (I’m actually asking, this isn’t a rhetorical question, I haven’t used one of the things but it seems like it could vaguely be helpful).
Xilinx has a very highly optimized core for the FFT. You are restricted to power of 2 sizes. Which usually isn't a problem because its fairly common to zero pad an FFT anyway to avoid highly aliased (i.e. hard-edges) binning.
The downside of implementing directly in hardware, the size would be fixed.
They usually have dedicated acceleration hardware, yes: https://www.ti.com/lit/an/sprabb6b/sprabb6b.pdf?ts=174057874...
yes, almost all DSPs I know have native HW supports for FFT, since it's the bread and butter for signal processing
I remember hearing about logic to help with deinterleaving the results of the butterfly network after the FFT is done.
Yeah, bit-reversed addressing mode as seen on the dsPIC is an example of this.
(Discrete) Fast Fourier Transform implementations:
https://fftw.org/ ; FFTW: https://en.wikipedia.org/wiki/FFTW
gh topic: fftw: https://github.com/topics/fftw
xtensor-stack/xtensor-fftw is similar to numpy.fft: https://github.com/xtensor-stack/xtensor-fftw
Nvidia CuFFTW, and/amd-fftw, Intel MKL FFTW
NVIDIA CuFFT (GPU FFT) https://docs.nvidia.com/cuda/cufft/index.html
ROCm/rocFFT (GPU FFT) https://github.com/ROCm/rocFFT .. docs: https://rocm.docs.amd.com/projects/rocFFT/en/latest/
AMD FFT, Intel FFT: https://www.google.com/search?q=AMD+FFT , https://www.google.com/search?q=Intel+FFT
project-gemmi/benchmarking-fft: https://github.com/project-gemmi/benchmarking-fft
"An FFT Accelerator Using Deeply-coupled RISC-V Instruction Set Extension for Arbitrary Number of Points" (2023) https://ieeexplore.ieee.org/document/10265722 :
> with data loading from either specially designed vector registers (V-mode) or RAM off-the-core (R-mode). The evaluation shows the proposed FFT acceleration scheme achieves a performance gain of 118 times in V-mode and 6.5 times in R-mode respectively, with only 16% power consumption required as compared to the vanilla NutShell RISC-V microprocessor
"CSIFA: A Configurable SRAM-based In-Memory FFT Accelerator" (2024) https://ieeexplore.ieee.org/abstract/document/10631146
/? dsp hardware FFT: https://www.google.com/search?q=dsp+hardware+fft
GPU saw a 10% improvement over the TPU
>The TPU is so inefficient at FTs that the researchers did not use the FFT algorithm on sequences < 4096 elements, instead opting for a quadratic-scaling FT implementation using a pre-computed DFT matrix.
> on an Nvidia Quadro P6000 GPU, the FT was responsible for up to 30% of the inference time on the FNet architecture [0]
This company [0] claimed in 2021 they could squash inference time by 40% if google would use their light chips on TPU. Perhaps more if FFTNet does more heavy lifting.
[0]: https://scribe.rip/optalysys/attention-fourier-transforms-a-...
I have been entertaining myself a bit lately by thinking about the ways in which some improvements to a design are very, very interesting to people when it takes 1.2 machines to do a task, not worth paying attention to when it's 6 machines to do the task, and suddenly very interesting again when it's 120 machines to do the task. There's that weird saddle point in the middle where I cannot get anyone else interested in my 20% resource improvements. It's just crickets.
4096 tokens is pretty short by today's standards for transformers too.
I would guess that the FFT scales better as you increase the number of tokens in the context window. Interesting Google's models outperform their competitors on context size.
I'm glad someone else had the same thought. I have been wondering what their "secret sauce" is for a while given how their model doesn't degrade for long-context nearly as much as other LLMs that are otherwise competitive. It could also just be that they used longer-context training data than anyone else though.
> faster than FFT
Not only that, but FFT support on TPU has always been best effort. Last I tried this, there were serious precision issues.
Reminds me how CNNs were also not implemented using FFTs.
But this would only work on a very limited number of tokens, right?
Reference for the later part?
The section "3.3 Implementation" is mostly about hardware level speedups, which basically says:
On GPU(s) FFT is consistently faster, but in TPU(s), for shorter sequences matrix multiplication was faster.
Yeah but a comparison in power utilization is needed too. You can build hardware that is better than a GPU at something i.e MatMul being really efficient and fast. However, actual FFT hardware would annihilate power and speed at large enough n. Simply because the number of multiplications MatMul does is O(n^3) as opposed to the O(n log n) multiplies that FFT does (complex verse real multiplies with holding).
5 replies →