["Genie 2: A large-scale foundation model" Parker-Holder et al]
["Generative AI for designing and validating easily synthesizable and structurally novel antibiotics" Swanson et al]
Probabilistic ML has made high dimensional inference tractable
1024x1024xTime
https://parti.research.google
A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says Welcome Friends!
Data
A PDF that we can optimize
Maximize the likelihood of the data
Maximize the likelihood of the training samples
Parametric Model
Training Samples
Trained Model
Evaluate probabilities
Low Probability
High Probability
Generate Novel Samples
Simulator
Generative Model
Generative Model
Simulator
GANS
VAEs
Normalizing
Flows
Diffusion Models
[Image Credit: https://lilianweng.github.io/posts/2018-10-13-flow-models/]
Base
Data
How is the bridge constrained?
Normalizing flows: Reverse = Forward inverse
Diffusion: Forward = Gaussian noising
Flow Matching: Forward = Interpolant
is p(x0) restricted?
Diffusion: p(x0) is Gaussian
Normalising flows: p(x0) can be evaluated
Is bridge stochastic (SDE) or deterministic (ODE)?
Diffusion: Stochastic (SDE)
Normalising flows: Deterministic (ODE)
(Exact likelihood evaluation)
sampled from a Gaussian distribution with mean 0 and variance 1
How is
distributed?
Base distribution
Target distribution
Invertible transformation
Normalizing flows in 1934
[Image Credit: "Understanding Deep Learning" Simon J.D. Prince]
Bijective
Sample
Evaluate probabilities
Probability mass conserved locally
Image Credit: "Understanding Deep Learning" Simon J.D. Prince
Splines
Issues NFs: Lack of flexibility
Neural Network
Sample
Evaluate probabilities
Forward Model
Observable
Dark matter
Dark energy
Inflation
Predict
Infer
Parameters
Inverse mapping
Fault line stress
Plate velocity
Normalizing flow
Continuity Equation
[Image Credit: "Understanding Deep Learning" Simon J.D. Prince]
Chen et al. (2018), Grathwohl et al. (2018)
Generate
Evaluate Probability
Loss requires solving an ODE!
Diffusion, Flow matching, Interpolants... All ways to avoid this at training time
Assume a conditional vector field (known at training time)
The loss that we can compute
The gradients of the losses are the same!
["Flow Matching for Generative Modeling" Lipman et al]
["Stochastic Interpolants: A Unifying framework for Flows and Diffusions" Albergo et al]
Intractable
Continuity equation
[Image Credit: "Understanding Deep Learning" Simon J.D. Prince]
Sample
Evaluate probabilities
Reverse diffusion: Denoise previous step
Forward diffusion: Add Gaussian noise (fixed)
Prompt
A person half Yoda half Gandalf
Denoising = Regression
Fixed base distribution:
Gaussian
["A point cloud approach to generative modeling for galaxy surveys at the field level"
Cuesta-Lazaro and Mishra-Sharma
International Conference on Machine Learning ICML AI4Astro 2023, Spotlight talk, arXiv:2311.17141]
Base Distribution
Target Distribution
Simulated Galaxy 3d Map
Prompt:
Prompt: A person half Yoda half Gandalf
["CosmoFlow: Scale-Aware Representation Learning for Cosmology with Flow Matching" Kannan et al (in prep)]
Real or Fake?
Mean relative velocity
k Nearest neighbours
Pair separation
Pair separation
Varying cosmological parameters
Physics as a testing ground: Well-understood summary statistics enable rigorous validation of generative models
["Generalization in diffusion models arises from geometry-adaptive harmonic representations" Kadkhodaie et al (2024)]
["CosmoFlow: Scale-Aware Representation Learning for Cosmology with Flow Matching" Kannan et al (in prep)]
["CosmoFlow: Scale-Aware Representation Learning for Cosmology with Flow Matching" Kannan et al (in prep)]
Gaussian
MNIST
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
# Linear
x = nn.Dense(features=64)(x)
# Non-linearity
x = nn.silu(x)
# Linear
x = nn.Dense(features=64)(x)
# Non-linearity
x = nn.silu(x)
# Linear
x = nn.Dense(features=2)(x)
return x
model = MLP()
import jax.numpy as jnp
example_input = jnp.ones((1,4))
params = model.init(jax.random.PRNGKey(0), example_input)
y = model.apply(params, example_input)
Architecture
Parameters
Call
cuestalz@mit.edu