Skip to content

Instantly share code, notes, and snippets.

@whitews
Last active August 29, 2015 14:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save whitews/c350a0a4a69e4921d908 to your computer and use it in GitHub Desktop.
Save whitews/c350a0a4a69e4921d908 to your computer and use it in GitHub Desktop.
Find modes using Expectation-Maximization of a Gaussian Mixture Model
import numpy as np
import matplotlib.pyplot as plt
from utils import get_gmm_modes
data = np.array([3643, 4699, 91, 144, 13, 3662, 3870, 4360, 90,
62, 8100, -48, 3679, 4479, 115, 4105, 69, -22,
-55, 4339, 87, -24, 4273, 142, 4956, 4256, 4162,
134, 4090, 86, 4100, 4366, 150, 4600, -63, 4346,
136, 4531, 3966, 4394, 67, 58, -48, 3793, 4239,
4239, 88, 4514, 99, 4399, 4212, 144, 12, 3768,
1731, 4381, 90, 4669, 1, 33, -12, 8865, 4095,
96, -55, 58, 61, 163, -18, -98, 4218, 4251,
120, 4028, 4369, 4339, 4366, 82, -6, 36, 3952,
99, 138, 32, 4274, 4390, 132, 132, 4498, 62,
4242, 164, 22, 220, 4723, 4534, 82, 175, 4150,
121, 153, 103, 144, -4, 122, 3931, 4747, -85,
4677, 3692, 4682, 98, 4143, 10, -36, 217, 92,
76, 3819, 33, 188, 52, 3859, 144, 157, -20,
243, 243, 3973, 4080, 4810, 4232, 3997, 63, 4540,
4424, 4309, 15, 4152, 3561, 114, 3625, 3946, 74,
93, 4339, 72, 4412, 70, 208, 162, 4621, 225,
4249, 3992, 93, 66, 4231, 3870, 146, 222, 222,
3836, 109, 34, 3631, 198, 4329, 30, 4330, 4378,
4378, 40, 3859, 4381, 146, 73, 4767, -16, 157,
4304, 4230, 4285, 3931, 32, 4419, 4623, 4174, 222,
4096, 4230, 114, 145, -54, 14, 4248, 4248, 4280,
37, 3903, 3804, 4071, 234, 4818, 49, 4460, 156,
4176, 4384, 97, 142, 4246, 138, 3564, 5096, 4081,
3806, 4435, 110, 24, 4358, 163, 76, 72, 4191,
4191, 60, 4302, 4302, 216, 13, 208, 80, 67,
4176, 49, 4090, 106, 66, 4456, 4456, 122, -51,
212, 3817, 55, 48, 118, 146, 99, 3124, 200,
105, 105, 130, 163, 4076, 154, 4131, 132, 4582,
4582, 26, 4447, 4182, 78, 70, 3602, -16, 105,
4152, 100, 12646, 164, 4500, 4466, 254, 231, 3592,
3592, 3883, 4714, 152, 4257, 7, 160, 69, 4111,
79, 4106, 4670, 236, 135, 4173, 4874, 201, 93,
42, 4519, 103, 93, 4164, 75, 4989, 4194, 4089,
92, 8, 4126, -7, 208, 4411, 4189, 4582, 114,
46, 43, 216, 51, 4236, 4105, 3877, -8, 4369,
99, 37, 132, 4312, 4312, 66, 68, 80, 8,
4036, 4173, 164, 8337, 128, 3754, 87, 196, 4230,
104, -40, 241, 4357, 4340, 139, -6, 86, 97,
121, 123, -144, 64, 5050, 4921, 136, -81, 4285,
4153, 3843, 6, 31, 4036, 3859, 136, 54, 4700,
32, 4228, 3943, 4346, 3610, 60, 211, 133, 224,
218, 3942, 4719, -37, 3606, 79, 43, 218, 256,
4632, 4291, 123, 4537, 154, 3756, 12864, 3966, 4576,
4262, 231, 231, 123, 85, 4090, 122, 105, 69,
4405, 134, 30, 3889, 121, 230, 3, 4417, -68,
123, 123, 8179, -1, 3549, 13609, 56, 3975, 1281,
4809, 28, 4098, 91, 4503, 169, 165, 4296, 4549,
4874, -37, 70, 4327, 6, 45, -30, 98, 175,
4488, 4198, 114, -24, 168, 4438, 3861, 84, 4321,
4507, 93, 4123, 4422, 3870, 96, 120, -25, 4380,
-73, 91, 4239, 4130, 3484, 152, 169, 3024, 183,
44, 151, 4646, 24, 4131, 1800, 4414, 4232, 4789,
4276, 4276, 56, 33, -60, 128, 132, 142, 4719,
4719, 4272, 4401, 4401, 67, 63, 180, 4135, 4618,
48, 6, 121, 73, 134])
print "Length: " + str(len(data))
print "Mean: " + str(data.mean())
print "Max: " + str(data.max())
print "Min: " + str(data.min())
# Initial guess of parameters and initializations
# doesn't have to be a good guess!
max_iterations = 1000
cluster_count = 2
### Example optional parameters
# weights = [0.5, 0.5]
# mu_list = [data.min() * .33, data.max() * .66]
# sigma_list = [400, 800]
(gmm_pi, gmm_mu, gmm_sigma) = get_gmm_modes(
data,
max_iteration=max_iterations,
cluster_count=cluster_count
# cluster_weights=weights,
# mu_list=mu_list,
# sigma_list=sigma_list
)
print "Pi: " + str(gmm_pi)
print "Mu: " + str(gmm_mu)
print "Sigma: " + str(gmm_sigma)
plt.figure(num=None, figsize=(16, 8))
hist = plt.hist(data, bins=len(data)/2, color='steelblue')
for m in gmm_mu:
plt.axvline(m, color='orange', linestyle='dashed', linewidth=2)
plt.show()
import numpy as np
def parametrized_normal_pdf(data, mu, sigma):
return np.exp(-0.5 * ((data - mu)/sigma)**2) / (sigma * np.sqrt(2*np.pi))
def gaussian_mixture_model(data, weights, mu_list, sigma_list):
gmm_sum = 0
for i, weight in enumerate(weights):
gmm_sum += weight * parametrized_normal_pdf(
data, mu_list[i], sigma_list[i])
return gmm_sum
def get_gmm_modes(
data,
max_iteration,
cluster_count,
cluster_weights=None,
mu_list=None,
sigma_list=None):
"""
:param data: 1-dimensional data array
:param max_iteration: Maximum number of iterations. Note: The actual
number of iterations may be less if convergence is found.
:param cluster_count: Number of cluster to attempt to fit
:param cluster_weights: (Optional) List of cluster weights to use for
initial values. Length must match cluster_count. If not specified,
the default will be equal weighting (1 / cluster_count)
:param mu_list: (Optional) List of means to use for initial values. If
not specified, the default will to locate the most common values
using a histogram.
:param sigma_list: (Optional) List of width (variance) to use for
initial values. If not specified, the default will be 1/2 the data
range divided by the cluster_count
For example, with a cluster_count of 2, and a data range
from 0 to 100, the initial widths will all be 25.
"""
# input validation
if cluster_weights:
if sum(cluster_weights) != 1.0:
raise ValueError("Sum of cluster weights must equal 1.0")
else:
cluster_weights = list()
for i in range(0, cluster_count):
cluster_weights.append(1.0/cluster_count)
if not mu_list:
mu_list = list()
data_hist = np.histogram(data, bins=data.max()-data.min())
for i in range(1, cluster_count + 1):
max_index = data_hist[0].argmax()
mean = data_hist[1][max_index]
mu_list.append(mean)
# set max index to zero to find the next most common index
data_hist[0][max_index] = 0
if not sigma_list:
sigma_list = list()
width = (max(data) - min(data)) / cluster_count
for i in range(0, cluster_count):
sigma_list.append(width)
if len(cluster_weights) != cluster_count:
raise ValueError("Length of cluster weights != cluster count")
if len(cluster_weights) != cluster_count:
raise ValueError("Length of mu list != cluster count")
if len(sigma_list) != cluster_count:
raise ValueError("Length of sigma list != cluster count")
counter = 0
converged = False
new_mu_list = mu_list
new_sigma_list = sigma_list
new_pi_list = cluster_weights
n_ = np.zeros(len(cluster_weights))
# rolling diffs of parameters for convergence test
rolling_pis = [1, 1, 1, 1, 1]
rolling_mus = [1, 1, 1, 1, 1]
rolling_sigmas = [1, 1, 1, 1, 1]
while not converged:
tmp_mu_list = list()
tmp_sigma_list = list()
tmp_pi_list = list()
counter += 1
# Compute responsibility function and new parameters
for i, pi in enumerate(cluster_weights):
a = new_pi_list[i] * parametrized_normal_pdf(
data, new_mu_list[i], new_sigma_list[i])
b = gaussian_mixture_model(
data, new_pi_list, new_mu_list, new_sigma_list)
# May have zeroes in divisor, ignore invalid warning
with np.errstate(invalid='ignore'):
gamma = a / b
gamma = np.nan_to_num(gamma)
n_[i] = 1. * gamma.sum()
if np.sum(n_[i]) == 0:
tmp_mu_list.append(0)
tmp_sigma_list.append(0)
else:
tmp_mu_list.append(sum(gamma * data) / n_[i])
tmp_sigma_list.append(
np.sqrt(sum(gamma * (data - tmp_mu_list[i]) ** 2) / n_[i]))
# the spread may collapse to zero for non-existent clusters,
# which causes a divide by zero issue, so reset to small value
if tmp_sigma_list[i] == 0:
tmp_sigma_list[i] = 0.0000001
tmp_pi_list.append(n_[i] / data.size)
# Check for maximum iteration or convergence
if counter >= max_iteration:
converged = True
else:
pi_diff = np.abs(np.array(new_pi_list) - np.array(tmp_pi_list))
rolling_pis[(counter-1) % len(rolling_pis)] = np.sum(pi_diff)
mu_diff = np.abs(np.array(new_mu_list) - np.array(tmp_mu_list))
rolling_mus[(counter-1) % len(rolling_mus)] = np.sum(mu_diff)
sigma_diff = np.abs(np.array(new_sigma_list) - np.array(tmp_sigma_list))
rolling_sigmas[(counter-1) % len(rolling_sigmas)] = np.sum(sigma_diff)
if sum(rolling_pis + rolling_mus + rolling_sigmas) < 1e-20:
converged = True
new_mu_list = tmp_mu_list
new_sigma_list = tmp_sigma_list
new_pi_list = tmp_pi_list
return new_pi_list, new_mu_list, new_sigma_list
@dleehr
Copy link

dleehr commented May 5, 2014

On line 3,

from utils import get_gmm_modes

should read

from gmm_modes import get_gmm_modes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment