In [1]:
import numpy as np
import scipy.stats as sps
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'figure.figsize': (10.0, 6.0), 'font.size': 18})
t = np.linspace(-10,10,1000) # x-axis for our plots

Bayesian inference for Gaussian distribution

Gaussian distribution parametrized by mean μ and precision λ=1σ2 rather than variance σ2

p(x|μ,λ)=N(xμ,λ1)=λ2πexp{λ(xμ)22}

NormalGamma distribution can be used as the conjugate prior over the mean μ and precision λ parameters of Gaussian distribution.

NormalGamma(μ,λ|m,κ,a,b)=N(μ|m,(κλ)1)Gam(λ|a,b)

where gamma distribution

Gam(λa,b)=1Γ(a)baλa1exp{bλ}
In [2]:
def NormalGamma_pdf(mu, lmbd, m, kappa, a, b):
    # sps.norm.pdf takes mean and std (not variance), therefore 1.0/
    # sps.gamma.pdf takes shape and scale (not rate), therefore 1.0/b
    return sps.norm.pdf(mu, m, 1.0/np.sqrt(lmbd*kappa))* sps.gamma.pdf(lmbd, a, scale=1.0/b)

def NormalGamma_rvs(m, kappa, a, b, N):
    # Sample from NormalGamma distribution
    lmbd = sps.gamma.rvs(a, scale=1.0/b, size=N)
    mu = sps.norm.rvs(m, 1.0/np.sqrt(lmbd*kappa), N)
    return mu, lmbd

def NormalGamma_plot(m0, kappa0, a0, b0, limits):
   mu_vals, lmbd_vals = np.meshgrid(np.linspace(limits[0], limits[1], 500), np.linspace(limits[2], limits[3], 500))
   pdf = NormalGamma_pdf(mu_vals, lmbd_vals, m0, kappa0, a0, b0)
   plt.imshow(pdf, origin='lower', cmap='Greys', extent=limits)
   plt.xlabel('$\mu$'); plt.ylabel('$\lambda$')


#Plot some example NormalGamma distribution
m0, kappa0, a0, b0 = 0.0, 1.0, 1.0, 1.0
plot_limits=[-6, 6, 0.01, 5]
NormalGamma_plot(m0, kappa0, a0, b0, limits=plot_limits)

# Obtain few samples from the prior distribution.
mu_sampled, lmbd_sampled = NormalGamma_rvs(m0, kappa0, a0, b0, 50)
plt.plot(mu_sampled, lmbd_sampled, '+'); plt.axis(plot_limits)

# Each of these sample represents mean and precision. Plot the corresponding Gaussian distributions
plt.figure()
plt.plot(t, sps.norm.pdf(t[:,np.newaxis], mu_sampled, 1/np.sqrt(lmbd_sampled)));

Given N observations x=[x1,x2,,xN] from gaussian distribution N(xμ,λ1) and given prior p(μ,λ1)=NormalGamma(μ,λ|m0,κ0,a0,b0) the poterior distribution over the Gaussian parameters is:

p(μ,λ1|x)=NormalGamma(μ,λ|mN,κN,aN,bN)

,

where

mN=κ0m0+Nx¯κ0+NκN=κ0+NaN=a0+N2bN=b0+N2(s+κ0(x¯m0)2κ0+N)x¯=1Nixis=1Ni(xix¯)2

Note that the prior parameters can be interpreted as follows:

  • 2a0 - prior number of observation for precision (or variance)
  • b0/a0 - prior variance (around m0)
  • κ0 - number of prior observations for mean
  • m0 - prior mean
In [11]:
def NormalGammaPosteriorParams(m0, kappa0, a0, b0, x):
  # given prior parameters m0, kappa0, a0, b0 and observations x
  # sufficient statistics N, mean, var calculate parameters of posterior
  N = len(x)
  mean = np.mean(x)
  var = np.var(x)

  kappaN = kappa0 + N
  mN = (kappa0*m0 + mean*N) / kappaN;
  aN = a0 + 0.5*N;
  bN = b0 + 0.5 * N * (var + kappa0*(mean-m0)**2/(kappa0 + N))
  return mN, kappaN, aN, bN

def NormalGammaPosteriorParams2(m0, kappa0, a0, b0, x):
  # alternative more practical implementation, which does not crash for N=0
  N = len(x)
  f = np.sum(x)
  s = np.sum(x**2)

  kappaN = kappa0 + N
  mN = (kappa0*m0 + f) / kappaN;
  aN = a0 + 0.5*N;
  bN  = b0 + 0.5 * (s + kappa0 * m0**2 - kappaN * mN**2);
  return mN, kappaN, aN, bN


# Generate observation from agiven Gaussian distribution 
N = 5
mu = 1.0    # mean
sigma = 1 # standard deviation (std)
x = sps.norm.rvs(mu, sigma, N)

# Given NormalGamma prior and the observation, calculate and plot posterior distibution of the Gaussian parameters
mN, kappaN, aN, bN = NormalGammaPosteriorParams2(m0, kappa0, a0, b0, x)
NormalGamma_plot(mN, kappaN, aN, bN, limits=plot_limits)

# Obtain few samples from the posterior distribution.
mu_sampled, lmbd_sampled = NormalGamma_rvs(mN, kappaN, aN, bN, 50)
plt.plot(mu_sampled, lmbd_sampled, '+'); plt.axis(plot_limits)

# Each of these sample represents mean and precision. Plot the corresponding Gaussian distributions
pdfs_sampled = sps.norm.pdf(t[:,np.newaxis], mu_sampled, 1/np.sqrt(lmbd_sampled))
plt.figure()
plt.plot(t, pdfs_sampled);

Predictive probability

For the Gaussian with NormalGamma prior, the posterior predictive distribution is the Student's t-distribution:

p(xx)=p(x|μ,λ)p(μ,λx)dμdλ=N(x|μ,λ)NormalGamma(μ,λ|m0,κ0,a0,b0)dμdλ=St(xmN,2aN,aNκNbN(κN+1))

where:

St(xμ,ν,γ)=Γ(ν2+12)Γ(ν2)(γπν)12[1+γ(xμ)2ν]ν212
In [12]:
# Plot the distribution from which the observations were generated
plt.plot(t, sps.norm.pdf(t, mu, sigma), 'r',lw=3)

# Plot predictive distribution, which depends on the prior and the observations 
# (i.e. depends on the posterior distribution over the parameters \mu and \lambda)
plt.plot(t, sps.t.pdf(t, loc=mN, df=2*aN, scale=np.sqrt(bN*(kappaN+1)/aN/kappaN)), 'k',lw=1)

# Take all the sampled Gaussian distributions from the previous figure, average them and plot the resulting distribution
# For large number of samples, this should be good approximation to the predictive distribution 
plt.plot(t, pdfs_sampled.mean(axis=1));