Conditional Flow Matching (CFM): a simulation-free training objective for continuous normalizing flows. We explore a few different flow matching variants and ODE solvers on a simple dataset. This blog is a brief overview of the mlx-cfm GitHub repo, for simple conditional flow matching in MLX with ODE solvers.

Background
Training: consider a smooth time-varying vector field u:[0,1]×Rd→Rd that governs the dynamics of an ordinary differential equation (ODE), dx=ut(x)dt. The probability path pt(x) can be generated by transporting mass along the vector field ut(x) between distributions over time, following the continuity equation
∂t∂p=−∇⋅(ptut).
However, the target distributions pt(x) and the vector field ut(x) are intractable in practice. Therefore, we assume the probability path can be expressed as a marginal over latent variables:
pt(x)=∫pt(x∣z)q(z)dz,
where pt(x∣z)=N(x∣μt(z),σt2I) is the conditional probability path, with a latent z sampled from a prior distribution q(z). The dynamics of the conditional probability path are now governed by a conditional vector field ut(x∣z). We approximate this using a neural network, parameterizing the time-dependent vector field vθ:[0,1]×Rd→Rd. We train the network by regressing the conditional flow matching loss:
LCFM(θ)=Et,q(z),pt(x∣z)∥vθ(t,x)−ut(x∣z)∥2,
such that t∼U(0,1),z∼q(z),andxt∼pt(x∣z). But, how do we compute ut(x∣z)? Well, assuming a Gaussian probability path, we have a unique vector field (Theorem 3; Lipman et al. 2023) given by,
ut(x∣z)=σt(z)σ˙t(z)(x−μt(z))+μ˙t(z),
where μ˙ and σ˙ are the time derivatives of the mean and standard deviation. If we consider z≡(x0,x1) and q(z)=q0(x0)q1(x1) with
μt(z)σt(z)=tx1+(1−t)x0,=σ>0,
then we have independent conditional flow matching (Tong et al. 2023) with the resulting conditional probability path and vector field
pt(x∣z)ut(x∣z)=N(x∣tx1+(1−t)x0,σ2),=x1−x0.
Alternatively, the variance-preserving stochastic interpolant (Albergo & Vanden-Eijnden 2023) has the form
μt(z)=cos(πt/2)x0+sin(πt/2)x1andσt(z)=0,ut(x∣z)=2π(cos(πt/2)x1−sin(πt/2)x0).
Sampling: now that we have our vector field, we can sample from our prior x∼q0(x), and run a forward ODE solver (e.g., fixed Euler or higher-order, adaptive Dormand–Prince) generally defined by
xt+Δ=xt+vθ(t,xt)Δ,
for t steps between 0 and 1.
Getting Started
Check out the repo on GitHub to recreate the figure at the top of this post! 😄