Skip to content

Instantly share code, notes, and snippets.

@arpitnarechania
Last active October 22, 2020 13:40
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save arpitnarechania/dbf03d8ef7fffa446379d59db6354bac to your computer and use it in GitHub Desktop.
Save arpitnarechania/dbf03d8ef7fffa446379d59db6354bac to your computer and use it in GitHub Desktop.
Confusion Matrix
license: MIT

Pure D3 implementation of a Confusion Matrix with some computed metrics in a tabular view.

DEMO

<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Confusion Matrix</title>
<link rel="stylesheet" type="text/css" href="style.css"/>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.17/d3.min.js"></script>
</head>
<body>
<div id="dataView"></div>
<div style="display:inline-block; float:left" id="container"></div>
<div style="display:inline-block; float:left" id="legend"></div>
<script src="main.js"></script>
<script>
var confusionMatrix = [
[169, 10],
[7, 46]
];
var tp = confusionMatrix[0][0];
var fn = confusionMatrix[0][1];
var fp = confusionMatrix[1][0];
var tn = confusionMatrix[1][1];
var p = tp + fn;
var n = fp + tn;
var accuracy = (tp+tn)/(p+n);
var f1 = 2*tp/(2*tp+fp+fn);
var precision = tp/(tp+fp);
var recall = tp/(tp+fn);
accuracy = Math.round(accuracy * 100) / 100
f1 = Math.round(f1 * 100) / 100
precision = Math.round(precision * 100) / 100
recall = Math.round(recall * 100) / 100
var computedData = [];
computedData.push({"F1":f1, "PRECISION":precision,"RECALL":recall,"ACCURACY":accuracy});
var labels = ['Class A', 'Class B'];
Matrix({
container : '#container',
data : confusionMatrix,
labels : labels,
start_color : '#ffffff',
end_color : '#e67e22'
});
// rendering the table
var table = tabulate(computedData, ["F1", "PRECISION","RECALL","ACCURACY"]);
</script>
</body>
var margin = {top: 50, right: 50, bottom: 100, left: 100};
function Matrix(options) {
var width = 250,
height = 250,
data = options.data,
container = options.container,
labelsData = options.labels,
startColor = options.start_color,
endColor = options.end_color;
var widthLegend = 100;
if(!data){
throw new Error('Please pass data');
}
if(!Array.isArray(data) || !data.length || !Array.isArray(data[0])){
throw new Error('It should be a 2-D array');
}
var maxValue = d3.max(data, function(layer) { return d3.max(layer, function(d) { return d; }); });
var minValue = d3.min(data, function(layer) { return d3.min(layer, function(d) { return d; }); });
var numrows = data.length;
var numcols = data[0].length;
var svg = d3.select(container).append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
var background = svg.append("rect")
.style("stroke", "black")
.style("stroke-width", "2px")
.attr("width", width)
.attr("height", height);
var x = d3.scale.ordinal()
.domain(d3.range(numcols))
.rangeBands([0, width]);
var y = d3.scale.ordinal()
.domain(d3.range(numrows))
.rangeBands([0, height]);
var colorMap = d3.scale.linear()
.domain([minValue,maxValue])
.range([startColor, endColor]);
var row = svg.selectAll(".row")
.data(data)
.enter().append("g")
.attr("class", "row")
.attr("transform", function(d, i) { return "translate(0," + y(i) + ")"; });
var cell = row.selectAll(".cell")
.data(function(d) { return d; })
.enter().append("g")
.attr("class", "cell")
.attr("transform", function(d, i) { return "translate(" + x(i) + ", 0)"; });
cell.append('rect')
.attr("width", x.rangeBand())
.attr("height", y.rangeBand())
.style("stroke-width", 0);
cell.append("text")
.attr("dy", ".32em")
.attr("x", x.rangeBand() / 2)
.attr("y", y.rangeBand() / 2)
.attr("text-anchor", "middle")
.style("fill", function(d, i) { return d >= maxValue/2 ? 'white' : 'black'; })
.text(function(d, i) { return d; });
row.selectAll(".cell")
.data(function(d, i) { return data[i]; })
.style("fill", colorMap);
var labels = svg.append('g')
.attr('class', "labels");
var columnLabels = labels.selectAll(".column-label")
.data(labelsData)
.enter().append("g")
.attr("class", "column-label")
.attr("transform", function(d, i) { return "translate(" + x(i) + "," + height + ")"; });
columnLabels.append("line")
.style("stroke", "black")
.style("stroke-width", "1px")
.attr("x1", x.rangeBand() / 2)
.attr("x2", x.rangeBand() / 2)
.attr("y1", 0)
.attr("y2", 5);
columnLabels.append("text")
.attr("x", 30)
.attr("y", y.rangeBand() / 2)
.attr("dy", ".22em")
.attr("text-anchor", "end")
.attr("transform", "rotate(-60)")
.text(function(d, i) { return d; });
var rowLabels = labels.selectAll(".row-label")
.data(labelsData)
.enter().append("g")
.attr("class", "row-label")
.attr("transform", function(d, i) { return "translate(" + 0 + "," + y(i) + ")"; });
rowLabels.append("line")
.style("stroke", "black")
.style("stroke-width", "1px")
.attr("x1", 0)
.attr("x2", -5)
.attr("y1", y.rangeBand() / 2)
.attr("y2", y.rangeBand() / 2);
rowLabels.append("text")
.attr("x", -8)
.attr("y", y.rangeBand() / 2)
.attr("dy", ".32em")
.attr("text-anchor", "end")
.text(function(d, i) { return d; });
var key = d3.select("#legend")
.append("svg")
.attr("width", widthLegend)
.attr("height", height + margin.top + margin.bottom);
var legend = key
.append("defs")
.append("svg:linearGradient")
.attr("id", "gradient")
.attr("x1", "100%")
.attr("y1", "0%")
.attr("x2", "100%")
.attr("y2", "100%")
.attr("spreadMethod", "pad");
legend
.append("stop")
.attr("offset", "0%")
.attr("stop-color", endColor)
.attr("stop-opacity", 1);
legend
.append("stop")
.attr("offset", "100%")
.attr("stop-color", startColor)
.attr("stop-opacity", 1);
key.append("rect")
.attr("width", widthLegend/2-10)
.attr("height", height)
.style("fill", "url(#gradient)")
.attr("transform", "translate(0," + margin.top + ")");
var y = d3.scale.linear()
.range([height, 0])
.domain([minValue, maxValue]);
var yAxis = d3.svg.axis()
.scale(y)
.orient("right");
key.append("g")
.attr("class", "y axis")
.attr("transform", "translate(41," + margin.top + ")")
.call(yAxis)
}
// The table generation function
function tabulate(data, columns) {
var table = d3.select("#dataView").append("table")
.attr("style", "margin-left: " + margin.left +"px"),
thead = table.append("thead"),
tbody = table.append("tbody");
// append the header row
thead.append("tr")
.selectAll("th")
.data(columns)
.enter()
.append("th")
.text(function(column) { return column; });
// create a row for each object in the data
var rows = tbody.selectAll("tr")
.data(data)
.enter()
.append("tr");
// create a cell in each row for each column
var cells = rows.selectAll("td")
.data(function(row) {
return columns.map(function(column) {
return {column: column, value: row[column]};
});
})
.enter()
.append("td")
.attr("style", "font-family: Courier") // sets the font style
.html(function(d) { return d.value; });
return table;
}
.axis text {
font: 10px sans-serif;
}
.axis line, .axis path {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
td, th, tr {
padding: 4px;
border: 1px solid black;
}
table{
border-collapse: collapse;
}
#dataView{
margin-top:50px;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment