Skip to content

Instantly share code, notes, and snippets.

@jaredwinick
Last active May 26, 2018 03:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jaredwinick/31cf37cb0cf5db911eedd54c78e992c1 to your computer and use it in GitHub Desktop.
Save jaredwinick/31cf37cb0cf5db911eedd54c78e992c1 to your computer and use it in GitHub Desktop.
Visualizing Linear Regression by Gradient Descent

Inspired by Professor Ng's lectures in the Coursera Machine Learning class, these animations visualize linear regression (1-variable) by using gradient descent. The graph on the left shows the data we are trying to fit, and the hypothesis line as the variables θ0 and θ1 converge. The plot on the right shows the value of the cost function. The animation loops forever, each time starting with a "random" value of θ0 and θ1.

<!DOCTYPE html>
<meta charset="utf-8">
<style>
body {
font: 10px sans-serif;
}
.axis path,
.axis line {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
.line {
fill: none;
stroke: steelblue;
stroke-width: 1px;
}
</style>
<body>
<script src="http://d3js.org/d3.v4.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3-legend/2.18.0/d3-legend.min.js"></script>
<script>
var margin = {top: 20, right: 20, bottom: 30, left: 50},
width = 450 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var minX = 0;
var minY = 0;
var maxX = 100;
var maxY = 100;
var minSlope = -2;
var maxSlope = 2;
var theta0Generator = d3.randomUniform(minY, maxY);
var theta1Generator = d3.randomUniform(minSlope, maxSlope);
var x = d3.scaleLinear()
.domain([minX, maxX])
.range([0, width]);
var y = d3.scaleLinear()
.domain([minY, maxY])
.range([height, 0]);
var t0 = d3.scaleLinear()
.domain([minY, maxY])
.range([0, width]);
var t1 = d3.scaleLinear()
.domain([minSlope, maxSlope])
.range([height, 0]);
var xAxis = d3.axisBottom()
.scale(x);
var yAxis = d3.axisLeft()
.scale(y);
var t0Axis = d3.axisBottom()
.scale(t0);
var t1Axis = d3.axisLeft()
.scale(t1);
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis);
svg.append("g")
.attr("class", "y axis")
.call(yAxis)
svg.append("text")
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor
.attr("transform", "translate("+ (width/2) +","+(height+margin.bottom)+")") // centre below axis
.text("x");
svg.append("text")
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor
.attr("transform", "translate("+ (-30) +","+(height/2)+")") // centre below axis
.text("y");
var svg2 = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
svg2.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(t0Axis);
svg2.append("g")
.attr("class", "y axis")
.call(t1Axis)
svg2.append("text")
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor
.attr("transform", "translate("+ (width/2) +","+(height+margin.bottom)+")") // centre below axis
.html("&Theta;0");
svg2.append("text")
.attr("text-anchor", "middle") // this makes it easy to centre the text as the transform is applied to the anchor
.attr("transform", "translate("+ (-35) +","+(height/2)+")") // centre below axis
.html("&Theta;1");
/**
* generate data for y = ax + b but at a bit of randomization
* to y so we don't get a perfect line.
*/
function generateData(numberOfPoints, minX, maxX, a, b) {
var xGenerator = d3.randomUniform(minX, maxX);
var yGenerator = d3.randomNormal(0, 10);
return d3.range(numberOfPoints).map(function() {
var x = xGenerator();
var y = (a * x) + b + yGenerator();
return {x: x, y: y};
});
}
// Generate data that we will later fit
var data = generateData(200, minX, maxX, 1.06, 26);
var circles = svg.selectAll("circle")
.data(data)
.enter()
.append("circle");
circles.attr("cx", function(d, i) { return x(d.x); })
.attr("cy", function(d) { return y(d.y); })
.attr("r", 2);
// hypothesis h(x) = theta0 + theta1 * x
function h(x, theta0, theta1) {
return (theta0 + (theta1 * x));
}
/*
* returns {theta0, theta1}
*/
function gradiantDescentStep(currentTheta0, currentTheta1, alpha, data) {
var theta0Sum = data.reduce(function(accumulator, value) {
return accumulator + (h(value.x, currentTheta0, currentTheta1) - value.y);
}, 0);
var theta1Sum = data.reduce(function(accumulator, value) {
return accumulator + ((h(value.x, currentTheta0, currentTheta1) - value.y) * value.x);
}, 0);
var newTheta0 = currentTheta0 - (alpha * (1.0 / data.length) * theta0Sum);
var newTheta1 = currentTheta1 - (alpha * (1.0 / data.length) * theta1Sum);
return {theta0: newTheta0, theta1: newTheta1};
}
function calculateCost(currentTheta0, currentTheta1, data) {
var sum = data.reduce(function(accumulator, value) {
return accumulator + Math.pow((h(value.x, currentTheta0, currentTheta1) - value.y),2);
}, 0);
return sum * (1.0 / (2 * data.length));
}
function calculateLineData(theta0, theta1) {
var y0 = h(minX, theta0, theta1);
var y1 = h(maxX, theta0, theta1);
return [{x: minX, y: y0}, {x: maxX, y: y1}];
}
/*
* This is used to artifically "slow down" the first
* few steps of the gradient descent so we can see
* the line and cost better at the start when the
* variables are changing quickly
*/
function numberOfStepsForIteration(iteration) {
if (iteration < 20) {
return 1;
}
return Math.min(iteration * 2, 1000);
}
var costData = [];
var costScale = d3.scaleLog()
.domain([ 50, 100, 400, 1000, 4000 ])
.range([d3.rgb("#2c7bb6"), d3.rgb('#00ccbc'), d3.rgb('#ffff8c'), d3.rgb('#f29e2e'), d3.rgb('#d7191c')]);
svg2.append("g")
.attr("class", "legendLog")
.attr("transform", "translate(" + (width-40) + ",10)");
var legend = d3.legendColor()
.cells([50, 100, 400, 1000, 4000])
.title("Cost")
.scale(costScale);
svg2.select(".legendLog")
.call(legend);
function runGradientDescent() {
var theta0 = theta0Generator();
var theta1 = theta1Generator();
var lineData = calculateLineData(theta0, theta1);
var l = svg.append("line")
.attr("class", "line")
.attr("x1", x(lineData[0].x))
.attr("y1", y(lineData[0].y))
.attr("x2", x(lineData[1].x))
.attr("y2", y(lineData[1].y));
var iteration = 1;
var t = d3.timer(function() {
lineData = calculateLineData(theta0, theta1);
l.attr("x1", x(lineData[0].x))
.attr("y1", y(lineData[0].y))
.attr("x2", x(lineData[1].x))
.attr("y2", y(lineData[1].y));
// add a new point to the cost data to render
var cost = calculateCost(theta0, theta1, data);
costData.push({x: theta0, y: theta1, cost: cost});
console.log("cost:" + calculateCost(theta0, theta1, data));
// we are just running a fixed number of iterations
// as opposed to checking for convergence
if (iteration > 30000) {
t.stop();
// start all over again
runGradientDescent();
}
previousCost = cost;
svg2.selectAll("circle")
.data(costData)
.enter().append("circle")
.attr("cx", function(d, i) { return t0(d.x); })
.attr("cy", function(d) { return t1(d.y); })
.attr("r", 4)
.style("fill", function(d) { return costScale(d.cost); });
for (var i = 0; i < numberOfStepsForIteration(iteration); ++i) {
var update = gradiantDescentStep(theta0, theta1, .0005, data);
theta0 = update.theta0;
theta1 = update.theta1;
++iteration;
}
}, 500);
}
runGradientDescent();
</script>
</body>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment