import numpy as np
import scipy.stats as sps
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'figure.figsize': (10.0, 6.0), 'font.size': 18})
t = np.linspace(-10,10,1000) # x-axis for our plots
Gaussian distribution parametrized by mean and precision rather than variance
NormalGamma distribution can be used as the conjugate prior over the mean and precision parameters of Gaussian distribution.
where gamma distribution
def NormalGamma_pdf(mu, lmbd, m, kappa, a, b):
# sps.norm.pdf takes mean and std (not variance), therefore 1.0/
# sps.gamma.pdf takes shape and scale (not rate), therefore 1.0/b
return sps.norm.pdf(mu, m, 1.0/np.sqrt(lmbd*kappa))* sps.gamma.pdf(lmbd, a, scale=1.0/b)
def NormalGamma_rvs(m, kappa, a, b, N):
# Sample from NormalGamma distribution
lmbd = sps.gamma.rvs(a, scale=1.0/b, size=N)
mu = sps.norm.rvs(m, 1.0/np.sqrt(lmbd*kappa), N)
return mu, lmbd
def NormalGamma_plot(m0, kappa0, a0, b0, limits):
mu_vals, lmbd_vals = np.meshgrid(np.linspace(limits[0], limits[1], 500), np.linspace(limits[2], limits[3], 500))
pdf = NormalGamma_pdf(mu_vals, lmbd_vals, m0, kappa0, a0, b0)
plt.imshow(pdf, origin='lower', cmap='Greys', extent=limits)
plt.xlabel('$\mu$'); plt.ylabel('$\lambda$')
#Plot some example NormalGamma distribution
m0, kappa0, a0, b0 = 0.0, 1.0, 1.0, 1.0
plot_limits=[-6, 6, 0.01, 5]
NormalGamma_plot(m0, kappa0, a0, b0, limits=plot_limits)
# Obtain few samples from the prior distribution.
mu_sampled, lmbd_sampled = NormalGamma_rvs(m0, kappa0, a0, b0, 50)
plt.plot(mu_sampled, lmbd_sampled, '+'); plt.axis(plot_limits)
# Each of these sample represents mean and precision. Plot the corresponding Gaussian distributions
plt.figure()
plt.plot(t, sps.norm.pdf(t[:,np.newaxis], mu_sampled, 1/np.sqrt(lmbd_sampled)));
Given observations from gaussian distribution and given prior the poterior distribution over the Gaussian parameters is:
,
where
Note that the prior parameters can be interpreted as follows:
def NormalGammaPosteriorParams(m0, kappa0, a0, b0, x):
# given prior parameters m0, kappa0, a0, b0 and observations x
# sufficient statistics N, mean, var calculate parameters of posterior
N = len(x)
mean = np.mean(x)
var = np.var(x)
kappaN = kappa0 + N
mN = (kappa0*m0 + mean*N) / kappaN;
aN = a0 + 0.5*N;
bN = b0 + 0.5 * N * (var + kappa0*(mean-m0)**2/(kappa0 + N))
return mN, kappaN, aN, bN
def NormalGammaPosteriorParams2(m0, kappa0, a0, b0, x):
# alternative more practical implementation, which does not crash for N=0
N = len(x)
f = np.sum(x)
s = np.sum(x**2)
kappaN = kappa0 + N
mN = (kappa0*m0 + f) / kappaN;
aN = a0 + 0.5*N;
bN = b0 + 0.5 * (s + kappa0 * m0**2 - kappaN * mN**2);
return mN, kappaN, aN, bN
# Generate observation from agiven Gaussian distribution
N = 5
mu = 1.0 # mean
sigma = 1 # standard deviation (std)
x = sps.norm.rvs(mu, sigma, N)
# Given NormalGamma prior and the observation, calculate and plot posterior distibution of the Gaussian parameters
mN, kappaN, aN, bN = NormalGammaPosteriorParams2(m0, kappa0, a0, b0, x)
NormalGamma_plot(mN, kappaN, aN, bN, limits=plot_limits)
# Obtain few samples from the posterior distribution.
mu_sampled, lmbd_sampled = NormalGamma_rvs(mN, kappaN, aN, bN, 50)
plt.plot(mu_sampled, lmbd_sampled, '+'); plt.axis(plot_limits)
# Each of these sample represents mean and precision. Plot the corresponding Gaussian distributions
pdfs_sampled = sps.norm.pdf(t[:,np.newaxis], mu_sampled, 1/np.sqrt(lmbd_sampled))
plt.figure()
plt.plot(t, pdfs_sampled);
For the Gaussian with NormalGamma prior, the posterior predictive distribution is the Student's t-distribution:
where:
# Plot the distribution from which the observations were generated
plt.plot(t, sps.norm.pdf(t, mu, sigma), 'r',lw=3)
# Plot predictive distribution, which depends on the prior and the observations
# (i.e. depends on the posterior distribution over the parameters \mu and \lambda)
plt.plot(t, sps.t.pdf(t, loc=mN, df=2*aN, scale=np.sqrt(bN*(kappaN+1)/aN/kappaN)), 'k',lw=1)
# Take all the sampled Gaussian distributions from the previous figure, average them and plot the resulting distribution
# For large number of samples, this should be good approximation to the predictive distribution
plt.plot(t, pdfs_sampled.mean(axis=1));