Conditional Flow Matching (CFM): a simulation-free training objective for continuous normalizing flows. We explore a few different flow matching variants and ODE solvers on a simple dataset. This repo was inspired and adapted by the awesome work in TorchCFM and Torchdyn. 😄
Training: consider a smooth time-varying vector field that governs the dynamics of an ordinary differential equation (ODE), . The probability path can be generated by transporting mass1 along the vector field between distributions over time, following the continuity equation
However, the target distributions and the vector field are intractable in practice. Therefore, we assume the probability path can be expressed as a marginal over latent variables:
where is the conditional probability path, with a latent sampled from a prior distribution . The dynamics of the conditional probability path are now governed by a conditional vector field . We approximate this using a neural network, parameterizing the time-dependent vector field . We train the network by regressing the conditional flow matching loss:
such that . But, how do we compute ? Well, assuming a Gaussian probability path, we have a unique vector field (Theorem 3; Lipman et al. 2023) given by,
where and are the time derivatives of the mean and standard deviation. If we consider and with
then we have independent conditional flow matching (Tong et al. 2023) with the resulting conditional probability path and vector field
Alternatively, the variance-preserving stochastic interpolant (Albergo & Vanden-Eijnden 2023) has the form
Sampling: now that we have our vector field, we can sample from our prior , and run a forward ODE solver (e.g., fixed Euler or higher-order, adaptive Dormand–Prince) generally defined by
for steps between and .
Run with default params and save the result in media/*.png
:
python main.py --method vp --solver dopri5
main.py
: training and samplingmodels.py
: neural net definitiondatasets.py
: generate prior and target datacfm.py
: flow matching variantsodeint.py
: adaptive and fixed numerical integratorssolver.py
: solver definition for integratorInstall the dependencies (optimized for Apple silicon; yay for MLX!):
pip install -r requirements.txt
this is a very important thing! ↩