Skip to content

Instantly share code, notes, and snippets.

@widged
Forked from enjalot/README.md
Created October 14, 2015 22:53
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save widged/813ae0bd69dc4e73bb73 to your computer and use it in GitHub Desktop.
cosine similarity

Click on the various examples

Each example is a random vector in the same "space" as the sample in the top left.

Drag on individual rows

Each row represents one dimension of our vectors. You can drag them back and forth to change the value of our vector for that dimension.

Watch how the similarity bar changes

The bar below each vector is a measure of how similar it is to the sample in the top left.

Whats going on?

Cosine similarity is a technique for getting the "distance" between two vectors in high-dimensional spaces. We intuitively understand distance in 2 and 3 dimensions, but what happens if we have more than that? It can be convenient to calculate a single number that tells us how similar or different two high-dimensional vectors are.

I made this explanation because I'm trying to use find similar blocks and in order to do that I need to understand cosine similarity. They say nothing helps you understand better than trying to explain it!

<!DOCTYPE html>
<head>
<meta charset="utf-8">
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min.js"></script>
<style>
body {
position: fixed;
top:0;left:0;bottom:0;right:0;
}
svg {
width: 100%;
height: 100%;
}
.selected {
fill: #d33f67;
}
.highlighted {
stroke: #d33f67;
stroke-width: 3;
}
.active {
pointer-events: none;
}
.random {
cursor: pointer;
}
.background {
fill: #efefef;
}
.background.selected {
fill: #f8f0f0;
}
.row {
cursor: ew-resize;
}
</style>
</head>
<body>
<svg></svg>
<script>
// the example we let you play with
var example = [ 3, 2, 4, 2, 0, 2 ];
//var example = [ 0, 1, 2, 3, 4, 0 ];
var dimensions = example.length;
var range = [0, 5]; // [0, 5)
// the amount of vectors we want to compare with
var num = 8;
var randomData = d3.range(num).map(function(i) {
return d3.range(dimensions).map(function() {
return Math.floor(Math.random() * range[1])
})
})
var selectedIndex = 3;
var selectedData = randomData[selectedIndex]
var svg = d3.select("svg")
var exg = svg.append("g")
.attr("transform", "translate(10, 15)")
var ex = new strip()
.data(example)
.onChange(function(data) {
rerender();
updateExplanation();
})
.update(exg)
var randomg = svg.selectAll("g.random")
.data(randomData)
var cellSize = 15;
var renter = randomg.enter().append("g").classed("random", true)
renter.append("rect").classed("background", true)
.attr({
x: -20,
y: -20,
width: 200,
height: 155,
})
renter.attr("transform", function(d,i) {
var x = 29 + 238 * (i % 4);
var y = 175 + 167 * Math.floor(i / 4);
return "translate(" + [x, y] + ")"
})
.each(function(d,i) {
var s = new strip()
.data(d)
.cellWidth(cellSize)
.cellHeight(cellSize)
.onChange(function(data) {
rerender();
updateExplanation();
})
.update(d3.select(this))
})
.on("click", function(d,i) {
selectedData = d;
d3.selectAll("rect.background").classed("selected", false);
d3.select(this).select("rect.background").classed("selected", true);
d3.selectAll("rect.progress").classed("selected", false);
d3.select(this).select("rect.progress").classed("selected", true);
d3.selectAll("rect.active").classed("highlighted", false);
d3.select(this).select("rect.active").classed("highlighted", true);
d3.selectAll("rect.bg").classed("highlighted", false);
d3.select(this).select("rect.bg").classed("highlighted", true);
updateExplanation();
})
var progressWidth = cellSize * range[1] * 2;
var progress = renter.append("rect").classed("progress", true)
.attr({
x: 0,
y: cellSize * example.length + 30,
height: 5,
width: function(d,i) { return cosineSimilarity(example, d) * progressWidth; },
fill: "#515151",
})
renter.append("rect").classed("bg", true)
.attr({
x: 0,
y: cellSize * example.length + 30,
height: 5,
width: progressWidth,
fill: "none",
stroke: "#333333"
})
renter.filter(function(d,i) {
return i === selectedIndex;
})
.each(function(d,i) {
d3.select(this).select("rect.background").classed("selected", true);
d3.select(this).select("rect.progress").classed("selected", true);
d3.select(this).select("rect.active").classed("highlighted", true);
d3.select(this).select("rect.bg").classed("highlighted", true);
})
function rerender() {
progress.attr({
width: function(d,i) { return cosineSimilarity(example, d) * progressWidth; }
})
}
var explanation = svg.append("g").classed("explanation", true)
.attr("transform", "translate(" + [422, 15] + ")")
var sexp = new strip2()
.data(example)
.update(explanation);
var selected = svg.append("g").classed("selected", true)
.attr("transform", "translate(" + [544, 15] + ")")
var sexp = new strip2()
.data(example)
.update(explanation);
var sel = new strip2()
.data(selectedData)
.update(selected);
updateExplanation();
function updateExplanation() {
// alot of hardcoding going on here...
sexp.data(example).update(explanation)
sel.data(selectedData).update(selected)
}
function cosineSimilarity(u, v) {
// compute the distance between u and v
var udotv = 0
var umag = 0
var vmag = 0
for(var i = 0; i < u.length; i++) {
udotv += u[i] * v[i]
umag += u[i]*u[i]
vmag += v[i]*v[i]
}
umag = Math.sqrt(umag)
vmag = Math.sqrt(vmag)
return udotv/(umag*vmag);
}
function strip2() {
// for showing just the cell values
var cellWidth = 20;
var cellHeight = 20;
var rows = 6;
var cols = 5;
var fontSize = 14;
var half = Math.floor(cols/2);
var colors = d3.scale.ordinal()
.domain(d3.range(rows))
.range([
"#e41a1c",
"#377eb8",
"#4daf4a",
"#984ea3",
"#ff7f00",
"#ffff33"
])
var data = [];
this.update = function(g) {
var bar = g.selectAll("rect.bar")
.data(data)
bar.enter().append("rect").classed("bar", true);
bar.attr({
x: function(d,i) { return 0 },
y: function(d,i) { return i * cellHeight },
width: cellWidth,
height: cellHeight,
fill: function(d,i) {
var color = d3.hcl(colors(i))
var v = (cols-1) - d;
if(v < half) {
return color.darker(1 - v/half)
} else if(v > 2) {
return color.brighter((v-1)/cols)
} else {
return color;
}
}
})
var txt = g.selectAll("text.number")
.data(data)
txt.enter().append("text").classed("number", true)
txt.attr({
x: function(d,i) { return cellWidth + 5 },
y: function(d,i) { return fontSize + i * cellHeight },
"font-size": fontSize,
"font-family": "Courier New"
})
.text(function(d) { return (cols-1) - d });
return this;
}
this.data = function(val) {
if(val) { data = val; return this; }
return data;
}
return this;
}
function strip() {
//interactive vector editor
var cellWidth = 20;
var cellHeight = 20;
var rows = 6;
var cols = 5;
var half = Math.floor(cols/2);
var colors = d3.scale.ordinal()
.domain(d3.range(rows))
.range([
"#e41a1c",
"#377eb8",
"#4daf4a",
"#984ea3",
"#ff7f00",
"#ffff33"
])
var data = [];
var cb; // callback for onChange
this.update = function(g) {
var mapping = d3.scale.linear()
.domain([0, cols-1]) // based on max value
.range([-(cols-1), 0])
var container = g.append("g")
.attr("transform", "translate(" + [cellWidth * cols, 0] + ")")
var rowsg = container.selectAll("g.row")
.data(d3.range(data.length))
rowsg.enter().append("g").classed("row", true)
.attr("transform", function(d) {
var x = cellWidth * mapping(data[d]);
var y = 0;
return "translate(" + [x,y] + ")";
})
var cellsg = rowsg.selectAll("rect.cell")
.data(function(d) {
return d3.range(cols).map(function(c) {
return { c: c, r: d }
})
})
cellsg.enter().append("rect").classed("cell", true)
cellsg.attr({
x: function(d) { return d.c * cellWidth },
y: function(d) { return d.r * cellHeight },
width: cellWidth,
height: cellHeight,
fill: function(d,i) {
var color = d3.hcl(colors(d.r))
if(d.c < half) {
return color.darker(1 - d.c/half)
} else if(d.c > half) {
return color.brighter(d.c/cols)
} else {
return color;
}
}
})
// We convert the relative mouse position to one of the potential
// values in our array
var mouseScale = d3.scale.linear()
.domain([-cellWidth * cols/2, cellWidth/2 * (cols+1)])
.range([0, cols-1])
var drag = d3.behavior.drag()
.on("drag", function(d) {
var x = d3.event.x;
var y = d3.event.y;
var v = Math.floor(mouseScale(x))
v = Math.max(0, Math.min(v, cols-1))
if(data[d] === v) return;
data[d] = v;
rowsg.attr("transform", function(d) {
var x = cellWidth * mapping(data[d]);
var y = 0;
return "translate(" + [x,y] + ")";
})
if(cb) cb(data);
})
rowsg.call(drag)
var active = container.append("rect").classed("active", true)
.attr({
x: 0,//cellWidth * Math.floor(cols/2),
y: 0,
width: cellWidth,
height: cellHeight * rows,
fill: "none",
stroke: "#000"
})
return this;
}
this.data = function(val) {
if(val) { data = val; return this; }
return data;
}
this.cellWidth = function(val) {
if(val) { cellWidth = val; return this; }
return cellWidth;
}
this.cellHeight = function(val) {
if(val) { cellHeight = val; return this; }
return cellHeight;
}
this.onChange = function(val) {
if(val) { cb = val; return this; }
return val
}
return this;
}
</script>
</body>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','//www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-67666917-1', 'auto');
ga('send', 'pageview');
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment