In this tutorial, we explore hierarchical Bayesian regression with NumPyro and walk through the entire workflow in a structured manner. We start by generating synthetic data, then we define a probabilistic model that captures both global patterns and group-level variations. Through each snippet, we set up inference using NUTS, analyze posterior distributions, and perform posterior predictive checks to understand how well our model captures the underlying structure. By approaching the tutorial step by step, we build an intuitive understanding of how NumPyro enables flexible, scalable Bayesian modeling. Check out the Full Codes here.
try:
import numpyro
except ImportError:
!pip install -q "llvmlite>=0.45.1" "numpyro[cpu]" matplotlib pandas
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.diagnostics import hpdi
numpyro.set_host_device_count(1)
We set up our environment by installing NumPyro and importing all required libraries. We prepare JAX, NumPyro, and plotting tools so we have everything ready for Bayesian inference. As we run this cell, we ensure our Colab session is fully equipped for hierarchical modeling. Check out the Full Codes here.
def generate_data(key, n_groups=8, n_per_group=40):
k1, k2, k3, k4 = random.split(key, 4)
true_alpha = 1.0
true_beta = 0.6
sigma_alpha_g = 0.8
sigma_beta_g = 0.5
sigma_eps = 0.7
group_ids = np.repeat(np.arange(n_groups), n_per_group)
n = n_groups * n_per_group
alpha_g = random.normal(k1, (n_groups,)) * sigma_alpha_g
beta_g = random.normal(k2, (n_groups,)) * sigma_beta_g
x = random.normal(k3, (n,)) * 2.0
eps = random.normal(k4, (n,)) * sigma_eps
a = true_alpha + alpha_g[group_ids]
b = true_beta + beta_g[group_ids]
y = a + b * x + eps
df = pd.DataFrame({"y": np.array(y), "x": np.array(x), "group": group_ids})
truth = dict(true_alpha=true_alpha, true_beta=true_beta,
sigma_alpha_group=sigma_alpha_g, sigma_beta_group=sigma_beta_g,
sigma_eps=sigma_eps)
return df, truth
key = random.PRNGKey(0)
df, truth = generate_data(key)
x = jnp.array(df["x"].values)
y = jnp.array(df["y"].values)
groups = jnp.array(df["group"].values)
n_groups = int(df["group"].nunique())
We generate synthetic hierarchical data that mimics real-world group-level variation. We convert this data into JAX-friendly arrays so NumPyro can process it efficiently. By doing this, we lay the foundation for fitting a model that learns both global trends and group differences. Check out the Full Codes here.
def hierarchical_regression_model(x, group_idx, n_groups, y=None):
mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0.0, 5.0))
mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 5.0))
sigma_alpha = numpyro.sample("sigma_alpha", dist.HalfCauchy(2.0))
sigma_beta = numpyro.sample("sigma_beta", dist.HalfCauchy(2.0))
with numpyro.plate("group", n_groups):
alpha_g = numpyro.sample("alpha_g", dist.Normal(mu_alpha, sigma_alpha))
beta_g = numpyro.sample("beta_g", dist.Normal(mu_beta, sigma_beta))
sigma_obs = numpyro.sample("sigma_obs", dist.Exponential(1.0))
alpha = alpha_g[group_idx]
beta = beta_g[group_idx]
mean = alpha + beta * x
with numpyro.plate("data", x.shape[0]):
numpyro.sample("y", dist.Normal(mean, sigma_obs), obs=y)
nuts = NUTS(hierarchical_regression_model, target_accept_prob=0.9)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=1, progress_bar=True)
mcmc.run(random.PRNGKey(1), x=x, group_idx=groups, n_groups=n_groups, y=y)
samples = mcmc.get_samples()
We define our hierarchical regression model and launch the NUTS-based MCMC sampler. We allow NumPyro to explore the posterior space and learn parameters such as group intercepts and slopes. As this sampling completes, we obtain rich posterior distributions that reflect uncertainty at every level. Check out the Full Codes here.
def param_summary(arr):
arr = np.asarray(arr)
mean = arr.mean()
lo, hi = hpdi(arr, prob=0.9)
return mean, float(lo), float(hi)
for name in ["mu_alpha", "mu_beta", "sigma_alpha", "sigma_beta", "sigma_obs"]:
m, lo, hi = param_summary(samples[name])
print(f"{name}: mean={m:.3f}, HPDI=[{lo:.3f}, {hi:.3f}]")
predictive = Predictive(hierarchical_regression_model, samples, return_sites=["y"])
ppc = predictive(random.PRNGKey(2), x=x, group_idx=groups, n_groups=n_groups)
y_rep = np.asarray(ppc["y"])
group_to_plot = 0
mask = df["group"].values == group_to_plot
x_g = df.loc[mask, "x"].values
y_g = df.loc[mask, "y"].values
y_rep_g = y_rep[:, mask]
order = np.argsort(x_g)
x_sorted = x_g[order]
y_rep_sorted = y_rep_g[:, order]
y_med = np.median(y_rep_sorted, axis=0)
y_lo, y_hi = np.percentile(y_rep_sorted, [5, 95], axis=0)
plt.figure(figsize=(8, 5))
plt.scatter(x_g, y_g)
plt.plot(x_sorted, y_med)
plt.fill_between(x_sorted, y_lo, y_hi, alpha=0.3)
plt.show()
We analyze our posterior samples by computing summaries and performing posterior predictive checks. We visualize how well the model recreates observed data for a selected group. This step helps us understand how accurately our model captures the underlying generative process. Check out the Full Codes here.
alpha_g = np.asarray(samples["alpha_g"]).mean(axis=0)
beta_g = np.asarray(samples["beta_g"]).mean(axis=0)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].bar(range(n_groups), alpha_g)
axes[0].axhline(truth["true_alpha"], linestyle="--")
axes[1].bar(range(n_groups), beta_g)
axes[1].axhline(truth["true_beta"], linestyle="--")
plt.tight_layout()
plt.show()
We plot the estimated group-level intercepts and slopes to compare their learned patterns with the true values. We explore how each group behaves and how the model adapts to their differences. This final visualization brings together the complete picture of hierarchical inference.
In conclusion, we implemented how NumPyro allows us to model hierarchical relationships with clarity, efficiency, and strong expressive power. We observed how the posterior results reveal meaningful global and group-specific effects, and how predictive checks validate the model’s fit to the generated data. As we put everything together, we gain confidence in constructing, fitting, and interpreting hierarchical models using JAX-powered inference. This process strengthens our ability to apply Bayesian thinking to richer, more realistic datasets where multilevel structure is essential.
Check out the Full Codes here. Feel free to check out our GitHub Page for Tutorials, Codes and Notebooks. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Asif Razzaq is the CEO of Marktechpost Media Inc.. As a visionary entrepreneur and engineer, Asif is committed to harnessing the potential of Artificial Intelligence for social good. His most recent endeavor is the launch of an Artificial Intelligence Media Platform, Marktechpost, which stands out for its in-depth coverage of machine learning and deep learning news that is both technically sound and easily understandable by a wide audience. The platform boasts of over 2 million monthly views, illustrating its popularity among audiences.
