Skip to content

Instantly share code, notes, and snippets.

@avimoondra
Last active August 27, 2020 23:32
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save avimoondra/547525daf20ddfa918ee to your computer and use it in GitHub Desktop.
Save avimoondra/547525daf20ddfa918ee to your computer and use it in GitHub Desktop.
Probability Tree
<!DOCTYPE html>
<html lang="en">
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap.min.css">
<script src="http://d3js.org/d3.v3.min.js"></script>
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta charset="utf-8">
</head>
<body>
<style type="text/css">
#svgWrapper {
display: inline-block;
}
#panelWrapper {
width: 250px;
border: solid #808080 1px;
padding: 10px;
display: inline-block;
position: fixed;
margin-top: 10px;
}
.link.event1{
stroke: #F05133;
}
.link.event2{
stroke: #808080;
}
.link {
fill: none;
}
.step1, .step2, .step3, .mathg {
opacity: 1e-6;
}
.question {
font-size: 16px;
font-weight: bold;
}
.math-line {
stroke: black;
width: 1px;
}
.mathg text{
text-anchor: middle;
}
circle[class$="border"] {
fill: none;
stroke-width: 1;
}
.hidden {
opacity: 1e-6;
}
</style>
<div class="row">
<div id="svgWrapper">
<svg>
<g class="mathg" transform="translate(60,500)">
<text x="10" y="0"> P( cancer | +) = </text>
<g transform="translate(160,-10)">
<text> P(cancer & +) </text>
<line class="math-line" x1="-100" y1="6" x2="100" y2="6"></line>
<text y="20"> P(cancer & +) + P(no cancer & +) </text>
</g>
<text x="275"> = </text>
<g transform="translate(460,-10)">
<g id="flippedPr" class="bayesVis">
<circle cy="-50" class="circle-0"></circle>
<circle cy="-50" class="circle-0-border"></circle>
</g>
<g id="LawOfTotalPr" class="bayesVis" transform="translate(0,60)">
<circle cx="-60" class="circle-0"></circle>
<circle cx="-60" class="circle-0-border"></circle>
<circle cx="60" class="circle-2"></circle>
<circle cx="60"class="circle-2-border"></circle>
<text>+</text>
</g>
<line class="math-line" x1="-165" y1="6" x2="165" y2="6"></line>
</g>
<g transform="translate(655,0)">
<text>≈ <tspan id="bayesCalc"></tspan></text>
</g>
</g>
</svg>
</div>
<div id="panelWrapper">
<div>
<p><b>Step 1.</b> In Canada, about 0.35% of women over 40 will be diagnosed with breast cancer in any given year. <button id="animateTree-I" type="button" class="btn btn-primary btn-xs"> Draw tree </button> </p>
<p><b>Step 2.</b> A common screening test for cancer is the mammogram, but this test is not perfect. In about 11% of patients with breast cancer, the test gives a false negative: it indicates a woman does not have breast cancer when she does have breast cancer. Similarly, the test gives a false positive in 7% of patients who do not have breast cancer: it indicates these patients have breast cancer when they actually do not. <button id="animateTree-II" type="button" class="btn btn-primary btn-xs"> Draw subtrees </button> </p>
<p><b>Step 3.</b> Compute joint probabilities based on cancer status and mammogram test results <button id="computeJointProbs" type="button" class="btn btn-primary btn-xs"> Compute </button> </p>
<p><b>Step 4.</b> Compute P( cancer | + ) using Bayes' Theorem <button id="computeBayesProb" type="button" class="btn btn-primary btn-xs"> Compute </button></p>
</div>
</div>
</div>
<script>
var svgViewBoxWidth = 800; //internal coordinate system for width, height
var svgViewBoxHeight = 800;
var svgWidth = 650; // actual width, height
var svgHeight = svgWidth
var margin = {top: 40, right: 100, bottom: 40, left: 100},
width = svgViewBoxWidth - margin.left - margin.right,
height = svgViewBoxHeight - margin.top - margin.bottom;
var svg = d3.select("#svgWrapper").select("svg")
.attr("width", svgWidth)
.attr("height", svgHeight)
.attr("viewBox","0 0 " + svgViewBoxWidth + " " + svgViewBoxHeight);
var red = "#F05133",
gray = "#808080",
grayRedScale = d3.interpolateRgb(gray,red),
fillColor = "#FFFFFF",
strokeColor = "#808080";
var trunkWidth = 80,
maxCircleArea = Math.pow(trunkWidth,2),
graphWidth = width - 100,
graphHeight = height - 400;
var TREEDATA = {},
METADATA = {};
var dispatch = d3.dispatch("statechange")
var computeRadius = function(probability){
return Math.sqrt(probability*maxCircleArea/Math.PI)
}
dispatch.on("statechange.joints", function(joints){
[0, 2].forEach(function(index){
radius = computeRadius(joints[index].jointPr)
d3.selectAll("g.bayesVis>.circle-" + index)
.attr("r", radius)
.style("fill", fillColor)
.style("stroke", strokeColor)
})
calc = joints[0].jointPr/(joints[0].jointPr + joints[2].jointPr)
d3.select("#bayesCalc")
.text(formatDecimal(calc, 3))
})
var formatDecimal = function(num, numDecimals){
scale = Math.pow(10, numDecimals)
return parseFloat(Math.round(num*scale)/scale).toFixed(numDecimals)
}
var computeJointsNames = function(data){
jointsNames = [];
k = 0;
for(i = 0; i < 2; i++){
for(j = 0; j < 2; j++){
node = data.children[i]
jointsNames[k++] = node.eventOutcome + " & " + node.children[j].eventOutcome;
}
}
return jointsNames
}
var computeJoints = function(data){
joints = [];
k = 0;
for(i = 0; i < 2; i++){
for(j = 0; j < 2; j++){
node = data.children[i]
joints[k++] = {
"pr": node.probability,
"conditionalPr": node.children[j].probability,
"jointPr": node.probability*node.children[j].probability
}
}
}
return joints;
}
var cloneAll = function(selector) {
var nodes = d3.selectAll(selector);
nodes.each(function(d, i) {
nodes[0][i] = this.parentNode.insertBefore(this.cloneNode(true), this.nextSibling); });
return nodes;
}
var trunkToProbScale = d3.scale.linear()
.domain([0,trunkWidth])
.range([0,1])
var drag = d3.behavior.drag()
.on("drag", function(){
g = d3.select(this);
mid = g.data()[0].x
upper = mid + trunkWidth / 2
lower = mid - trunkWidth / 2
topHeight = 0;
g.selectAll("rect")
.attr("height", function(d, i){
if(d3.event.y <= upper && d3.event.y >= lower){
if(d.position == "top"){
topHeight = d3.event.y - lower
return topHeight;
}
return upper - d3.event.y
}
return d3.select(this).attr("height")
})
topProb = trunkToProbScale(topHeight);
botProb = 1 - topProb;
currNodeData = g.data()[0]
currNodeData.children[0].probability = topProb;
currNodeData.children[1].probability = botProb;
d3.select(".gWrapper").call(graph, TREEDATA)
});
d3.probTree = function(){
var width = 1,
height = 1,
trunkWidth = 1,
duration = 600;
var probTree = function(g, data){
// tree
var tree = d3.layout.tree()
.size([height, width])
.separation(function(a,b) {return a.parent == b.parent ? 1 : 1;});
// links
trunkYOffset = 12;
var diagonal = d3.svg.diagonal()
.source(function(d,i){
offset = trunkWidth/2 - (trunkWidth * d.target["probability"])/2;
if(d.target["position"] == "top"){
return {x: d.source.x - offset, y: d.source.y + trunkYOffset};
}
return {x: d.source.x + offset, y: d.source.y + trunkYOffset}; })
.target(function(d, i){
return {x: d.target.x, y: d.target.y - trunkYOffset};})
.projection(function(d) {
return [d.y, d.x]; });
var nodes = tree.nodes(data),
links = tree.links(nodes);
var link = g.selectAll("path.link")
.data(links)
link.enter().append("path")
.attr("class", function(d, i){
return d.target.linkClass; })
.attr("d", diagonal)
.style("stroke-width", function(d, i){
return d.target.probability*trunkWidth; });
link
.attr("d", diagonal)
.attr("stroke-dasharray", function(){
pathLength = this.getTotalLength();
return pathLength + " " + pathLength; })
.style("stroke-width", function(d, i){
return d.target.probability*trunkWidth; });
// nodes
var node = g.selectAll(".node")
.data(nodes)
.enter().append("g")
.attr("class", function(d, i){
return d.nodeClass; })
.attr("transform", function(d) {
return "translate(" + d.y + "," + d.x + ")"; })
g.selectAll(".node.event").each(function(){
selection = d3.select(this).call(drag);
selection.selectAll("rect")
.data(selection.data()[0].children)
.enter().append("rect")
.attr("width", 20)
.attr("height", function(d, i){
return d.probability * trunkWidth; })
.attr("x", -10)
.attr("y", -trunkWidth/2)
.attr("transform",function(d, i){
return d["position"] == "bot" ? "scale(1,-1)" : null; })
.style("stroke","black")
.style("fill", "white")
selection.selectAll("text")
.data(selection.data()[0].children)
.enter().append("text")
.attr("text-anchor", "middle")
.attr("y", function(d){
return d["position"] == "top" ? -(trunkWidth/2 + 10) : trunkWidth/2 + 10;})
.attr("dy", ".35em")
.text(function(d){
return d.name == "joint"
? "P( " + d.eventOutcome + " | " + d.parent.eventOutcome + " ) = "
: "P( " + d.eventOutcome + " ) = "; })
.append("tspan")
.attr("class","eventText")
selection.selectAll("tspan.eventText")
.transition()
.text(function(d){
return formatDecimal(d.probability, 4) })
})
// joint nodes
joints = computeJoints(TREEDATA);
var circle = g.selectAll(".node.joint").selectAll("circle")
.data(function(d, i){
return [i]; })
// the index *is* the data
circle.enter().append("circle")
.attr("class",function(d, i){ return "tree-circle-" + d; })
circle
.attr("r", function(d){
return computeRadius(joints[d].jointPr); })
.style("fill", fillColor)
.style("stroke", strokeColor)
.style("opacity", 0.8)
// create these circle borders on first time ONLY
// if (d3.selectAll('.node.joint>circle[class$="border"]')[0].length == 0) {
// var circleBorders = cloneAll("circle[class^='tree-circle']")
// .classed("circle-border", true)
// .classed("hidden", true)
// }
dispatch.statechange(joints);
jointsNames = computeJointsNames(TREEDATA)
var text = g.selectAll(".node.joint").selectAll("text")
.data(function(d, i){
return [i]; })
text.enter().append("text")
.attr("class", "jointText")
.attr("x", -30) //TODO MAGIC NUMBER
.attr("y", 20)
.attr("text-anchor","start")
text.text(function(d){
return "P( " + jointsNames[d] + " ) = " + formatDecimal(joints[d].jointPr, 4) })
}
probTree.width = function(x) {
if (!arguments.length) return width;
width = x;
return probTree;
};
probTree.height = function(x) {
if (!arguments.length) return height;
height = x;
return probTree;
};
probTree.trunkWidth = function(x) {
if (!arguments.length) return trunkWidth;
trunkWidth = x;
return probTree;
};
return probTree;
}
var animateTree = function(step){
return function(){
d3.select(".question." + step)
.transition()
.ease("linear")
.style("opacity", 1)
d3.selectAll(".node." + step)
.transition()
.ease("linear")
.style("opacity", 1);
d3.selectAll(".link." + step)
.attr("stroke-dasharray", function(){
return this.getTotalLength() + " " + this.getTotalLength(); })
.attr("stroke-dashoffset", function() { return this.getTotalLength(); })
.transition()
.duration(1000)
.ease("linear")
.attr("stroke-dashoffset", 0)
.style("opacity", 1)
}
}
var animateJoints = function(){
d3.selectAll(".node.joint")
.transition()
.ease("linear")
.style("opacity", 1)
}
var animateBayes = function(){
d3.selectAll(".mathg")
.transition()
.ease("linear")
.style("opacity", 1)
d3.selectAll(".hidden")
.classed("hidden", false)
}
var graph = d3.probTree()
.width(graphWidth)
.height(graphHeight)
.trunkWidth(trunkWidth);
var questionLabels = function(g){
g.selectAll("text.question")
.data([1, 0])
.enter().append("text")
.attr("text-anchor", "middle")
.attr("class", function(d, i){
return d ? "step1 question" : "step2 question"})
.attr("x", function(d, i){
return d ? 0 : (graphWidth/2)})
.text(function(d, i){
return d ? METADATA["event1"] : METADATA["event2"] })
}
d3.json("probTree.json", function(error, data) {
TREEDATA = data.tree;
METADATA = data.meta;
var g = svg.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")")
.attr("class", "gWrapper")
.call(graph, TREEDATA)
.call(questionLabels)
d3.select("#animateTree-I").on("click", animateTree("step1"))
d3.select("#animateTree-II").on("click", animateTree("step2"))
d3.select("#computeJointProbs").on("click", animateJoints);
d3.select("#computeBayesProb").on("click", animateBayes);
});
</script>
</body>
{
"meta": {
"dataset": "Example 2.55",
"event1": "Truth",
"event2": "Mammogram"
},
"tree": {
"name": "event",
"nodeClass": "node event step1",
"children": [
{
"name": "event",
"position": "top",
"nodeClass": "node event step2",
"linkClass": "link event1 step1",
"eventOutcome": "cancer",
"probability": 0.0035,
"children": [
{
"name": "joint",
"probability": 0.89,
"position": "top",
"nodeClass": "node joint step3",
"linkClass": "link event2 step2",
"eventOutcome": "+"
},
{
"name": "joint",
"position": "bot",
"nodeClass": "node joint step3",
"linkClass": "link event2 step2",
"probability": 0.11,
"eventOutcome": "-"
}
]
},
{
"name": "event",
"position": "bot",
"nodeClass": "node event step2",
"linkClass": "link event1 step1",
"probability": 0.9965,
"eventOutcome": "no cancer",
"children": [
{
"name": "joint",
"probability": 0.07,
"position": "top",
"nodeClass": "node joint step3",
"linkClass": "link event2 step2",
"eventOutcome": "+"
},
{
"name": "joint",
"position": "bot",
"nodeClass": "node joint step3",
"linkClass": "link event2 step2",
"probability": 0.93,
"eventOutcome": "-"
}
]
}
]
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment