Siddhant Shah

Stochastic Variational Inference (SVI) with Pyro

Published on Sunday, 24 September 2023

This is an in-depth explanation of the tutorial available at the Pyro Library Website, who have adapted it from Chapter 7 of the excellent book Statistical Rethinking by Richard McElreath. It refers to the work of the paper Ruggedness: The blessing of bad geography in Africa. The goal of this exercise is to explore the relationship between topographic heterogeneity of a nation as measured by the Terrain Ruggedness Index and its GDP per capita.

Table of Contents

Abstract of Ruggedness: The blessing of bad geography in Africa

We show that geography, through its impact on history, can have important effects on current economic development. The analysis focuses on the historic interaction between ruggedness and Africa’s slave trades. Although rugged terrain hinders trade and most productive activities, negatively affecting income globally, within Africa rugged terrain afforded protection to those being raided during the slave trades. Since the slave trades retarded subsequent economic development, within Africa ruggedness has also had a historic indirect positive effect on income. Studying all countries worldwide, we estimate the differential effect of ruggedness on income for Africa. We show that:

We model the analysis of these data points in Pyro as given below.

Set Up

Importing the Required Libraries

%reset -s -f

import logging
import os

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.6')

logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
%matplotlib inline'default')

import pyro.distributions as dist
import pyro.distributions.constraints as constraints

Importing the data

# Importing Data
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]

# Now we log-normalize the GDP
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])

# We then convert the Numpy array behind this dataframe to a
# torch.Tensor for analysis with PyTorch and Pyro.
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

Note that as the variable GDP is highly skewed, we log-transform it before proceeding.

Preliminary Visualisation

# Plot Scatter Plots
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey='all')
african_nations = df[df["cont_africa"] == 1]
non_african_nations = df[df["cont_africa"] == 0]
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

This displays a somewhat superficial claim that there is indeed a possible relationship between ruggedness and GDP, but that further analysis will be needed to confirm it.

Preliminary Visualisation

Our Mathematical Model

We will now be implementing a Bayesian Regression Model.

Our initial hypothesis is that ruggedness has an effect on current income that is the same for all parts of the world. This relationship can be written

yi=κ1αri+βqi+ei\begin{equation} y_i =κ_1 -\alpha r_i +βq_i +e_i \end{equation}

In equation (1)(1), we assume that the common impact of ruggedness on income is negative. This is not important for the exposition. It simply anticipates our empirical findings of a negative common effect of ruggedness.

Historical studies and the empirical work of Nunn (2008) have documented that Africa’s slave trades adversely affected the political and social structures of societies. We capture this effect of Africa’s slave trades with the following equation

qi=κ2γxi+uiif i is in Africa,uiotherwise\begin{equation} \begin{align*} q_i =\quad & κ_2 -\gamma x_i +u_i \quad \text{if i is in Africa}, \\ &u_i \quad \text{otherwise} \end{align*} \end{equation}

where xix_i denotes slave exports, κ2κ_2 and γ\gamma are constants (γ>0)(\gamma > 0), and uiu_i is a classical error term.

Historical accounts argue that the number of slaves taken from an area was reduced by the ruggedness of the terrain. This relationship is given by

xi=κ3λri+vi\begin{equation} x_i = κ_3 - λr_i + v_i \end{equation}

where κ3κ_3 and λλ are constants (λ>0)(λ > 0), and viv_i is a classical error term.

Equations (1)(1), (2)(2) and (3)(3) are essential relationships in our analysis. We combine all three into one and rename variables to get the fundamental relationship for our model.

y=β0+ri(β1+β2IiAfrica)+β3IiAfrica+εy=β0+riγ+β3IiAfrica+ε\begin{align*} y &= β_0 + r_i (β_1 + β_2 I_i^{Africa}) + β_3 I_i^{Africa} + ε\\ y &= β_0 + r_i \gamma + β_3 I_i^{Africa} + ε \end{align*}

where IAfricaI^{Africa} is an indicator variable that equals 11 if ii is in Africa and 00 otherwise, rir_i is our measure of ruggedness, β0\beta_0, γ\gamma, β3\beta_3 are constants, and ϵ\epsilon is a classical error term.

def model(is_cont_africa, ruggedness, log_gdp=None):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_s = pyro.sample("bS", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_sr = pyro.sample("bSR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))

    mean = a + b_r * ruggedness + is_cont_africa * (b_s + b_sr * ruggedness)

    with pyro.plate("data", len(ruggedness)):
        return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

In a Bayesian Linear Regression model, we need to specify prior distributions on the parameters αR\alpha \in\mathbb R (represented by a) and βR3\beta \in\mathbb R^3 (expanded here into scalars b_a, b_r, and b_ar). These are probability distributions that represent our beliefs prior to observing any data about reasonable values for α\alpha and β\beta. We will also add a random scale parameter σ\sigma that controls the observation noise.

Process of Inference

Notation Through out the mathematical references in this document, these notations will remain consistent:

Now that we have specified a model, Bayes’ rule tells us how to use it to perform inference, or draw conclusions about latent variables from data: compute the posterior distribution over z\textbf{z}:

pθ(zx)=pθ(x,z)dzpθ(x,z)p_\theta(\textbf{z} \mid \textbf{x}) = \frac{p_\theta(\textbf{x}, \textbf{z})}{\int d\textbf{z} \quad p_\theta(\textbf{x}, \textbf{z})}

To check the results of modeling and inference, we would like to know how well a model fits observed data x\textbf x, which we can quantify with the evidence or marginal likelihood

pθ(x)=dzpθ(x,z)p_\theta(\textbf{x}) = \int d\textbf{z} \quad p_\theta(\textbf{x}, \textbf{z})

and also to make predictions for new data, which we can do with the posterior predictive distribution

pθ(xx)=dzpθ(xz)pθ(zx)p_\theta(x' \mid \textbf{x}) = \int d\textbf{z} \quad p_\theta(x' \mid \textbf{z}) p_\theta(\textbf{z} \mid \textbf{x})

We aim to learn the parameters θ\theta of our models from observed data x\textbf{x}, by maximising the marginal likelihood:

θmax=arg maxθpθ(x)\theta_{max} = \argmax_{\theta} p_\theta(\textbf{x})

Estimating the Learnable Parameters


Each of these computations (the posterior distribution, the marginal likelihood and the posterior predictive distribution) requires performing integrals that are often impossible or computationally intractable.

While Pyro includes support for many different exact and approximate inference algorithms, the best-supported is variational inference, which offers a unified scheme for finding θmax\theta_{max} and computing a tractable approximation qϕ(z)q_\phi(\textbf{z}) to the true, unknown posterior pθmax(zx)p_{\theta_{max}}(\textbf z \mid \textbf x) by converting the intractable integrals into optimization of a functional of pp and qq.

This distribution qϕq_\phi is called the variational distribution in much of the literature, and in the context of Pyro it’s called the guide.

Just like the model, the guide is encoded as a Python program guide() that contains pyro.sampleand pyro.param statements. It does not contain observed data, since the guide needs to be a properly normalised distribution so that it is easy to sample from.

Note: We have implemented a guide that assumes that there is no correlation amongst the latent variables. (AutoNormal)

def custom_guide(is_cont_africa, ruggedness, log_gdp=None):
    a_loc = pyro.param('a_loc', lambda: torch.tensor(0.))
    a_scale = pyro.param('a_scale', lambda: torch.tensor(1.),
    sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.),
    weights_loc = pyro.param('weights_loc', lambda: torch.randn(3))
    weights_scale = pyro.param('weights_scale', lambda: torch.ones(3),
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_s = pyro.sample("bS", dist.Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
    b_sr = pyro.sample("bSR", dist.Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
    return {"a": a, "b_s": b_s, "b_r": b_r, "b_sr": b_sr, "sigma": sigma}
Custom Guide Parameter Visualisation

We can achieve this result by simply using:

auto_guide = pyro.infer.autoguide.AutoNormal(model)
                  model_args=(is_cont_africa, ruggedness, log_gdp),
AutoGuide Parameter Visualisation

ELBO Minimisation

Variational inference approximates the true posterior by searching the space of variational distributions to find one that is most similar to the true posterior according to some measure of distance or divergence (Kullback-Leibler divergence KL(qϕ(z)pθmax(zx))KL(q_\phi(\textbf{z})|| p_{\theta_{max}}(\textbf z \mid \textbf x))), but computing this directly requires knowing the true posterior ahead of time, which would defeat the purpose.

We are interested in optimising this divergence, which might sound even harder, but in fact it is possible to use Bayes’ theorem to rewrite the definition of KL(qϕ(z)pθmax(zx))KL(q_\phi(\textbf{z})|| p_{\theta_{max}}(\textbf z \mid \textbf x)) as the difference between an intractable constant that does not depend on qϕq_\phi and a tractable term called the evidence lower bound (ELBO), defined below. Maximising this tractable term will therefore produce the same solution as minimising the original KLKL-divergence.


# These should be reset each training loop.
auto_guide = pyro.infer.autoguide.AutoNormal(model)
adam = pyro.optim.Adam({"lr": 0.02})  # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, auto_guide, adam, elbo)

losses = []
for step in range(1000 if not smoke_test else 2):  # Consider running for more steps.
    loss = svi.step(is_cont_africa, ruggedness, log_gdp)
    if step % 100 == 0:"Elbo loss: {}".format(loss))

plt.figure(figsize=(5, 2))
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");
ELBO Loss vs SVI Steps

At the end of this code block, we print the values of the learnable parameters which are stored in the param_store

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).data.cpu().numpy())
AutoNormal.locs.a 9.176606
AutoNormal.scales.a 0.07974657
AutoNormal.locs.bS -1.8880633
AutoNormal.scales.bS 0.13369536
AutoNormal.locs.bR -0.18150033
AutoNormal.scales.bR 0.042435974
AutoNormal.locs.bSR 0.3381295
AutoNormal.scales.bSR 0.08354119
AutoNormal.locs.sigma -2.2117462
AutoNormal.scales.sigma 0.053564996

By plotting a histogram of the density(frequency) of slope of the log\log(GDP) vs. Terrain Ruggedness graph, i.e. the slope of the regression line (γ\gamma), we reiterate on the fact that γ\gamma approximately follows a normal distribution.

with pyro.plate("samples", 800, dim=-1):
    samples = auto_guide(is_cont_africa, ruggedness)

gamma_within_africa = samples["bR"] + samples["bSR"]
gamma_outside_africa = samples["bR"]

fig = plt.figure(figsize=(10, 6))
             kde=True, stat="density",
             label="African nations")
             kde=True, stat="density",
             label="Non-African nations",
fig.suptitle("Density of Slope : log(GDP) vs. Terrain Ruggedness")
plt.xlabel("Slope of regression line")
Density vs Slope of Regression Line

Making Predictions with the Learned Model

We now calculate the posterior predictive distribution for this model, with the guide as auto_guide, and pick out 800800 samples from it. Next we use this to determine the 90%90\% confidence interval for our data.

predictive = pyro.infer.Predictive(model,guide=auto_guide, num_samples=800)
svi_samples = predictive(is_cont_africa, ruggedness, log_gdp=None)
svi_gdp = svi_samples["obs"]

predictions = pd.DataFrame({
    "cont_africa": is_cont_africa,
    "rugged": ruggedness,
    "y_mean": svi_gdp.mean(0).detach().cpu().numpy(),
    "y_perc_5": svi_gdp.kthvalue(int(len(svi_gdp) * 0.05), dim=0)[0].detach().cpu().numpy(),
    "y_perc_95": svi_gdp.kthvalue(int(len(svi_gdp) * 0.95), dim=0)[0].detach().cpu().numpy(),
    "true_gdp": log_gdp,
african_nations = predictions[predictions["cont_africa"] == 1].sort_values(by=["rugged"])
non_african_nations = predictions[predictions["cont_africa"] == 0].sort_values(by=["rugged"])

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey='all')
fig.suptitle("Posterior predictive distribution with 90% CI", fontsize=16)

ax[0].plot(non_african_nations["rugged"], non_african_nations["y_mean"])
                   non_african_nations["y_perc_95"], alpha=0.5)
ax[0].plot(non_african_nations["rugged"], non_african_nations["true_gdp"], "o")
ax[0].set(xlabel="Terrain Ruggedness Index", ylabel="log GDP (2000)", title="Non African Nations")

ax[1].plot(african_nations["rugged"], african_nations["y_mean"])
                   african_nations["y_perc_95"], alpha=0.5)
ax[1].plot(african_nations["rugged"], african_nations["true_gdp"], "o")
ax[1].set(xlabel="Terrain Ruggedness Index", ylabel="log GDP (2000)", title="African Nations");
Posterior Distribution

Appendix: Models in Pyro

Probabilistic models in Pyro are specified as Python functions model(*args, **kwargs) that generate observed data from latent variables using special primitive functions whose behaviour can be changed by Pyro’s internals depending on the high-level computation being performed.

Specifically, the different mathematical pieces of model() are encoded via the mapping:

  1. latent random variables     \iffpyro.sample
  2. observed random variables     \iffpyro.sample with the obs keyword argument
  3. learnable parameters     \iffpyro.param
  4. plates     \iffpyro.plate context managers

The pyro.sample Primitive

Probabilistic programs in Pyro are built up around samples from primitive probability distributions, marked by pyro.sample:

def sample(
    name: str,
    fn: pyro.distributions.Distribution,
    obs: typing.Optional[torch.Tensor] = **None**,
    infer: typing.Optional[dict] = **None**) -> torch.Tensor:
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

However, note that when any sample statement is observed, the cumulative effect of every other sample statement in a model changes following Bayes’ rule; it is the job of Pyro’s inference algorithms to “run the program backwards” and assign mathematically consistent values to all pyro.samplestatements in a model.

The pyro.param Primitive

The next primitive used in our model, pyro.param, is a frontend for reading from and writing to Pyro’s key-value parameter store:

def param(
    name: str,
    init: Optional[Union[torch.Tensor, Callable[..., torch.Tensor]]] = **None**,
    constraint: torch.distributions.constraints.Constraint = constraints.real
) -> torch.Tensor:

pyro.param values persist across model calls, unless the parameter store is updated by an optimization algorithm or cleared via pyro.clear_param_store(). Unlike pyro.sample, pyro.param can be called with the same name multiple times in a model; every call with the same name will return the same value. The global parameter store itself is accessible by calling pyro.get_param_store().

The pyro.plate Primitive

pyro.plate is Pyro’s formal encoding of plate notation, widely used in probabilistic machine learning to simplify visualisation and analysis of models with lots of conditionally independent and identically distributed random variables.

def plate(
    name: str,
    size: int,
    dim: Optional[int] = **None**,
) -> contextlib.AbstractContextManager: