Comment by elchananHaas
3 days ago
First, I think this is really cool. Its great to see novel generative architectures.
Here are my thoughts on the statistics behind this. First, let D be the data sample. Start with the expectation of -Log[P(D)] (standard generative model objective).
We then condition on the model output at step N.
- Expectation of Log[Sum over model outputs at step N{P(D | model output at step N) * P(model output at step N)}]
Now use Jensen's inequality to transform this to
<= - expectation of Sum over model outputs at step N{Log[P(D | model output at step N) * P(model output at step N)]}
Apply Log product to sum rule
= - expectation of Sum over model outputs at step N {Log(P(D | model output at step N)) + Log(P(model output at step N))}
If we assume there is some normally distributed noise we can transform the first term into the standard L2 objective.
= - expectation of Sum over model outputs at step N {L2 distance(D, model output at step N) + Log(P(model output at step N))}
Apply linearity of expectation
= Sum over model outputs at step N [expectation of{L2 distance(D, model output at step N)}] - Sum over model outputs at step N [expectation of {Log(P(model output at step N))}]
and the summations can be replaced with sampling
= expectation of {L2 distance(D model output at step N)} - expectation of {Log(P(model output at step N))}]
Now, focusing on just the - expectation of Log(P(sampled model output at step N)) term.
= - expectation of Log[P(model output at step N)]
and condition on the prior step to get
= - expectation of Log[Sum over possible samples at N-1 of (P(sample output at step N| sample at step N - 1) * P(sample at step N - 1))]
Now, for each P(sample at step T | sample at step T - 1) this is approximately equal to 1/K. This is enforced by the Split-and-Prune operations which try to keep each output sampled at roughly equal frequencies.
So this is approximately equal to
≃ - expectation of Log[Sum over possible samples at N-1 of (1/K * P(possible sample at step N - 1))]
And you get an upper bound by only considering the actual sample.
<= -Log[1/K * expectation of P(actual sample at step N - 1))]
And applying some log rules you get
= Log(K) - expectation of Log[P(sample at step N - 1)]
Now, you have (approximately) expectation of -Log[P(sample at step N)] <= Log(K) - expectation of Log[P(sample at step N - 1)]. You can repeatedly apply this transformation until step 0 to get
(approximately) expectation of -Log[P(sample at step N)] <= N * Log(K) - expectation of Log[P(sample at step 0)]
and WLOG assume that expectation of P(sample at step 0) is 1 to get
expectation of -Log[P(sample at step N)] <= N * Log(K)
Plugging this back into the main objective, we get (assuming the Split-and-Prune is perfect)
expectation of -Log[P(D)] <= expectation of {L2 distance(D, sampled model output at step N)} + N * Log(K)
And this makes sense. You are providing the model with an additional Log_2(K) bits of information every time you perform an argmin operation, so in total you have provided the model with N * Log_2(K) bits for information. However, this is constant so you can ignore it from the gradient based optimizer.
So, given this analysis my conclusions are:
1) The Split-and-Merge is a load-bearing component of the architecture with regards to its statistical correctness. I'm not entirely sure about how this fits with the gradient based optimizer. Is it working with the gradient based optimizer, fighting the gradient based optimizer, or somewhere in the middle? I think the answer to this question will strongly affect this approaches scalability. This will also need a more in-depth analysis to study how deviations from perfect splitting affect the upper bound on loss.
2) With regards to statistical correctness, the L2 distance between the output at step N and D is the only one that is important. The L2 losses in the middle layers can be considered auxiliary losses. Maybe the final L2 loss / L2 losses deeper in the model should be weighted more heavily? In final evaluation the intermediate L2 losses can be ignored.
3) Future possibilities could include some sort of RL to determine the number of samples K and depth N on a dynamic basis. Even a split with K=2 increases NLL loss by Log_2(2) = 1. For many samples after a given depth the increase in loss due to the additional information outweighs the decrease in L2 loss. This also points to another difficulty, it is hard to give fractional information in this Discrete Distribution Network architecture. In contrast, diffusion models and autoregressive models can handle fractional bits. This could be another point of future development.
A thought on why the intermediate L2 losses are important: In the early layers there is little information so the L2 loss will be high and images blurry. In much deeper layers the information from the argmins will dominate and there will be little information left to learn. The L2 losses from the intermediate layers help this by providing a good training signal when there is some information known about the target, but there are still large unknowns.
The model can be thought of as N Discrete Distribution Networks, one of each depth 1 to N, that are stacked on each other and are being trained simultaneously.
One more concern I noticed: This generative approach needs not only for each layer to select each output with uniform probability, but also for each layer to select each output with uniform probability regardless of the input.
This is the bad case I am concerned about.
Layer 1 -> (A, B) Layer 2 -> (C, D)
Lets say Layer 1 outputs A and B each with probability 1/2 (perfect split). Now, Layer 2 outputs C when it gets A as an input and D when it gets B as an input. Layer 2 is then outputting each output with probability 1/2, but it is not outputting each output with probability 1/2 when conditioned on the output of layer 1.
If this happens, the claim of exponential increase in diversity each layer breaks down.
It could be that the first-order approximation provided by Split-and-Prune is good enough. My guess though is that the gradient and the split-and-prune are helping each other to keep the outputs reasonably balanced on the datasets you are working on. The split and prune lets the optimization process "tunnel" though regions of the loss landscape that would make it hard to balance the classes.