In [1]:
import numpy as np
import scipy.stats as sps
import matplotlib.pyplot as plt
from scipy.misc import logsumexp
#%matplotlib qt5
%matplotlib inline
plt.rcParams.update({'figure.figsize': (10.0, 6.0), 'font.size': 18})

Gaussian Mixture Model (GMM)

  • probability density finction
    p(x|{μc},{σc2},{πc})=c=1CN(x|μc,σc2)πc

where

  • {μc} is the set of C means
  • {σc2} is the set of C variances
  • {πc} is set of C weights such that c=1Cπc=1

and single variate Gasussian distribution

p(xμ,σ2)=N(xμ,σ2)=12πσ2exp{(xμ)22σ2}
In [2]:
# Plot GMM pdf together with the individual components pdfs; return the GMM pdf line
def plot_GMM(t, mus, sigmas, pis):
  p_xz = sps.norm.pdf(t[:,np.newaxis], mus, sigmas) * pis # all GMM components are evaluated at once
  px = np.sum(p_xz, axis=1)
  plt.plot(t, p_xz, ':')
  plt.plot(t, px, 'k')
  return px


#Handcraft some GMM parameter 
mus = [-4.0, 0.0, 4.0, 5]
sigmas = [1.0, 1.4, 1.2, 1]
pis = [0.1, 0.4, 0.2, 0.3]

t = np.linspace(-10,10,1000)
true_GMM_pdf = plot_GMM(t, mus, sigmas, pis)

# Generate N datapoints from this GMM
N = 100
Nc = sps.multinomial.rvs(N, pis) # Draw observation counts for each component from multinomial distribution
x = sps.norm.rvs(np.repeat(mus, Nc), np.repeat(sigmas, Nc))
np.random.shuffle(x)
plt.plot(x, np.zeros_like(x), '+k');

GMM - EM algorithm

  • E-step
γnc=P(zn=c|xn,ηold)=p(xn|zn=c,ηold)P(zn=c|ηold)p(xn|ηold)=N(xn|μcoldσc2old)πcoldkN(xn|μkold,σk2old)πkold
  • M-step
μcnew=nγncxnnγncσc2new=nγnc(xnμcnew)2nγncπc=nγncN
In [3]:
#Choose some initial parameters
C = 3        # number of GMM components 
mus = x[:C]  # we choose few first observations as the initial means
sigmas = np.repeat(np.std(x), C) # sigma for all components is set to std of the the training data
pis = np.ones(C)/C

plt.clf()
plt.plot(t, true_GMM_pdf, 'gray')
plot_GMM(t, mus, sigmas, pis);
In [4]:
for _ in range(50):
  #E-step
  log_p_xz = sps.norm.logpdf(x[:,np.newaxis], mus, sigmas) + np.log(pis)
  log_p_x = logsumexp(log_p_xz, axis=1, keepdims=True)
  print "Training data log likelihood:", log_p_x.sum()

  gammas = np.exp(log_p_xz - log_p_x)
  #M-step
  Nc = gammas.sum(axis=0)
  mus =  x.dot(gammas) / Nc
  sigmas =  np.sqrt((x**2).dot(gammas) / Nc - mus**2) # we use std, not variance!
  pis = Nc / Nc.sum()
    
plot_GMM(t, mus, sigmas, pis)

plt.clf()
plt.plot(t, true_GMM_pdf, 'gray')
plot_GMM(t, mus, sigmas, pis);
plt.plot(x, np.zeros_like(x), '+k');
Training data log likelihood: -275.687446231
Training data log likelihood: -257.756537213
Training data log likelihood: -256.73326682
Training data log likelihood: -255.418278557
Training data log likelihood: -253.868553645
Training data log likelihood: -252.588253235
Training data log likelihood: -251.9249919
Training data log likelihood: -251.682616958
Training data log likelihood: -251.604724967
Training data log likelihood: -251.57672686
Training data log likelihood: -251.562926418
Training data log likelihood: -251.553611058
Training data log likelihood: -251.546153868
Training data log likelihood: -251.539755108
Training data log likelihood: -251.534101033
Training data log likelihood: -251.529025116
Training data log likelihood: -251.52441792
Training data log likelihood: -251.520199681
Training data log likelihood: -251.516309709
Training data log likelihood: -251.51270094
Training data log likelihood: -251.509336448
Training data log likelihood: -251.506186972
Training data log likelihood: -251.503229061
Training data log likelihood: -251.500443693
Training data log likelihood: -251.497815217
Training data log likelihood: -251.495330568
Training data log likelihood: -251.492978667
Training data log likelihood: -251.490749977
Training data log likelihood: -251.488636164
Training data log likelihood: -251.486629848
Training data log likelihood: -251.484724409
Training data log likelihood: -251.482913848
Training data log likelihood: -251.48119267
Training data log likelihood: -251.479555809
Training data log likelihood: -251.47799856
Training data log likelihood: -251.476516528
Training data log likelihood: -251.475105591
Training data log likelihood: -251.473761871
Training data log likelihood: -251.472481707
Training data log likelihood: -251.471261635
Training data log likelihood: -251.470098376
Training data log likelihood: -251.468988817
Training data log likelihood: -251.467930006
Training data log likelihood: -251.466919136
Training data log likelihood: -251.465953544
Training data log likelihood: -251.465030697
Training data log likelihood: -251.464148189
Training data log likelihood: -251.463303737
Training data log likelihood: -251.462495171
Training data log likelihood: -251.461720434