Score based diffusions explained in just one paragraph
Posted on June 19, 2022If you work in machine learning, then you will have noticed that score-based diffusion models are busy taking over the world: most notably through impressive projects like DALL·E 2 and Imagen. Correspondingly, the internet has become awash with how-tos and explainer posts for how score-based diffusions work. (e.g. one/two/three etc.)
Now these posts are generally pretty complicated, and I haven’t seen much intuition offered on score-based diffusions. So, it’s time to throw my hat in the ring. Here’s a one-paragraph explanation of diffusion models – without proofs – which I’ve not seen emphasised in the blog-o-sphere before.
Score-based diffusions
Consider the dataset (of images or whatever) as an empirical probability distribution $\nu$. Consider the simplest-possible SDE $y(0) \sim \nu; \mathrm{d}y(t) = \mathrm{d}w(t)$. Run this SDE for long enough and the data will have been corrupted into noise. Generation is performed by reversing this process, to turn noise back into data. Reversing the SDE produces the ODE $\mathrm{d}y(t) = -\frac{1}{2}\nabla_y \log p(y(t))\mathrm{d}t$, to be solved backward in time. Now here’s the crucial bit: this is just gradient descent to optimise a log-likelihood! (A good sample is one which maximises log-likelihood.) Finally, training means finding a model $s_\theta(y(t)) \approx \nabla_y \log p(y(t))$ that can be used to evaluate this ODE.
That’s it.
That’s the explanation.
Job done.
Various footnotes.
Yes, I’m aware that many SOTA models use the (overcomplicated) discrete-time formulation rather than the (much simpler) continuous-time formulation I’ve used here.
There are “stronger” ways to reverse an SDE. The above ODE only matches marginals but it’s also possible to reverse SDEs pathwise, i.e. for almost all Brownian paths. I actually don’t have a self-contained reference for this fact, but it’s immediate in the rough path theoretic formulation of SDEs. (The correct way to do stochastic calculus, I think.) The simplest + closest thing I know of is Theorem 5.10 of my thesis, which shows how to backpropagate through SDE solvers.
A while ago I wrote this JAX implementation of score based diffusions. For those of you getting started on this topic, this is the shortest and simplest implementation of score-based diffusions that I’ve seen anywhere.
In supervised learning, we don’t tend to use SGD. We use Adam. This is because SGD is just an ODE solver (the explicit Euler method) applied to the gradient flow equation, but we don’t care about finding a solution to the gradient flow equation. We just want the final steady state (which will be a minimum of our loss function), so we can use specialised steady-state-finding approaches instead, i.e. Adam. Perhaps the same is true here! So, research proposal: can we “solve” the inference-time ODE via Adam? And in doing so get more efficient sampling!
This line of thought is a can of worms. Besides Adam you can draw all kinds of links to Langevin dynamics, the use of MCMC/HMC samplers, steady-state-finding as a nonlinear optimisation problem, and so on. Much of the older literature did exactly this, so I’m bit mystified why the current literature has focused on either the explicit Euler method (looking at you, ugly discrete-time formulation) or on developing novel ODE/SDE solvers (which are surely less efficient).