Comment by mkaic
1 year ago
I strongly enjoy the simplicity of their "minGRU" architecture. It's basically just:
class MinGRU(nn.Module):
def __init__(self, token_size, hidden_state_size):
self.token_to_proposal = nn.Linear(token_size, hidden_size)
self.token_to_mix_factors = nn.Linear(token_size, hidden_size)
def forward(self, previous_hidden_state, current_token):
proposed_hidden_state = self.token_to_proposal(current_token)
mix_factors = torch.sigmoid(self.token_to_mix_factors(current_token))
return torch.lerp(proposed_hidden_state, previous_hidden_state, mix_factors)
And since the proposed hidden states and mix factors for each layer are both only dependent on the current token, you can compute all of them in parallel if you know the whole sequence ahead of time (like during training), and then combine them in linear time using parallel scan.
The fact that this is competitive with transformers and state-space models in their small-scale experiments is gratifying to the "best PRs are the ones that delete code" side of me. That said, we won't know for sure if this is a capital-B Breakthrough until someone tries scaling it up to parameter and data counts comparable to SOTA models.
One detail I found really interesting is that they seem to do all their calculations in log-space, according to the Appendix. They say it's for numerical stability, which is curious to me—I'm not sure I have a good intuition for why running everything in log-space makes the model more stable. Is it because they removed the tanh from the output, making it possible for values to explode if calculations are done in linear space?
EDIT: Another thought—it's kind of fascinating that this sort of sequence modeling works at all. It's like if I gave you all the pages of a book individually torn out and in a random order, and asked you to try to make a vector representation for each page as well as instructions for how to mix that vector with the vector representing all previous pages — except you have zero knowledge of those previous pages. Then, I take all your page vectors, sequentially mix them together in-order, and grade you based on how good of a whole-book summary the final vector represents. Wild stuff.
FURTHER EDIT: Yet another thought—right now, they're just using two dense linear layers to transform the token into the proposed hidden state and the lerp mix factors. I'm curious what would happen if you made those transforms MLPs instead of singular linear layers.
This architecture, on the surface, seems to preclude the basic function of recognizing sequences of tokens. At the very least, it seems like it should suffer from something like the pumping lemma: if [the ][cat ][is ][black ] results in the output getting close to a certain vector, [the ][cat ][is ][black ][the ][cat ][is ][black ][the ][cat ][is ][black ] should get even closer to that vector and nowhere close to a "why did you just repeat the same sentence three times" vector? Without non-linear mixing between input token and hidden state, there will be a lot of linear similarities between similar token sequences...
Counterpoint: the hidden state at the beginning of ([the][cat][is][black]) x 3 is (probably) initialized to all zeros, but after seeing those first 4 tokens, it will not be all zeros. Thus, going into the second repetition of the sentence, the model has a different initial hidden state, and should exhibit different behavior. I think this makes it possible for the model to learn to recognize repeated sequences and avoid your proposed pitfall.
The new hidden state after the first repetition will just be a linear combination between zero and what the non-recurring network outputs. After more repetitions, it will be closer to what the network outputs.
I don't think it's a capital-B Breakthrough, but recurrent networks are everywhere, and a simplification that improved training and performance clears the stage to build back complexity up again to even higher hights.
Log space is important if the token probabilities span a large range of values (powers). There is a reason that maximum likelihood fitting is always performed with log likelihoods.