Comment by p1esk
6 days ago
I’m not sure why you’re talking about efficiency when the question is “do sparse models work better than dense models?” The answer is no, they don’t.
Even the old LTH paper you cited trains a dense model and then tries to prune it without too much quality loss. Pruning is a well known method to compress models - to make them smaller and faster, not better.
Before we had proper GPUs everyone said the same thing about Neural Networks.
Current model architectures are optimized to get the most out of GPUs, which is why we have transformers dominating as they're mostly large dense matrix multiplies.
There's plenty of work showing transformers improve with inner dimension size but it's not feasible to scale them up further because it blows up parameter and activation sizes (including KV caches) so people to turn to low rank ("sparse") decompositions like MLA.
Lottery ticket hypothesis shows that most of the weights in current models are redundant and that we could get away with much smaller sparse models, but currently there's no advantage to doing so because on GPUs you still end up doing dense multiplies.
Plenty of mech interp work shows that models are forced to commingle different concepts to fit them into the "low" dimensional vector space. (https://www.neelnanda.io/mechanistic-interpretability/glossa...)
https://arxiv.org/abs/2210.06313
https://arxiv.org/abs/2305.01610
Yes, we know that large dense layers work better than small dense layers (up to a point). We also know how to train large dense models and then prune them. But we don’t know how to train large sparse models to be better than large dense models. If someone figures it out then we can talk about building hardware for it.
It isn't directly what you are asking for, but there is a similar relationship at work with respect to L_1 versus L_2 regularization. The number of samples required to train a model is O(log(d)) for L_1 and O(d) for L_2 where d is the dimensionality [1]. This relates to the standard random matrix results about how you can approximate high dimensional vectors in a log(d) space with (probably) small error.
At a very handwaving level, it seems reasonable that moving from L_1 to L_0 would have a similar relationship in learning complexity, but I don't think that has every been addressed formally.
[1] https://www.andrewng.org/publications/feature-selection-l1-v...