Comment by WithinReason
4 days ago
I would like to know your thoughts on using 2/3 of such a small the model's size for embeddings. What would be different if you used a byte-level vocabulary and spent the parameter budget on transformer parameters instead? I think you would lose performance (tok/s) but might gain accuracy.
At this small scale the embeddings indeed were a big focus. Consider this thought process.
The tokens themselves are a form of compression. Lets say we have the word "WaffleHouse", character level this would be 11 tokens, but with an embedder this would be perhaps 2 or 3 tokens (I didn't actually run through the tokenizer but we could verify precisely). This matters a lot for on device processing especially.
So while we could get more intelligence out of the model by bumping up the "knowledge" parameters, the device would need to process more input and output tokens.
Another advantage on small devices is the embeddings are just a lookup table which requires little to no computation. Its the rest of the parameters that have the expensive matrix multplications, so if we increased those we'd also be increasing the number of FLOPs needed for a forward pass.
This blog post explains it well. https://www.adamcasson.com/posts/transformer-flops
So all this to say is there are definite tradeoffs between model size, performance on evals, and compute cost. We ran many internal experiments with different choices to see could work well, and then picked what we believed work will best for the open community.
How would this matrix get trained with PyTorch? I currently have a toy Transformer network - I ended up marking the matrix as sparse and using SparseAdam - gives a bit of a performance boost, but at the same time I can't use torch.compile() on the fetch from this matrix.
Does Gemma use any specific scheme to compress embeddings? Which have you considered?
For instance, it's well-known that transformer embeddings tend to form clusters. Have you considered splitting the embedding table into "cluster centroid" and "offset from centroid" tables, where the later would presumably have a smaller range and precision?
Beautiful writeup! Thanks for your service!
Makes sense, thank you.