Skip to content

Instantly share code, notes, and snippets.

@tmcw
Last active April 7, 2016 20:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tmcw/3701320 to your computer and use it in GitHub Desktop.
Save tmcw/3701320 to your computer and use it in GitHub Desktop.
<!DOCTYPE html>
<html>
<head>
<title>k-means in one dimension</title>
<script src="https://d3js.org/d3.v2.js"></script>
<script src="kmeans.js"></script>
<style>
body { margin: 0; padding: 0; font:normal 12px/20px sans-serif; }
#vis { width: 640px; }
rect.title {
fill:#fff;
}
rect.stage {
fill:#ccc;
}
rect.closest {
fill:#A0D1F2;
}
circle.control {
fill: #d23634;
}
g.row-0 {
cursor:move;
}
circle.mean {
fill:#006EB8;
}
text.controltext {
fill:#fff;
text-anchor:middle;
}
</style>
</head>
<body>
<div id='vis'>
<script src='kmeans-diagram.js'></script>
</div>
</body>
</html>
var w = 560,
h = 500,
row = 100,
padding = 40;
var x = d3.scale.linear()
.range([0, w - padding])
.domain([0, 20])
.clamp(true);
var svg = d3.select('#vis').append('svg')
.attr('width', w + 2 * padding)
.attr('height', h + 2 * padding);
var points = [0, 2, 3, 4, 5, 10, 18, 16, 20];
var means = sample(points, 4);
var clusters1 = means_clusters(points, means, dist1d, function(d) { return d; });
var means2 = clusters_means(clusters1, average1d, function(d) { return d; });
var clusters2 = means_clusters(points, means2, dist1d, function(d) { return d; });
var means3 = clusters_means(clusters2, average1d, function(d) { return d; });
var steps = svg.selectAll('g')
.data(['input data',
'means',
'all points closest to mean choices',
'new means',
'all points closest to new mean',
'the new means'])
.enter()
.append('g')
.attr('class', 'step')
.attr('transform', function(d, i) {
return 'translate(' + padding + ',' + (padding + (i * row)) + ')';
});
steps.append('rect')
.attr('class', 'stage')
.attr('width', w - padding)
.attr('height', row - padding * 2)
.attr('y', -10);
steps.append('text')
.text(function(d) { return d; })
.attr('dy', -15);
function closest_d(d, i, v) {
if (i == 2) {
return clusters1;
} else if (i == 4) {
return clusters2;
} else {
return [];
}
}
var c = steps.selectAll('g.closest')
.data(closest_d)
.enter()
.append('g').attr('class', 'closest');
var closest_areas = c.append('rect')
.attr('class', 'closest')
.attr('height', 20)
.attr('y', -10);
var p = steps.selectAll('g.points')
.data(points)
.enter()
.append('g').attr('class', function(d, i, v){
return 'point row-' + v;
});
p.append('circle')
.attr('r', 10)
.attr('class', function(d, i, v) {
return 'control row-' + v;
});
var controltext = p.append('text')
.attr('class', 'controltext')
.attr('dx', '0px')
.attr('dy', '4px');
steps.selectAll('g.row-0')
.call(d3.behavior.drag()
.on("drag", function(d, i) {
points[i] = x.invert(d3.event.x);
update();
})
);
function means_d(d, i, v) {
if (i == 1 || i == 2) {
return means;
} else if (i == 3 || i == 4) {
return means2;
} else if (i == 5) {
return means3;
} else {
return [];
}
}
var m = steps.selectAll('g.means')
.data(means_d)
.enter()
.append('g').attr('class', 'mean');
m.append('circle')
.attr('r', 5)
.attr('class', 'mean');
function update() {
means = sample(points, 4);
clusters1 = means_clusters(points, means, dist1d, function(d) { return d; });
means2 = clusters_means(clusters1, average1d, function(d) { return d; });
clusters2 = means_clusters(points, means2, dist1d, function(d) { return d; });
means3 = clusters_means(clusters2, average1d, function(d) { return d; });
closest_areas.data(closest_d);
closest_areas
.transition()
.attr('x', function(d) {
return x(d3.min(d, function(x) { return x; }));
})
.attr('width', function(d) {
return x((d3.max(d, function(x) { return x; }) -
d3.min(d, function(x) { return x; })));
});
p.data(points);
p.attr('transform', function(d) {
return 'translate(' + x(d) + ', 0)';
});
controltext.text(function(d, i) { return d; });
m.data(means_d)
.transition()
.attr('transform', function(d) {
return 'translate(' + x(d) + ', 15)';
});
}
update();
function sample(list, m) {
var n = list.length;
if (m > n) return void console &&
console.log('list length must be > sample');
var sampleList = [];
for (var i = n - m; i < n; i++) {
var item = list[~~(Math.random() * i)];
if (sampleList.indexOf(item) !== -1) {
sampleList.push(list[i]);
} else {
sampleList.push(item);
}
}
return sampleList;
}
function dist1d(a, b) {
return Math.abs(a - b);
}
function average1d(n, val) {
var s = 0;
for (var i = 0; i < n.length; i++) {
s += val(n[i]);
}
return s / n.length;
}
function dist(a, b) {
var d = 0;
for (var i = 0; i < a.length; i++) {
d += Math.pow(a[i] - b[i], 2);
}
return Math.sqrt(d);
}
function identity(x) {
return x;
}
function means_clusters(x, means, distance, val) {
if (!val) val = identity;
if (!distance) distance = dist1d;
// For every value, find the closest mean and add that value to the
// mean's `vals` array.
var groups = {};
for (var i = 0; i < x.length; i++) {
var dists = [];
for (var j = 0; j < means.length; j++) {
dists.push(distance(val(x[i]), val(means[j])));
}
var closest_index = dists.indexOf(Math.min.apply(null, dists));
if (!groups[closest_index]) groups[closest_index] = [];
groups[closest_index].push(x[i]);
}
var out = [];
for (var idx in groups) {
out.push(groups[idx]);
}
return out;
}
function clusters_means(clusters, average, val) {
if (!average) average = average1d;
if (!val) val = identity;
var newmeans = [];
for (i = 0; i < clusters.length; i++) {
var centroid = average(clusters[i], val);
newmeans.push(centroid);
}
return newmeans;
}
function kmeans(x, n, distance, average) {
}
if (typeof module !== 'undefined') {
module.exports = {
sample: sample,
dist1d: dist1d,
dist: dist,
means_clusters: means_clusters,
clusters_means: clusters_means,
kmeans: kmeans
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment