BERT is just a single text diffusion step

3 months ago (nathan.rs)

To my knowledge this connection was first noted in 2021 in https://arxiv.org/abs/2107.03006 (page 5). We wanted to do text diffusion where you’d corrupt words to semantically similar words (like “quick brown fox” -> “speedy black dog”) but kept finding that masking was easier for the model to uncover. Historically this goes back even further to https://arxiv.org/abs/1904.09324, which made a generative MLM without framing it in diffusion math.

  • It goes further back than that. In 2014, Li Yao et al (https://arxiv.org/abs/1409.0585) drew an equivalence between autoregressive (next token prediction, roughly) generative models and generative stochastic networks (denoising autoencoders, the predecessor to difussion models). They argued that the parallel sampling style correctly approximates sequential sampling.

    In my own work circa 2016 I used this approach in Counterpoint by Convolution (https://arxiv.org/abs/1903.07227), where we in turn argued that despite being an approximation, it leads to better results. Sadly being dressed up as an application paper, we weren't able to draw enough attention to get those sweet diffusion citations.

    Pretty sure it goes further back than that still.

Back when BERT came out, everyone was trying to get it to generate text. These attempts generally didn't work, here's one for reference though: https://arxiv.org/abs/1902.04094

This doesn't have an explicit diffusion tie in, but Savinov et al. at DeepMind figured out that doing two steps at training time and randomizing the masking probability is enough to get it to work reasonably well.

To me, the diffusion-based approach "feels" more akin to whats going on in an animal brain than the token-at-a-time approach of the in-vogue LLMs. Speaking for myself, I don't generate words one a time based on previously spoken words; I start by having some fuzzy idea in my head and the challenge is in serializing it into language coherently.

  • > the token-at-a-time approach of the in-vogue LLMs. Speaking for myself, I don't generate words one a time based on previously spoken words

    Autoregressive LLMs don't do that either actually. Sure with one forward pass you only get one token at a time, but looking at what is happening in the latent space there are clear signs of long term planning and reasoning that go beyond just the next token.

    So I don't think it's necessarily more or less similar to us than diffusion, we do say one word at a time sequentially, even if we have the bigger picture in mind.

    • To take a simple example, let’s say we ask an autoregressive model a yes/no factual question like “is 1+1=2?”. Then, we force the LLM to start with the wrong answer “No, “ and continue decoding.

      An autoregressive model can’t edit the past. If it happens to sample the wrong first token (or we force it to in this case), there’s no going back. Of course there can be many more complicated lines of thinking as well where backtracking would be nice.

      “Reasoning” LLMs tack this on with reasoning tokens. But the issue with this is that the LLM has to attend to every incorrect, irrelevant line of thinking which is at a minimum a waste and likely confusing.

      As an analogy, in HN I don’t need to attend to every comment under a post in order to generate my next word. I probably just care about the current thread from my comment up to the OP. Of course a model could learn that relationship but that’s a huge waste of compute.

      Text diffusion solves the whole problem entirely by allowing the model to simply revise the “no” to a “yes”. Very simple.

    • If a process is necessary for performing a task, (sufficiently-large) neural networks trained on that task will approximate that process. That doesn't mean they're doing it anything resembling efficiently, or that a different architecture / algorithm wouldn't produce a better result.

      24 replies →

    • You're right that there is long-term planning going on, but that doesn't contradict the fact that an autoregressive LLM does, in fact, literally generate words one at a time based on previously spoken words. Planning and action are different things.

    • There is some long term planning going on, but bad luck when sampling the next token can take the process out of rails, so it's not just an implementation detail.

  • You 100% do pronounce or write words one at a time sequentially.

    But before starting your sentence, you internally formulate the gist of the sentence you're going to say.

    Which is exactly what happens in LLMs latent space too before they start outputting the first token.

    • I'm curious what makes you so confident on this? I confess I expect that people are often far more cognizant of the last thing that the they want to say when they start?

      I don't think you do a random walk through the words of a sentence as you conceive it. But it is hard not to think people don't center themes and moods in their mind as they compose their thoughts into sentences.

      Similarly, have you ever looked into how actors learn their lines? It is often in a way that is a lot closer to a diffusion than token at a time.

      17 replies →

    • Like most people I jump back and forth when I speak, disclaiming, correcting, and appending to previous utterances. I do this even more when I write, eradicating entire sentences and even the ideas they contain, within paragraphs that which by the time they were finished the sentence seemed unnecessary or inconsistent.

      I did it multiple times while writing this comment, and it is only four sentences. The previous sentence once said "two sentences," and after I added this statement it was changed to "four sentences."

    • For most serious texts I start with a tree outline, before I engage my literary skills.

    • >You 100% do pronounce or write words one at a time sequentially.

      It's statements like these that make me wonder if I am the same species as everyone else. Quite often, I've picked adjectives and idioms first, and then fill in around them to form sentences. Often because there is some pun or wordplay, or just something that has a nice ring to it, and I want to lead my words in that direction. If you're only choosing them one at a time and sequentially, have you ever considered that you might just be a dimwit?

      It's not like you don't see this happening all around you in others. Sure you can't read minds, but have you never once watched someone copyedit something they've written, where they move phrases and sentences around, where they switch out words for synonyms, and so on? There are at least dozens of fictional scenes in popular media, you must have seen one. You have to have noticed hints at some point in your life that this occurs. Please. Just tell me that you spoke hastily to score internet argument points, and that you don't believe this thing you've said.

      2 replies →

    • (Just to expand on that, it's true not just the for the first token. There's a lot of computation, including potentially planning ahead, before each token outputted.)

      That's why saying "it's just predicting the next word", is a misguided take.

  • Interpretability research has found that Autoregressive LLMs also plan ahead what they are going to say.

  • The fact that you’re cognitively aware is evidence that this is nowhere near diffusion. More like rumination or thinking tokens, if we absolutely had to find a present day LLM metaphor

  • It feels like a mix of both to me, diffusion "chunks" being generated in sequence. As I write this comment, I'm deciding on the next word while also shaping the next sentence, like turning a fuzzy idea into a clear sequence.

  • Maybe it's two different modes of thinking. I can have thoughts that coalesce from the ether, but also sometimes string a thought together linearly. Brains might be able to do both.

  • I feel completely the opposite way.

    When you speak or do anything, you focus on what you’re going do next. Your next action. And at that moment you are relying on your recent memory, and things you have put in place while doing the overall activity (context).

    In fact what’s actually missing from AI currently is simultaneous collaboration, like a group of people interacting — it is very 1 on 1 for now. Like human conversations.

    Diffusion is like looking at a cloud and trying to find a pattern.

  • LLMs are notoriously bad at reflecting on how they work and I feel like humans are probably in the same boat

  • > Speaking for myself, I don't generate words one a time based on previously spoken words

    This is a common but fundamentally a weird assumption people have about neurology where they think that what they consciously perceive has some bearing on what's actually happening at the operational or physical level.

It feels like it would make more sense to allow the model to do Levenshtein-like edits instead of just masking and filling in the masked tokens. It seems that intuitively it's really hard in this diffusion setup to just swap one word with a longer but better synonym towards the end, because there's no way to shift everything to the right afterwards.

I love seeing these simple experiments. Easy to read through quickly and understand a bit more of the principles.

One of my stumbling blocks with text diffusers is that ideally you wouldn’t treat the tokens as discrete but rather probably fields. Image diffusers have the natural property that a pixel is a continuous value. You can smoothly transition from one color to another. Not so with tokens. In this case they just do a full replacement. You can’t add noise to a token, you have to work in the embedding space. But how can you train embeddings directly? I found a bunch of different approaches that have been tried but they are all much more complicated than the image based diffusion process.

When text diffusion models started popping up I thought the same thing as this guy (“wait, this is just MLM”) though I was thinking more MaskGIT. The only thing I could think of that would make it “diffusion” is if the model had to learn to replace incorrect tokens with correct ones (since continuous diffusion’s big thing is noise resistance). I don’t think anyone has done this because it’s hard to come up with good incorrect tokens.

  • I've played around with MLM at the UTF8 byte level to train unorthodox models on full sequence translation tasks. Mostly using curriculum learning and progressive random corruption. If you just want to add noise, setting random indices to random byte values might be all you need. For example:

    Feeding the model the following input pattern:

      [Source UTF8 bytes] => [Corrupted Target UTF8 bytes]
    

    I expect it to output the full corrected target bytes. The overall training process follows this curriculum:

      Curriculum Level 0: Corrupt nothing and wait until the population/model masters simple repetition.
    
      Curriculum Level 1: Corrupt 1 random byte per target and wait until the population/model stabilizes.
    
      Curriculum Level N: Corrupt N random bytes per target. 
      
      Rinse & repeat until all target sequences are fully saturated with noise.
    

    An important aspect is to always score the entire target sequence each time so that we build upon prior success. If we just evaluate on the masked tokens, the step between each level of difficulty would be highly discontinuous in the learning domain.

    Ive stopped caring about a lot of the jargon & definitions. I find that trying to stick things into buckets like "is this diffusion" gets in the way of thinking and trying new ideas. I am more concerned with whether or not it works than what it is called.

    • The problem with that is we want the model to learn to deal with its own mistakes. With continuous diffusion mistakes mostly look like noise, but with what you’re proposing mistakes are just incorrect words that are semantically pretty similar to the real text, so the model wouldn’t learn to consider those “noise”. The noising function would have to generate semantically similar text (e.g., out of order correct tokens maybe? Tokens from a paraphrased version?)

I've really wanted to fine tune an inline code completion model to see if I could get at all close to cursor (I can't, but it would be fun), but as far as I know there are no open diffusion models to use as a base, and especially not any that would be good as a base. Hopefully something comes out soon that is viable for it

To me part of the appeal of image diffusion models was starting with random noise to produce an image. Why do text diffudion models start with a blank slate (ie all "masked" tokens), instead of with random tokens?

  • It depends on what you want the model to do for you. If you want the model to complete text, then you would provide the input text unmasked followed by a number of masked tokens that it's the model's job to fill in. Perhaps your goal is to have the model simply make edits to a bit of code. In that case, you'd mask out the part that it's supposed to edit and the model would iteratively fill in those masked tokens with generated tokens.

    One of the powerful abilities of text diffusion models is supposedly in coding. Auto-regressive LLMs don't inherently come with the ability to edit. They can generate instructions that another system interprets as editing commands. Being able to literally unmask the parts you want to edit is a pretty powerful paradigm that could improve if not just speed up many coding tasks.

    I suspect that elements of text diffusion will be baked into coding models like GPT Codex (if they aren't already). There's no reason you could not train a diffusion output head specifically designed for code editing and the same model is able to make use of that head when it makes the most sense to do so.

  • They don't all do that. There's many approaches being experimented on.

    Some start with random tokens, or with masks, others even start with random vector embeddings.

The problem with this approach to text generation is that it's still not flexible enough. If during inference the model changes its mind and wants to output something considerably different it can't because there are too many tokens already in place.

  • That's not true, you could just have looked at the first gif animation in the OP and seen that tokens disappear, the only part that stays untouched is the prompt, adding noise is part of the diffusion process and the code that does it is even posted in the article (ctrl+f "def diffusion_collator").

I think another easy improvement to this diffusion model would be for the logprobs to also affect the chance of a token being turned into a mask. So higher confidence tokens should have less of a chance to be pruned, should converge faster. I wonder if backprop would be able exploit that. (I'm not an ML engineer).