On Hierarchical Gaussian Processes, Kernel Design, and GP Latent variable models

Vidhi Lalchand

Research Highlights

 

13-02-2025

(A compilation of research directions over the last few years and a snapshot of current research)

  • Hierarchical GPs - are they worth the effort?

 

  • Adapting GPs for unsupervised learning. 

 

  • Scientific applications & some new directions.

 

 

 

Outline

"Functions describe the world"

- Thomas Garrity

Gaussian processes are a Bayesian nonparametric paradigm for learning "functions".

  • They are probabilistic \( \rightarrow\) our predictions are distributions over functions.
  • They don't have standard parameters they model the mapping \( f: \mathcal{X} \longrightarrow \mathbb{R} \) directly by placing a prior in the space of functions.

Gaussian Processes

Gaussian Processes: A generalisation of a Gaussian distribution

A sample from a \(k\)-dimensional Gaussian \( \mathbf{x} \sim \mathcal{N}(\mu, \Sigma) \) is a vector of size \(k\). $$ \mathbf{x} = [x_{1}, \ldots, x_{k}] $$

A GP is an infinite dimensional analogue of a Gaussian distribution (so samples are functions \( \Leftrightarrow \)vectors of infinite length)

f(x) \sim \mathcal{GP}(m(x),k(x, x^{\prime}))

The mathematical crux of a GP is that \( \textbf{f} \equiv  [f(x_{1}), f(x_{2}), f(x_{3}),....., f(x_{N})]\) is just a draw from a N-dimensional multivariate Gaussian.

\begin{bmatrix} f_{1} \\ \vdots\\ f_{499} \\ f_{500} \end{bmatrix} \sim \mathcal{N}\left(\begin{bmatrix} m(x_{1}) \\ \vdots\\ m(x_{499}) \\ m(x_{500}) \\ \end{bmatrix}, \begin{bmatrix} k(x_{1}, x_{1}) & \ldots & \ldots k(x_{1}, k_{500}) \\ \vdots & \ddots &\vdots \\ \vdots & \ddots & \vdots \\ k(x_{500}, x_{1}) & \ldots & \ldots k(x_{500}, k_{500}) \\ \end{bmatrix} \right)

But at any given point, we only need to represent our function \( f(x) \) at a finite index set \( \mathcal{I} = [x_{1},\ldots, x_{500}]\). So we are interested in our 500 dimensional function vector \( [f(x_{1}), f(x_{2}), f(x_{3}),....., f(x_{500})]\).

p(f_{1}, f_{2}, \ldots, f_{N}) = \int\ldots \int p(f_{1}, f_{2}, \ldots, \ldots) df_{{j} (j \notin \mathcal{I})}

Marginalising function values corresponding to inputs not in our index set.

Gaussian processes

A powerful, Bayesian, non-parametric paradigm for learning functions.

 \(f(x) \sim \mathcal{GP}(m(x), k_{\theta}(x,x^{\prime})) \)

 \(f(X) \sim \mathcal{N}(m(X), K_{X})\)

For a finite set of points, \( X \):

\( k_{\theta}(x,x^{\prime})\) encodes the support and inductive biases in function space.

GP Math: Canonical GPs

Learning occurs through adapting the hyperparameters of the kernel function by optimising the marginal likelihood.

A central quantity in Bayesian ML is the marginal likelihood.

Learning Step:

 

\begin{aligned} p(\bm{y}|\theta) &= \int p(\bm{y}|\bm{f})p(\bm{f}|\theta)d\bm{f}\\ &= \int \mathcal{N}(\bm{f}, \sigma_{n}^{2}\mathbb{I})\mathcal{N}(\bm{0}, K_{\theta})d\bm{f} \\ &= \mathcal{N}(0, K_{\theta} + \sigma^{2}_{n}\mathbb{I}) \end{aligned}
\bm{\theta_{\star}} = \argmax_{\theta}\log p(\bm{y}|\theta)

Data likelihood

Prior

Denominator of Bayes Rule

\textrm{Hyperparameters } \textcolor{blue}{\theta}\\ \textrm{Parameters } \textcolor{blue}{p(\bm{f}|\theta) = \mathcal{N}(0, K_{\theta})}\\ \textrm{Data } \textcolor{blue}{p(\bm{y}|\bm{f}) = \mathcal{N(\bm{f}, \sigma^{2}\mathbb{I})}}
\mathcal{L(\theta)} = \log p(\bm{y}|\theta) = -\overbrace{\frac{1}{2}y^{T}(K_{\theta} + \sigma^{2}_{n})^{-1}y}^{\textrm{model fit}} -\overbrace{\frac{1}{2}\log|K_{\theta} + \sigma^{2}_{n}\mathbb{I}|}^{\textrm{complexity penalty}} - \dfrac{n}{2}log2\pi \\
\begin{aligned} p(\bm{f}|\bm{y}, \theta) &= \dfrac{p(\bm{y}|\bm{f})p(\bm{f}|\theta)}{p(\bm{y}|\theta)} \\ \end{aligned}
p(\bm{f}_{*}| X_{*}, X, \bm{f}, \theta^{*}) = \mathcal{N}(\mu_{\star}, \Sigma_{\star})

Behaviour of ML-II 

Non-identifiability of ML-II under weak data, multiple restarts needed.

ML-II - does it always work?

Weakly identified hyperparameters can manifest in flat ridges (where different combinations of hyperparameters give very similar ML values) in the marginal likelihood surface, making ML-II estimates subject to high-variability. 

The problem of ridges in the marginal likelihood surface also does not necessarily go away as more observations are collected.

A fully Bayesian treatment of GP models would integrate away kernel hyperparameters when making predictions:

where \( \bm{f}_{*}| X_{*}, X, \bm{y}, \bm{\theta} \sim \mathcal{N}(\bm{\mu_{*},\Sigma_{*}}) \)
The posterior over hyperparameters is given by,
where \(p(\bm{y}|\bm{\theta}) \) is the GP marginal likelihood.

p(\bm{f}_{*}| X_{*}, X, \bm{y}) = \int p(\bm{f}_{*}| X_{*}, X, \bm{y}, \bm{\theta})\textcolor{red}{p(\bm{\theta}|\bm{y})}d\bm{\theta}
\boxed{ \textcolor{red}{p(\bm{\theta}|\bm{y})} \propto p(\bm{y}|\bm{\theta})p(\bm{\theta}) }

Hierarchical GPs

Propagate hyperparameter uncertainty to outputs.

\textcolor{red}{ \bm{\theta} \sim p(\bm{\theta}) }\\ \bm{f}| X, \bm{\theta} \sim \mathcal{N}(\bm{0}, K_{\theta}) \hskip 10pt \\ \bm{y}| \bm{f} \sim \mathcal{N}(\bm{f}, \sigma_{n}^{2}\mathbb{I}) \\ \\ \bm{y} = \bm{f} + \bm{\epsilon}, \hskip 20pt \bm{\epsilon} \sim \mathcal{N}(0, \sigma_{n}^{2}\mathbb{I}) \\

Hyperprior \( \rightarrow \)

 

Prior \( \rightarrow \)

Model \( \rightarrow \)

Likelihood \( \rightarrow \)

Intractable

Integrating out kernel hyperparameters gives rise to an approximate posterior which is a heavy-tailed non-Gaussian process, this may make it a better choice for certain applications.

Visualising the prior predictive function space 

Canonical

Hierarchical

We adapt a technique frequently used in physics and astronomy literature to sample from the hyperparameter posterior.

Nested Sampling (Skilling, 2004) is a gradient free method for Bayesian computation.

Fergus Simpson*, Vidhi Lalchand*, Carl E. Rasmussen. Marginalised Gaussian processes with Nested Sampling . https://arxiv.org/pdf/2010.16344.pdf

Principle: Sample from "nested shells" / iso-likelihood contours of the evidence and weight them appropriately to give posterior samples.

Nested Sampling: The principle

0 < ....< X_{i+1} < X_{i} < X_{i-1} < .....1
  • Start with N "live" points \( \{\theta_{1}, \theta_{2}, \ldots, \theta_{N} \} \) sampled from the prior,  \(\theta_{i} \sim p(\theta) \) , set \( \mathcal{Z} = 0\)
  • Compute the minimum likelihood \(\mathcal{L_{i}} = \min(\mathcal{L}(\theta_{1}), \ldots \mathcal{L}(\theta_{N})) \) from the current set of live points and discard point \( \theta_{i}\).
  • Sample a new point \( \theta^{\prime}\) from the prior subject to \( \mathcal{L}(\theta^{\prime}) > \mathcal{L}_{i} \)     

Define a notion of prior volume, $$ X(\lambda) = \int_{\mathcal{L}(\theta) > \lambda} \pi(\theta)d\theta$$

The area/volume of parameter space "inside" a iso-likelihood contour

One can re-cast the multi-dimensional evidence integral as a 1d function of prior volume \( X\).

$$ \mathcal{Z} = \int \mathcal{L}(\theta)\pi(\theta)d\theta = \int_{0}^{1} \mathcal{L}(X)dX$$

$$ \mathcal{Z} \approx \sum_{i=1}^{M}\mathcal{L_{i}}(X_{i} - X_{i+1})$$

  • Assign estimated prior mass at this step as, \(\hat{X_{i}} = e^{-i/N}\) # why?
  • Assign weight for the saved point, \( w_{i} = \hat{X}_{i-1} - \hat{X}_{i}\) # has to be positive as volume in prior is shrinking at each step
\begin{aligned} X_{i} &\approx \gamma X_{i-1} \\ \mathbb{E}(\gamma) &= 1 - \dfrac{1}{N+1} = \dfrac{N}{N + 1} \\ X_{i} &\approx \left(\dfrac{N}{N + 1}\right) X_{i-1} \\ &\approx \left(\dfrac{N}{N + 1}\right)^{i} X_{0} \\ &\approx \left(\dfrac{N}{N + 1}\right)^{i} \\ &\approx e^{-i/N} \end{aligned}

Results

Why does a more diffused predictive interval yield better test predictive density?

Fergus Simpson*, Vidhi Lalchand*, Carl E. Rasmussen. Marginalised Gaussian processes with Nested Sampling . https://arxiv.org/pdf/2010.16344.pdf

y = (\cos 2 x_1 \times \cos 2x_2) \sqrt{|x_1 x_2|}

2d modelling task

Time series data: ML-II v. HMC v. Nested

HMC v. Nested (Training runtime)

Gibbs kernel:

Hierarchical Non-stationary Gaussian Processes: Gibbs Kernel 

\log(\ell_{d}) \sim \mathcal{GP}(\mu_{\ell}, k_{\ell})
f \sim \mathcal{GP}(\mu_{f}, k_{f}(\ell(x_{i}),\ell(x_{j})))

Stationary kernels assume fixed inductive biases across the whole input space, where the smoothness and structure properties do not change over the space of covariates/inputs.

Non-stationary kernels allow hyperparameters to be input-dependent, in a hierarchical setting, these functions can in turn be modelled probabilistically. 

\prod_{d=1}^{D}\sqrt{\dfrac{2\ell_{d}(\bm{x}_{i})\ell_{d}(\bm{x}_{j})}{\ell_{d}^{2}(\bm{x}_{i}) + \ell_{d}^{2}(\bm{x}_{j})}}\exp \left\{ - \sum_{d=1}^{D} \dfrac{(x_{i}^{(d)} - x_{j}^{(d)})^{2}}{\ell_{d}^{2}(\bm{x}_{i}) + \ell_{d}^{2}(\bm{x}_{j})}\right \},

Prior Predictive samples

Hierarchical GP Model: MAP Inference

Classical 

Stationary Hierarchical

Non-Stationary Hierarchical

\textcolor{red}{ \bm{\theta} \sim p(\bm{\theta}) }\\ \bm{f}| X, \bm{\theta} \sim \mathcal{N}(\bm{0}, K_{\theta}) \hskip 10pt \\ \bm{y}| \bm{f} \sim \mathcal{N}(\bm{f}, \sigma_{n}^{2}\mathbb{I}) \\ \\ \bm{y} = \bm{f} + \bm{\epsilon}, \hskip 20pt \bm{\epsilon} \sim \mathcal{N}(0, \sigma_{n}^{2}\mathbb{I}) \\

Hyperprior \( \rightarrow \)

 

Prior \( \rightarrow \)

Model \( \rightarrow \)

Likelihood \( \rightarrow \)

\bm{f}| X, \bm{\theta} \sim \mathcal{N}(\bm{0}, K_{\theta}) \hskip 10pt \\ \bm{y}| \bm{f} \sim \mathcal{N}(\bm{f}, \sigma_{n}^{2}\mathbb{I}) \\ \\ \bm{y} = \bm{f} + \bm{\epsilon}, \hskip 20pt \bm{\epsilon} \sim \mathcal{N}(0, \sigma_{n}^{2}\mathbb{I}) \\
\textcolor{red}{ \log \bm{\ell_{d}}(\bm{x}) \sim \mathcal{N}(0, K_{l}) }\\ \bm{f}| X, \bm{\theta} \sim \mathcal{N}(\bm{0}, K_{\theta}) \hskip 10pt \\ \bm{y}| \bm{f} \sim \mathcal{N}(\bm{f}, \sigma_{n}^{2}\mathbb{I}) \\ \\ \bm{y} = \bm{f} + \bm{\epsilon}, \hskip 20pt \bm{\epsilon} \sim \mathcal{N}(0, \sigma_{n}^{2}\mathbb{I}) \\

Hyperprior \( \rightarrow \)

 

\bm{\theta_{\star}} = \argmax_{\bm{\theta}}\log p(\bm{y}|\bm{\theta})
\begin{aligned} \ell_{\text{MAP}} &= \argmax_{\ell}\log p(\bm{y}|\hat{\ell})p(\hat{\ell}) \\ &= \argmax_{\ell} \log \mathcal{N}(\bm{y}|0, K_{f} + \sigma^{2}_{n}\mathbb{I})\mathcal{N}(\hat{\ell}|0, K_{\ell}) \end{aligned}
\begin{aligned} \theta_{\text{MAP}} &= \argmax_{\ell}\log p(\bm{y}|\bm{\theta})p(\bm{\theta}) \\ &= \argmax_{\ell} \log \mathcal{N}(\bm{y}|0, K_{f} + \sigma^{2}_{n}\mathbb{I})\mathcal{N}(\bm{\theta}|0, \Sigma) \end{aligned}

Synthetic Examples: 1d and 2d

Exact GP - SE-ARD

Exact GP - Gibbs Kernel

Ground truth function surface

Posterior predictive means

Kernel matrix post training

Kernel matrix with ground truth hyperparameters

Learning a truncated trigonometric function in 2d

Stationary kernel

Means

Variances

Non-stationary

Lengthscale processes across the input space

Modelling Precipitation across the US

  • A significant amount of climate research suggests that climate processes are non-stationary.
  • For instance, the pattern of precipitation, a core target of climate modelling is tightly linked to the underlying terrain.
  • Very often the data is just spatial or spatio-temporal with 3 coordinates (lat, lon, time).

Source: National Oceanic Atmospheric Administration (https://www.ncei.noaa.gov/cdo-web/)

Kernel Identification with Transformers

A framework for performing kernel selection via a transformer deep neural network.

(a) The ability to identify multiple candidate kernels compatible with the data.

 (b) It is blazingly fast to evaluate once it is pre-trained

Training data

\{x_{n}, y_{n}\}_{n=1}^{N}

Labels: RBF*PER + MAT52

+

Transformer DNN

RBF {0.43}
MAT52 {0.31}
RBF + NOISE {0.2}
RBF + MAT52 {0.06}

Kernel Identification with Transformers

Kernel Identification with Transformers

Left: Classification performance for random samples drawn from primitive kernels across a range of test sizes and dimensionality.

Right: The time taken to predict a kernel for each of the UCI datasets. While KITT's overhead remains approximately constant, the tree search becomes impractical for larger inputs.

  • GPs can be used in the unsupervised settings by learning a non-linear, probabilistic mapping from latent space \( \mathbf{X} \) to data-space \( \mathbf{Y} \).
  • We assume the inputs \( \mathbf{X} \) are latent (unobserved).

 

Given: High dimensional training data \( \mathbf{Y} \equiv \{\bm{y}_{n}\}_{n=1}^{N},  Y \in \mathbb{R}^{N \times D}\)

Learn: Low dimensional latent space \( \mathbf{X} \equiv \{\bm{x}_{n}\}_{n=1}^{N}, X \in \mathbb{R}^{N \times Q}\)

Lalchand et al. (2022)

Gaussian Processes for Latent Variable Modelling (at scale)

Vidhi Lalchand, Aditya Ravuri, Neil D. Lawrence. Generalised GPLVM with Stochastic Variational Inference. In International Conference on Artificial Intelligence and Statistics
(AISTATS), 2022

The GPLVM is probabilistic kernel PCA with a non-linear mapping from a low-dimensional latent space \( \mathbf{X}\) to a high-dimensional data space \(\mathbf{Y}\).  

\mathbf{X} \xrightarrow{f}\mathbf{Y}
f_{d} \sim \mathcal{GP}(m,k(x, x^{\prime}))

The Gaussian process mapping

High-dimensional data space

. . . 

. . . 

N

D

\( X \in \mathbb{R}^{N \times Q}\)

\( f_{d} \sim \mathcal{GP}(0,k_{\theta})\)

\( Y \in \mathbb{R}^{N \times D} (= F + noise)\)

\bm{y}_{:,d} = f_{d}(\mathbf{X}) + \bm{\epsilon} \\
\begin{bmatrix} y_{1,1} & \ldots & \ldots & | & \ldots \\ - & - & y_{n} & y_{n,d} & - \\ \ldots & \ldots & \ldots & | & \ldots \\ \ldots & \ldots & \ldots & y_{d} & \ldots \\ \ldots & \ldots & \ldots & | & \ldots \\ \end{bmatrix}_{N \times D}
\bm{y}_{n,d} = f_{d}(\bm{x}_{n}) + \bm{\epsilon}_{n} \\
\prod_{d=1}^{D}p(\bm{y}_{:,d}|\mathbf{X}) = \mathcal{L}(\mathbf{X}, \mathbf{\theta}) = -\frac{DN}{2} \log 2\pi - \frac{D}{2} \log |\mathbf{K}_{\theta}| - \frac{1}{2} \text{tr}(\mathbf{K}_{\theta}^{-1} \mathbf{Y}\mathbf{Y}^{\top})

GP prior over mappings 

(per dimension, \(d\))

p(f_{1:D}|X) = \displaystyle \prod_{d=1}^{D}\mathcal{N}(f_{d}| 0, \mathbf{K}_{\theta})

GP marginal likelihood

\hat{\mathbf{X}}, \hat{\theta} = \text{argmax}_{\mathbf{X}, \theta} \mathcal{L}(\mathbf{X}, \theta)

Optimisation problem:

GPLVM generalises probabilistic PCA - one can recover probabilistic PCA by using a linear kernel

Gaussian Processes for Latent Variable Modelling (at scale)

Lalchand et al. (2022)

. . . 

N

D

\( \mathbf{X} \in \mathbb{R}^{N \times Q}\)

\( f_{d} \sim \mathcal{GP}(0,k_{\theta})\)

\( \mathbf{Y} \in \mathbb{R}^{N \times D} (= F + noise)\)

\bm{y}_{n,d} = f_{d}(\bm{x}_{n}) + \bm{\epsilon}_{n} \\
\begin{aligned} \mathcal{L}_{1:D} = \sum_{d=1}^{D}\mathcal{L}_{d} = \sum_{d=1}^{D}\sum_{n=1}^{N} \mathbb{E}_{q(f, \textbf{X}, \textbf{U})}[\log p(y_{n,d}|\bm{f}_{d}, \bm{x}_{n})] - \sum_{d}\textrm{KL}(q(\bm{u}_{d})||p(\bm{u}_{d})) - \sum_{n}\textrm{KL}(q(\bm{x}_{n})||p(\bm{x}_{n})) \end{aligned}

ELBO:

\log p(Y|\bm{\theta}) \geq

Stochastic Variational Formulation 

prohibitive

p(\textbf{U}) = \prod_{d=1}^{D}p(u_{d}|Z) = \mathcal{N}(\textbf{0}, K_{mm})
p(\textbf{X}) = \prod_{n=1}^{N=1}\mathcal{N}(\textbf{0}, \mathbb{I}_{Q})

Prior on latents

Prior on inducing variables

Bayesian GPLVM: Automatic relevance determination (ARD)

Vidhi Lalchand, Aditya Ravuri, Neil D. Lawrence. Generalised GPLVM with Stochastic Variational Inference. In International Conference on Artificial Intelligence and Statistics
(AISTATS), 2022

In the plot: Showing the two dimensions corresponding to the highest inverse lengthscale.

In the plot: Inverse lengthscales for each latent dimension. Both modes of inference switch off 7 out of 10 dimensions.

Data: Multiphase oil flow data that consists of 1000, 12 dimensional observations belonging to three known classes corresponding to different phases of oil flow.

Decoder only 

Encoder-Decoder 

p(\mathbf{X}|\mathbf{Y}) \approx q(\mathbf{X}) = \prod_{n=1}^{N}\mathcal{N}( \bm{x}_{n}| G_{\phi_{1}}(\bm{y}_{n}), H_{\phi_{2}}(\bm{y_{n}}) H_{\phi_{2}}(\bm{y_{n}})^{T})

GPLVMs are typically decoder only models but can be amortised with an encoder network to learn parameters of the variational distribution

\begin{aligned} k_{f}(\bm{x}, \bm{x}^{\prime}) &= \exp\left\{-\sum_{q=1}^{Q}\frac{(\bm{x}_{q} - \bm{x}^{\prime}_{q})^{2}}{2{\color{blue}{l_{q}^{2}}}}\right\} \\ \end{aligned}

Robust to missing dimensions in training data (missing at random)

Vidhi Lalchand, Aditya Ravuri, Neil D. Lawrence. Generalised GPLVM with Stochastic Variational Inference. In International Conference on Artificial Intelligence and Statistics
(AISTATS), 2022

30% 

60%

Training data: MNIST images with masked pixels

Reconstruction

Note: This is different to tasks where missing data is only introduced at test time

Learning to reconstruct dynamical data 

Learning interpretable latent dimensions in biological data

Vidhi Lalchand*, Aditya Ravuri*, Emma Dann*, Natsuhiko Kumasaka, Dinithi Sumanaweera, Rik G.H. Lindeboom, Shaista Madad, Neil D. Lawrence, Sarah A. Teichmann. Modelling Technical and Biological Effects in single-cell RNA-seq data with Scalable Gaussian Process Latent Variable Models (GPLVM). In Machine Learning in Computational Biology (MLCB), 2022

COVID Data: Gene expression profiles of peripheral blood mononuclear cells (PBMC) from a cohort of 107 patients.

N = 600,000 cells

Data dimension D = 5000

Latent dimension Q = 10

One of the latent dimensions captures gradations of COVID severity across patients.  when averaged 

The model is able to find structure in the gene expression profile 

In the plot: Schematic of the GPLM model learning a mapping from a low-dimensional latent space to gene expression profiles per cell. 

Generative model for small molecules + Property prediction

How do we use generative modelling to identify molecules which optimise a property of interest?

Gómez-Bombarelli R, et al. Automatic chemical design using a data-driven continuous representation of molecules. ACS central science. 2018. (ChemicalVAE)

Generative model for small molecules + Property prediction

Gaussian process property decoders on the generative latent space

Recurrent VAE + Gaussian process decoders

Overall loss = 

\mathcal{L}_{\text{V-RNN}} + \mathcal{L}_{\text{GP}}
k_{i}(\bm{z}, \bm{z}^{\prime}) = {\color{blue}{\sigma^{2}}}\exp\left\{-\sum_{q=1}^{Q}\frac{(\bm{z}_{q} - \bm{z}^{\prime}_{q})^{2}}{2{\color{blue}{\ell_{q}^{2}}}}\right\} \\

Q is the dimensionality of the latent space. The learnt lengthscales in each dimension indicate the influence of that dimension on the prediction.

Joint Model

Evolution of a 2d subspace of the latent space

Points shaded by actual QED (drug-likeness) scores

Can we do optimisations in a low dimensional subspace of the latent space and map back to the full latent space?

We really need extremely flexible inductive biases which adapt to the needs of the data - but you can just write down a new kernel function \(\longrightarrow\) much easier to explore non-stationary choices and parameterise them in interesting ways.

Hierarchical GPs also give rise to models with fat tailed predictive posteriors, this may be relevant in applications where we know our data to be heavy-tailed. 

Hierarchical GPs are worth it only if you are in a weak data regime or in regimes with very high epistemic and aleatoric uncertainty. In these settings, a HGP will be much more robust to a canonical GP.

Summary, TLDR & Thanks

GPs can play the role of regularisers of the latent space in conditional generative models, the strength of the regularisation can be controlled by the choice of the kernel function. 

email: vidrl@mit.edu / vr308@cam.ac.uk

twitter: @VRLalchand

 Nested Sampling: The principle

Define a notion of prior volume, $$ X(\lambda) = \int_{\mathcal{L}(\theta) > \lambda} \pi(\theta)d\theta$$

The area/volume of parameter space "inside" a iso-likelihood contour

One can re-cast the multi-dimensional evidence integral as a 1d function of prior volume \( X\).

$$ \mathcal{Z} = \int \mathcal{L}(\theta)\pi(\theta)d\theta = \int_{0}^{1} \mathcal{L}(X)dX$$

The evidence can then just be estimated by 1d quadrature.

$$ \mathcal{Z} \approx \sum_{i=1}^{M}\mathcal{L_{i}}(X_{i} - X_{i+1})$$

Static Nested Sampling 

Start with N "live" points \( \{\theta_{1}, \theta_{2}, \ldots, \theta_{N} \} \) sampled from the prior,  \(\theta_{i} \sim p(\theta) \) , set \( \mathcal{Z} = 0\)

Skilling (2004)

for \( i = 1, \ldots, K\)

         

  • Compute the minimum likelihood \(\mathcal{L_{i}} = \min(\mathcal{L}(\theta_{1}), \ldots \mathcal{L}(\theta_{N})) \) from the current set of live points.

while stopping criterion is unmet do

  • Add the point \( \theta_{i}\) associated with the lowest likelihood \( \mathcal{L_{i}}\) to a list of "saved" points.
  • Sample a new point \( \theta^{\prime}\) from the prior subject to \( \mathcal{L}(\theta^{\prime}) > \mathcal{L}_{i} \)     
  • Assign estimated prior mass at this step as, \(\hat{X_{i}} = e^{-i/N}\)
  • Assign weight for the saved point, \( w_{i} = \hat{X}_{i-1} - \hat{X}_{i}\) # has to be positive as volume in prior is shrinking at each step
  • Accumulate evidence,  \( \mathcal{Z} = \mathcal{Z} + \mathcal{L}_{i}w_{i}\)

  • Evaluate stopping criterion, if triggered then break;

end

return set of saved points \(\{ \theta_{i}\}_{i=1}^{N + K} \), along with importance weights \( \{p_{i} \}_{i=1}^{N + K}\), and evidence \( \mathcal{Z} \)

Add final N live points to the "saved" list with:

  • Each remaining weight \( w_{N} = \hat{X_{K}}/N\)
  • Final evidence is given by, \( \mathcal{Z} = \sum_{i=1}^{N+K}\mathcal{L}_{i}w_{i} \)
  • Importance weights for each sample are given by, \( p_{i} = \mathcal{L}_{i}w_{i}/\mathcal{Z}\)

 # final slab of enclosed prior mass

 \( \rightarrow \) hard problem

# why?

Prior Mass Estimation 

Why is \( X_{i} \) set to  \( e^{-i/N}\) where \(i\) is the iteration number and \(N\) is the number of live points.

0 < ....< X_{i+1} < X_{i} < X_{i-1} < .....1
\begin{aligned} X_{i} &\approx \gamma X_{i-1} \\ \mathbb{E}(\gamma) &= 1 - \dfrac{1}{N+1} = \dfrac{N}{N + 1} \\ X_{i} &\approx \left(\dfrac{N}{N + 1}\right) X_{i-1} \\ &\approx \left(\dfrac{N}{N + 1}\right)^{i} X_{0} \\ &\approx \left(\dfrac{N}{N + 1}\right)^{i} \\ &\approx e^{-i/N} \end{aligned}

(images from An Intro to dynamic nested sampling. Speagle (2017) )

Sampling from the constrained prior

Dynamic Nested Sampling. Speagle (2019)

At every step we need a sample \( \theta^{\prime}\) such that \( \mathcal{L_{\theta^{\prime}}} > \mathcal{L}_{i}\)

We could keep sampling uniformly from the prior and keep rejecting until we find one that meets the likelihood condition, but this takes too long when the likelihood threshold is high and there is a better way.

What if we could sample directly from the constrained prior?

p_{\lambda}(\theta) = \Bigg\{ \begin{aligned} &p(\theta) / X_{\lambda}, L(\theta) > \lambda \\ &0, \hspace{10mm} L(\theta) < \lambda \end{aligned}

New Directions

  • Regression to model functions with inhomogenous smoothness properties: building non-stationarity into kernels.
  • Probabilistic latent variable models + neural encoders.

 

 

 

 

 

 

 

 

 

  • Geometry: latent variables may live on a non-Euclidean manifold? 
  • Building custom generative AI for scientific applications. 

Neural encoder (parameteric)

Gaussian process (non-parameteric)

GPLVM: Generative Model with Sparse Gaussian processes

\textbf{F} \equiv \{ f_{d} \}_{{d=1}}^{D}
\textbf{U} \equiv \{ u_{d} \}_{{d=1}}^{D}
p(\textbf{U}) = \prod_{d=1}^{D}p(u_{d}|Z) = \mathcal{N}(\textbf{0}, K_{mm})

where, 

Q_{nn} = K_{nn} - K_{nm}K_{mm}^{-1}K_{mn}

Prior over latents:

Prior over inducing variables:

Conditional prior: 

Data likelihood: 

Stochastic Variational Evidence Lower bound 

Stochastic variational inference for GP regression was introduced in Hensman et al (2013) (\( \mathcal{L}_{1}\) below). In this work we extend the SVI bound in two ways - we introduce the variational distribution over the unknown \(X\) and make \(Y\) multi-output.

\begin{aligned} \mathcal{L}_{1:D} = \sum_{d=1}^{D}\mathcal{L}_{d} = \sum_{d=1}^{D}\sum_{n=1}^{N} \mathbb{E}_{q(f, \textbf{X}, \textbf{U})}[\log p(y_{n,d}|\bm{f}_{d}, \bm{x}_{n})] - \sum_{d}\textrm{KL}(q(\bm{u}_{d})||p(\bm{u}_{d})) - \sum_{n}\textrm{KL}(q(\bm{x}_{n})||p(\bm{x}_{n})) \end{aligned}

ELBO:

\log p(Y|\bm{\theta}) \geq

Non-Gaussian likelihoods         Flexible variational families         Amortised Inference        Missing data problems        

p(F, X, U |Y) = \Big [\prod_{d=1}^{D}p(\bm{f}_{d}|\bm{u}_{d},X)q(\bm{u}_{d}) \Big]q(X)\hspace{-4mm} \\ \approx q(F, X, U)

Variational Formulation:

q(X) =\prod_{n=1}^{N}\mathcal{N}(\bm{x}_{n}; \mu_{n}, s_{n}\mathbb{I}_{Q}),
\log p(Y|\theta) = \log \int p(Y|X)p(X)dX \\ \hspace{40mm} \geq \int q(X) (\mathcal{L}_{1} + \log p(X) - \log q(X))dX

Vidhi Lalchand, Aditya Ravuri, Neil D. Lawrence. Generalised GPLVM with Stochastic Variational Inference. In International Conference on Artificial Intelligence and Statistics
(AISTATS), 2022