Last active
August 29, 2015 14:00
-
-
Save whitews/c350a0a4a69e4921d908 to your computer and use it in GitHub Desktop.
Find modes using Expectation-Maximization of a Gaussian Mixture Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from utils import get_gmm_modes | |
data = np.array([3643, 4699, 91, 144, 13, 3662, 3870, 4360, 90, | |
62, 8100, -48, 3679, 4479, 115, 4105, 69, -22, | |
-55, 4339, 87, -24, 4273, 142, 4956, 4256, 4162, | |
134, 4090, 86, 4100, 4366, 150, 4600, -63, 4346, | |
136, 4531, 3966, 4394, 67, 58, -48, 3793, 4239, | |
4239, 88, 4514, 99, 4399, 4212, 144, 12, 3768, | |
1731, 4381, 90, 4669, 1, 33, -12, 8865, 4095, | |
96, -55, 58, 61, 163, -18, -98, 4218, 4251, | |
120, 4028, 4369, 4339, 4366, 82, -6, 36, 3952, | |
99, 138, 32, 4274, 4390, 132, 132, 4498, 62, | |
4242, 164, 22, 220, 4723, 4534, 82, 175, 4150, | |
121, 153, 103, 144, -4, 122, 3931, 4747, -85, | |
4677, 3692, 4682, 98, 4143, 10, -36, 217, 92, | |
76, 3819, 33, 188, 52, 3859, 144, 157, -20, | |
243, 243, 3973, 4080, 4810, 4232, 3997, 63, 4540, | |
4424, 4309, 15, 4152, 3561, 114, 3625, 3946, 74, | |
93, 4339, 72, 4412, 70, 208, 162, 4621, 225, | |
4249, 3992, 93, 66, 4231, 3870, 146, 222, 222, | |
3836, 109, 34, 3631, 198, 4329, 30, 4330, 4378, | |
4378, 40, 3859, 4381, 146, 73, 4767, -16, 157, | |
4304, 4230, 4285, 3931, 32, 4419, 4623, 4174, 222, | |
4096, 4230, 114, 145, -54, 14, 4248, 4248, 4280, | |
37, 3903, 3804, 4071, 234, 4818, 49, 4460, 156, | |
4176, 4384, 97, 142, 4246, 138, 3564, 5096, 4081, | |
3806, 4435, 110, 24, 4358, 163, 76, 72, 4191, | |
4191, 60, 4302, 4302, 216, 13, 208, 80, 67, | |
4176, 49, 4090, 106, 66, 4456, 4456, 122, -51, | |
212, 3817, 55, 48, 118, 146, 99, 3124, 200, | |
105, 105, 130, 163, 4076, 154, 4131, 132, 4582, | |
4582, 26, 4447, 4182, 78, 70, 3602, -16, 105, | |
4152, 100, 12646, 164, 4500, 4466, 254, 231, 3592, | |
3592, 3883, 4714, 152, 4257, 7, 160, 69, 4111, | |
79, 4106, 4670, 236, 135, 4173, 4874, 201, 93, | |
42, 4519, 103, 93, 4164, 75, 4989, 4194, 4089, | |
92, 8, 4126, -7, 208, 4411, 4189, 4582, 114, | |
46, 43, 216, 51, 4236, 4105, 3877, -8, 4369, | |
99, 37, 132, 4312, 4312, 66, 68, 80, 8, | |
4036, 4173, 164, 8337, 128, 3754, 87, 196, 4230, | |
104, -40, 241, 4357, 4340, 139, -6, 86, 97, | |
121, 123, -144, 64, 5050, 4921, 136, -81, 4285, | |
4153, 3843, 6, 31, 4036, 3859, 136, 54, 4700, | |
32, 4228, 3943, 4346, 3610, 60, 211, 133, 224, | |
218, 3942, 4719, -37, 3606, 79, 43, 218, 256, | |
4632, 4291, 123, 4537, 154, 3756, 12864, 3966, 4576, | |
4262, 231, 231, 123, 85, 4090, 122, 105, 69, | |
4405, 134, 30, 3889, 121, 230, 3, 4417, -68, | |
123, 123, 8179, -1, 3549, 13609, 56, 3975, 1281, | |
4809, 28, 4098, 91, 4503, 169, 165, 4296, 4549, | |
4874, -37, 70, 4327, 6, 45, -30, 98, 175, | |
4488, 4198, 114, -24, 168, 4438, 3861, 84, 4321, | |
4507, 93, 4123, 4422, 3870, 96, 120, -25, 4380, | |
-73, 91, 4239, 4130, 3484, 152, 169, 3024, 183, | |
44, 151, 4646, 24, 4131, 1800, 4414, 4232, 4789, | |
4276, 4276, 56, 33, -60, 128, 132, 142, 4719, | |
4719, 4272, 4401, 4401, 67, 63, 180, 4135, 4618, | |
48, 6, 121, 73, 134]) | |
print "Length: " + str(len(data)) | |
print "Mean: " + str(data.mean()) | |
print "Max: " + str(data.max()) | |
print "Min: " + str(data.min()) | |
# Initial guess of parameters and initializations | |
# doesn't have to be a good guess! | |
max_iterations = 1000 | |
cluster_count = 2 | |
### Example optional parameters | |
# weights = [0.5, 0.5] | |
# mu_list = [data.min() * .33, data.max() * .66] | |
# sigma_list = [400, 800] | |
(gmm_pi, gmm_mu, gmm_sigma) = get_gmm_modes( | |
data, | |
max_iteration=max_iterations, | |
cluster_count=cluster_count | |
# cluster_weights=weights, | |
# mu_list=mu_list, | |
# sigma_list=sigma_list | |
) | |
print "Pi: " + str(gmm_pi) | |
print "Mu: " + str(gmm_mu) | |
print "Sigma: " + str(gmm_sigma) | |
plt.figure(num=None, figsize=(16, 8)) | |
hist = plt.hist(data, bins=len(data)/2, color='steelblue') | |
for m in gmm_mu: | |
plt.axvline(m, color='orange', linestyle='dashed', linewidth=2) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def parametrized_normal_pdf(data, mu, sigma): | |
return np.exp(-0.5 * ((data - mu)/sigma)**2) / (sigma * np.sqrt(2*np.pi)) | |
def gaussian_mixture_model(data, weights, mu_list, sigma_list): | |
gmm_sum = 0 | |
for i, weight in enumerate(weights): | |
gmm_sum += weight * parametrized_normal_pdf( | |
data, mu_list[i], sigma_list[i]) | |
return gmm_sum | |
def get_gmm_modes( | |
data, | |
max_iteration, | |
cluster_count, | |
cluster_weights=None, | |
mu_list=None, | |
sigma_list=None): | |
""" | |
:param data: 1-dimensional data array | |
:param max_iteration: Maximum number of iterations. Note: The actual | |
number of iterations may be less if convergence is found. | |
:param cluster_count: Number of cluster to attempt to fit | |
:param cluster_weights: (Optional) List of cluster weights to use for | |
initial values. Length must match cluster_count. If not specified, | |
the default will be equal weighting (1 / cluster_count) | |
:param mu_list: (Optional) List of means to use for initial values. If | |
not specified, the default will to locate the most common values | |
using a histogram. | |
:param sigma_list: (Optional) List of width (variance) to use for | |
initial values. If not specified, the default will be 1/2 the data | |
range divided by the cluster_count | |
For example, with a cluster_count of 2, and a data range | |
from 0 to 100, the initial widths will all be 25. | |
""" | |
# input validation | |
if cluster_weights: | |
if sum(cluster_weights) != 1.0: | |
raise ValueError("Sum of cluster weights must equal 1.0") | |
else: | |
cluster_weights = list() | |
for i in range(0, cluster_count): | |
cluster_weights.append(1.0/cluster_count) | |
if not mu_list: | |
mu_list = list() | |
data_hist = np.histogram(data, bins=data.max()-data.min()) | |
for i in range(1, cluster_count + 1): | |
max_index = data_hist[0].argmax() | |
mean = data_hist[1][max_index] | |
mu_list.append(mean) | |
# set max index to zero to find the next most common index | |
data_hist[0][max_index] = 0 | |
if not sigma_list: | |
sigma_list = list() | |
width = (max(data) - min(data)) / cluster_count | |
for i in range(0, cluster_count): | |
sigma_list.append(width) | |
if len(cluster_weights) != cluster_count: | |
raise ValueError("Length of cluster weights != cluster count") | |
if len(cluster_weights) != cluster_count: | |
raise ValueError("Length of mu list != cluster count") | |
if len(sigma_list) != cluster_count: | |
raise ValueError("Length of sigma list != cluster count") | |
counter = 0 | |
converged = False | |
new_mu_list = mu_list | |
new_sigma_list = sigma_list | |
new_pi_list = cluster_weights | |
n_ = np.zeros(len(cluster_weights)) | |
# rolling diffs of parameters for convergence test | |
rolling_pis = [1, 1, 1, 1, 1] | |
rolling_mus = [1, 1, 1, 1, 1] | |
rolling_sigmas = [1, 1, 1, 1, 1] | |
while not converged: | |
tmp_mu_list = list() | |
tmp_sigma_list = list() | |
tmp_pi_list = list() | |
counter += 1 | |
# Compute responsibility function and new parameters | |
for i, pi in enumerate(cluster_weights): | |
a = new_pi_list[i] * parametrized_normal_pdf( | |
data, new_mu_list[i], new_sigma_list[i]) | |
b = gaussian_mixture_model( | |
data, new_pi_list, new_mu_list, new_sigma_list) | |
# May have zeroes in divisor, ignore invalid warning | |
with np.errstate(invalid='ignore'): | |
gamma = a / b | |
gamma = np.nan_to_num(gamma) | |
n_[i] = 1. * gamma.sum() | |
if np.sum(n_[i]) == 0: | |
tmp_mu_list.append(0) | |
tmp_sigma_list.append(0) | |
else: | |
tmp_mu_list.append(sum(gamma * data) / n_[i]) | |
tmp_sigma_list.append( | |
np.sqrt(sum(gamma * (data - tmp_mu_list[i]) ** 2) / n_[i])) | |
# the spread may collapse to zero for non-existent clusters, | |
# which causes a divide by zero issue, so reset to small value | |
if tmp_sigma_list[i] == 0: | |
tmp_sigma_list[i] = 0.0000001 | |
tmp_pi_list.append(n_[i] / data.size) | |
# Check for maximum iteration or convergence | |
if counter >= max_iteration: | |
converged = True | |
else: | |
pi_diff = np.abs(np.array(new_pi_list) - np.array(tmp_pi_list)) | |
rolling_pis[(counter-1) % len(rolling_pis)] = np.sum(pi_diff) | |
mu_diff = np.abs(np.array(new_mu_list) - np.array(tmp_mu_list)) | |
rolling_mus[(counter-1) % len(rolling_mus)] = np.sum(mu_diff) | |
sigma_diff = np.abs(np.array(new_sigma_list) - np.array(tmp_sigma_list)) | |
rolling_sigmas[(counter-1) % len(rolling_sigmas)] = np.sum(sigma_diff) | |
if sum(rolling_pis + rolling_mus + rolling_sigmas) < 1e-20: | |
converged = True | |
new_mu_list = tmp_mu_list | |
new_sigma_list = tmp_sigma_list | |
new_pi_list = tmp_pi_list | |
return new_pi_list, new_mu_list, new_sigma_list |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
On line 3,
from utils import get_gmm_modes
should read
from gmm_modes import get_gmm_modes