Skip to content

Instantly share code, notes, and snippets.

@rakeshchada
Last active June 21, 2018 20:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rakeshchada/43532fc344082fc1c5d4530110817306 to your computer and use it in GitHub Desktop.
Save rakeshchada/43532fc344082fc1c5d4530110817306 to your computer and use it in GitHub Desktop.
Neural Embedding Animator

This tool lets you produce and animate points on a scatter plot. It's built with the idea of understanding transitions of embeddings produced from a neural network. Please refer to this blog post for more details.

How to use?

Animating hidden representations of data points
  1. Upload colors corresponding to the labels for the data points in your classification task in the colors.csv file. Note that it only accepts a csv file. So if you have something like a ".txt" file - you would need to change its extension to be ".csv".
  2. Upload the first file containing the 2-D or 3-D initial hidden representations (weights from the final layer of the neural network) using the "1st Data CSV File" button. They should have a 1-1 correspondence with the above labels/colors uploaded.
  3. Upload the second file containing the 2-D or 3-D final hidden representations (weights from the final layer of the neural network) using the "2nd Data CSV File" button.
  4. Ignore the optional "Word List CSV" button
  5. Click "Start Animation" button. You would now see it animate the points.
Animating word embeddings
  1. Ignore the colors upload button.
  2. Upload the first file containing the initial word embeddings using the "1st Data CSV File" button.
  3. Upload the second file containing the final word embeddings using the "2nd Data CSV File" button.
  4. Upload the list of words corresponding to the embeddings the optional "Word List CSV" button. The order of the words is important here. The first word ideally corresponds to index 1 in your tokenizer (index 0 is mask), second word is index 2 and so on.
  5. Click "Start Animation" button. You would now see it animate the points.

For smoother animations and interpretable visualizations, I typically restrict the points to be fewer than ~5000 points. But it does work for any number of points.

Credits to Mike Bostock (https://gist.github.com/mbostock) whose amazing examples were a great inspiration in producing this.

//Width and height
var w = 600;
var h = 600;
var padding = 25;
var colors;
var firstData;
var secondData;
var words;
var scalesFirstData;
var scalesSecondData;
// Reset things
d3.select("svg")
.remove();
//Create SVG element
var svg = d3.select("#chart")
.append("svg")
.attr("width", w)
.attr("height", h);
//Define clipping path
svg.append("clipPath") //Make a new clipPath
.attr("id", "chart-area") //Assign an ID
.append("rect") //Within the clipPath, create a new rect
.attr("x", padding) //Set rect's position and size…
.attr("y", padding)
.attr("width", w - padding * 3)
.attr("height", h - padding * 2);
// Define the div for the tooltip
var div = d3.select("body")
.append("div")
.attr("class", "tooltip")
.style("opacity", 0);
//Create scale functions
function getScales(dataset) {
var xs = dataset.map(function(d) {return d[0]});
var ys = dataset.map(function(d) {return d[1]});
//Create scale functions
var xScale = d3.scale
.linear()
.domain(
[Math.min(...xs),
Math.max(...xs)
]
)
.range([padding, w - padding * 2]);
var yScale = d3.scale
.linear()
.domain(
[Math.min(...ys),
Math.max(...ys)
]
)
.range([w - padding * 2, padding]);
return [xScale, yScale];
}
function sliderInput() {
makeTransition();
}
function interpolate(i) {
var t = d3.select('#slider')
.property('value')/10000;
var newX = d3.interpolate(
scalesFirstData[0](firstData[i][0]),
scalesSecondData[0](secondData[i][0])
);
var newY = d3.interpolate(
scalesFirstData[1](firstData[i][1]),
scalesSecondData[1](secondData[i][1])
);
return [newX(t), newY(t)];
}
function makeTransition() {
svg.selectAll("circle")
.attr("cx", function(d, i) {
var interpolatedX = interpolate(i)[0];
return interpolatedX;
})
.attr("cy", function(d, i) {
var interpolatedY = interpolate(i)[1];
return interpolatedY;
});
}
function reset() {
document.getElementById('slider').value = 0;
makeTransition();
}
//On click, update with new data
d3.select("#reset")
.on("click", function() {
reset();
});
// On click, start the animation
d3.select("#start")
.on("click", function() {
d3.timer(elapsed => {
document.getElementById('slider').value = elapsed * 2
makeTransition();
if(elapsed >= 5000) {
return true; // this will stop the d3 timer.
}
});
});
var zoom = d3.behavior.zoom()
.translate([0, 0])
.scale(1)
.scaleExtent([1, 8])
.on("zoom", zoomed);
function zoomed(){
svg.selectAll("circle").attr("transform", "translate(" + d3.event.translate + ")scale(" + d3.event.scale + ")");
svg.selectAll("text").attr("transform", "translate(" + d3.event.translate + ")scale(" + d3.event.scale + ")");
}
function showPlot() {
var xScale = scalesFirstData[0];
var yScale = scalesFirstData[1];
//Define X axis
var xAxis = d3.svg.axis()
.scale(xScale)
.orient("bottom")
.ticks(5);
//Define Y axis
var yAxis = d3.svg.axis()
.scale(yScale)
.orient("left")
.ticks(5);
//Create circles
svg.append("g") //Create new g
.attr("id", "circles") //Assign ID of 'circles'
.attr("clip-path", "url(#chart-area)") //Add reference to clipPath
.selectAll("circle")
.data(firstData)
.enter()
.append("circle")
.attr("cx", function(d) {
return xScale(d[0]);
})
.attr("cy", function(d) {
return yScale(d[1]);
})
.attr("r", 2)
.attr("fill", function(d, i) {
var color = 'black';
if (colors) {color = colors[i]};
return color;
})
.style("opacity", function(d, i) {
var opacity = 0.5;
if (colors) {opacity = 0.8};
return opacity;
})
.on("mouseover", function(d, i) {
div.transition()
.duration(200)
.style("opacity", .9);
var text = i;
if (words && i == 0) text = "<ignore>";
if (words && i >= 1) text = words[i-1];
if (words) text = words[i];
div .html(text)
.style("left", (d3.event.pageX) + "px")
.style("top", (d3.event.pageY - 28) + "px");
}
)
.on("mouseout", function(d) {
div.transition()
.duration(500)
.style("opacity", 0);
}
)
.call(zoom);
//Create X axis
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + (h - padding) + ")")
.call(xAxis);
//Create Y axis
svg.append("g")
.attr("class", "y axis")
.attr("transform", "translate(" + padding + ",0)")
.call(yAxis);
}
function toggleSlider() {
if(firstData && secondData) {
document.getElementById("slider").style.visibility = "visible";
document.getElementById("start").style.visibility = "visible";
document.getElementById("reset").style.visibility = "visible";
document.getElementById('slider').value = 0;
}
else {
document.getElementById("slider").style.visibility = "hidden";
document.getElementById("start").style.visibility = "hidden";
document.getElementById("reset").style.visibility = "hidden";
}
}
// File upload functions
$(document).ready(function() {
toggleSlider();
var uploadColorsFile = function(evt) {
svg.selectAll("circle").remove();
var file = evt.target.files[0];
var reader = new FileReader();
reader.readAsText(file);
reader.onload = function(event) {
colors = $.csv.toArrays(event.target.result);
toggleSlider();
if (firstData != null && firstData !== "") {
showPlot();
}
};
};
// Upload selected file and create array
var uploadFile = function(evt) {
svg.selectAll("g").remove();
var file = evt.target.files[0];
var reader = new FileReader();
reader.readAsText(file);
reader.onload = function(event) {
firstData = $.csv.toArrays(
event.target.result);
if(firstData !== null && firstData !== "") {
scalesFirstData = getScales(firstData);
showPlot();
toggleSlider();
}
};
};
var uploadSecondData = function(evt) {
var file = evt.target.files[0];
var reader = new FileReader();
reader.readAsText(file);
reader.onload = function(event) {
secondData = $.csv.toArrays(
event.target.result);
if(secondData !== null && secondData !== "") {
scalesSecondData = getScales(secondData);
toggleSlider();
}
};
};
var uploadWordList = function(evt) {
var file = evt.target.files[0];
var reader = new FileReader();
reader.readAsText(file);
reader.onload = function(event) {
words = $.csv.toArrays(
event.target.result);
};
};
// Confirm browser supports HTML5 File API
var browserSupportFileUpload = function() {
var isCompatible = false;
if(window.File && window.FileReader && window.FileList && window.Blob) {
isCompatible = true;
}
return isCompatible;
};
// event listener for file upload
if (browserSupportFileUpload()) {
document.getElementById('txtFileUpload').addEventListener('change', uploadFile, false);
document.getElementById('colorsFileUpload').addEventListener('change', uploadColorsFile, false);
document.getElementById('secondDataUpload').addEventListener('change', uploadSecondData, false);
document.getElementById('wordListUpload').addEventListener('change', uploadWordList, false);
} else {
$("#introHeader").html('The File APIs is not fully supported in this browser. Please use another browser.');
}
});
<!DOCTYPE html>
<meta charset="utf-8">
<html>
<head>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="uielems">
<center><h4>Neural Embedding Animator</h4>
<table>
<tr>
<td>
<h5 id="introHeader">Colors CSV File</h5>
<input type="file" name="File Upload" id="colorsFileUpload" accept=".csv" />
</td>
<td>
<h5 id="introHeader">1st Data CSV File</h5>
<input type="file" name="File Upload" id="txtFileUpload" accept=".csv" />
</td>
<td>
<h5 id="introHeader">2nd Data CSV File</h5>
<input type="file" name="File Upload" id="secondDataUpload" accept=".csv" />
</td>
<td>
<h5 id="introHeader">Word List CSV File (optional) </h5>
<input type="file" name="File Upload" id="wordListUpload" accept=".csv" />
</td>
</tr>
<tr>
<td>
<br/>
</td>
</tr>
<tr>
<td>
<p>
<span class="btn" id="start">Start Animation</span>
</p>
</td>
<td>
<p>
<span class="btn" id="reset">Reset</span>
</p>
</td>
</tr>
</table>
</center>
</div>
<center>
<table>
<tr>
<td>
<input type="range" id="slider" class="slider" min="0" max="10000" value="0" oninput="sliderInput()">
</td>
<td
<div id="chart"></div>
</td>
</tr>
</table>
</center>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.0.0-alpha1/jquery.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery-csv/0.71/jquery.csv-0.71.min.js"></script>
<script src="https://d3js.org/d3.v3.min.js"></script>
<script src="embedding-animator.js"></script>
</body>
</html>
body {
font-family: "Helvetica Neue", Helvetica, sans-serif;
}
.axis path,
.axis line {
fill: none;
stroke: black;
shape-rendering: crispEdges;
}
.axis text {
font-family: sans-serif;
font-size: 11px;
}
.btn {
background-color: rgba(0, 0, 0, .75);
border-radius: 5px;
display: inline-block;
color: #fff;
padding: .5em 1em;
cursor: pointer;
}
.btn:hover {
background-color: rgba(0, 0, 0, .9);
}
.btn:active {
background-color: rgba(0, 0, 0, 1);
}
input.slider {
margin: 5px;
height: 40%;
transform: rotate(90deg);
}
div.uielems {
margin: 5px;
}
div.tooltip {
position: absolute;
text-align: center;
width: 60px;
height: 20px;
padding: 2px;
font: 12px sans-serif;
background: lightsteelblue;
border: 0px;
border-radius: 8px;
pointer-events: none;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment