Skip to content

Instantly share code, notes, and snippets.

@chornbaker
Last active December 14, 2016 21:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chornbaker/b83774b243b66e9fd0ff43a45253df61 to your computer and use it in GitHub Desktop.
Save chornbaker/b83774b243b66e9fd0ff43a45253df61 to your computer and use it in GitHub Desktop.
Matplotlib Scatterplot in D3
license: mit

Matplotlib Scatterplot in D3

This is a useful example for converting a matplotlib scatterplot to D3 with an initial animation. Using this template, you can quickly generate a dynamic scatterplot using all of your favorite matplotlib styles, and add more advanced interactions without dealing with a lot of additional formatting.

See scatterplot.py for code used to generate scatterplot.svg. It generates a simple scatterplot with some basic formatting. Matplotlib's savefig function will automatically output an svg format if the filepath has a .svg extension. The key is to assign a unique gid to each dot using the gid parameter, so it is simple to find the dots when you read the svg in index.html.

Assigning a gid to individual matplotlib elements is fairly simple for lines and bars, but a bit more complicated for scatterplots, since they are stored in a PathCollection object. As a workaround for this, I write the svg to StringIO and use BeautifulSoup find and add the id to each dot before saving the output file.

This is Part 3 in a series of examples for using matplotlib generated plots in D3.

<!DOCTYPE html>
<head>
<meta charset="utf-8">
<!-- CDN resource versions -->
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="https://d3js.org/d3-ease.v1.min.js"></script>
<script src="https://d3js.org/d3-queue.v3.min.js"></script>
<script src="https://d3js.org/d3-selection.v1.min.js"></script>
<script src="https://d3js.org/d3-transition.v1.min.js"></script>
<style>
body { margin:0;position:fixed;top:0;right:0;bottom:0;left:0; }
</style>
</head>
<body>
<!-- chart SVGs-->
<div id="chart"></div>
<script type="text/javascript">
d3.queue()
.defer(d3.xml, "scatterplot.svg")
.await(ready);
function ready(error, xml) {
if (error) throw error;
// Load SVG into chart
d3.select("#chart").node().appendChild(xml.documentElement);
///////////////////////////////////////////////////////////////////////////
/////////////////////////// Helper Functions //////////////////////////////
///////////////////////////////////////////////////////////////////////////
// Temporarily disable user interractions to allow animations to complete
var disableUserInterractions = function (time) {
isTransitioning = true;
setTimeout(function(){
isTransitioning = false;
}, time);
}//disableUserInterractions
// Convert paths into rectangles
function getRectFromPath(path) {
// TODO: Generalize this for more robust parsing
path = path.split(" ")
return {"x": parseFloat(path[0].split(" ")[1]),
"y": parseFloat(path[0].split(" ")[2]),
"w": (parseFloat(path[1].split(" ")[1]) - parseFloat(path[0].split(" ")[1])),
"h": (parseFloat(path[0].split(" ")[2]) - parseFloat(path[2].split(" ")[2]))
}
}//getRectFromPath
///////////////////////////////////////////////////////////////////////////
///////////////////////// Animation Elements //////////////////////////////
///////////////////////////////////////////////////////////////////////////
// Set initial transition state
var isTransitioning = false;
// Basic plot elements
var svg, plot, plot_bbox, dots, dot_locs,
label = "dot_";
svg = d3.select("#chart").select("svg")
plot = svg.select("#figure_1")
var height = parseFloat(svg.style("height"));
var width = parseFloat(svg.style("width"))
plot_bbox = getRectFromPath(plot.select("#patch_2").select("path").attr("d"))
// Select all the dots
dots = plot.selectAll("g").filter(function(d,i,j) {
return new RegExp(label, 'g').test(j[i].id)
});
// Store dot locations
dot_locs = {}
dots.nodes().forEach(function(d) {
var dot = d3.select(d).selectAll("use");
var dot_id = d.id;
dot_locs[dot_id] = {"x": parseFloat(dot.attr("x")),
"y": parseFloat(dot.attr("y")),
"c": dot.style("fill")};
})
svg.on("click", function(){
if (isTransitioning == false) {
init();
}
});
///////////////////////////////////////////////////////////////////////////
////////////////// Initialize Graphic and Animations //////////////////////
///////////////////////////////////////////////////////////////////////////
function init() {
//////////////////////////////////////////////////////
///////////////////// Actions ////////////////////////
//////////////////////////////////////////////////////
var DURATION = 3000,
WAVES = 20,
DELAY = DURATION / WAVES;
disableUserInterractions(DURATION * 2);
// Put all dots at the origin
dots.selectAll("use")
.attr("x", plot_bbox.x)
.attr("y", plot_bbox.y)
// Animate dots
dots.each(function(d,i,j) {
d3.select(this).selectAll("use")
.transition().delay(i%WAVES * DELAY).duration(DURATION)
.ease(d3.easeCubicOut)
.attr("x", function(){
return dot_locs[j[i].id].x
})
.attr("y", function(){
return dot_locs[j[i].id].y
})
})
}//init
init()
};
</script>
</body>
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from bs4 import BeautifulSoup
import io
def main():
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)
# Generate some random dots
n = 200
lims = [0, 5]
offset = 0.2
X = np.random.uniform(lims[0]+offset, lims[1]-offset, n)
Y = np.random.uniform(lims[0]+offset, lims[1]-offset, n)
colors = cycler('color', [plt.cm.winter(i) for i in np.linspace(0, 1, n)])
color_list = []
for i, c in zip(range(n), colors()):
color_list.append(c['color'])
ax.scatter(X, Y, alpha=0.8, gid="dots", c=color_list)
ax.set_xlim(lims)
ax.set_ylim(lims)
# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
# Hide the right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Save to StringIO
imgdata = io.StringIO()
plt.savefig(imgdata, bbox_inches='tight', transparent=True, format="svg")
imgdata.seek(0)
svg_data = imgdata.read()
# Find 'dots' element and assign ids to each dot
soup = BeautifulSoup(svg_data, 'xml')
element = soup.find("g", {"id": "dots"})
for i, child in enumerate(element.findChildren("g")):
child["id"] = "dot_" + str(i)
# Save svg
html = soup.prettify("utf-8")
with open('scatterplot.svg', "wb") as file:
file.write(html)
if __name__ == '__main__':
main()
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment