Last active
November 14, 2018 17:41
-
-
Save duhaime/0805ea30469cb970e3202e400b60ad2a to your computer and use it in GitHub Desktop.
Linear Regression Gradient Descent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset='UTF-8'> | |
<title>Linear Regression</title> | |
<style> | |
* { | |
font-family: courier; | |
} | |
.lr-container { | |
display: inline-block; | |
text-align: center; | |
} | |
</style> | |
</head> | |
<body> | |
<script src='https://d3js.org/d3.v5.min.js'></script> | |
<div class='lr-container'> | |
<h1>Linear Regression Gradient Descent</h1> | |
<div> | |
<span>alpha</span> | |
<input id='alpha' type='range' min='1' max='1001' step='1' value='400' > | |
</div> | |
<svg id='lr'></svg> | |
</div> | |
<script> | |
var w = 500, | |
h = 500; | |
var svg = d3.select('#lr') | |
.attr('width', w) | |
.attr('height', h) | |
// generate some sample data | |
var params = { | |
m: 0.3, | |
b: 2.0, | |
} | |
// generate fake data that fits line defined by params (with noise) | |
var data = []; | |
for (var i=0; i<10; i += 0.1) { | |
data.push({ | |
x: i, | |
y: i * params.m + params.b + (Math.random() - 0.5), | |
}) | |
} | |
// get data domains | |
var domains = { | |
x: d3.extent(data, function(d) { return d.x; }), | |
y: [0, 7], | |
} | |
// get data scales | |
var scales = { | |
x: d3.scaleLinear() | |
.domain(domains.x) | |
.range([10, w-10]), | |
y: d3.scaleLinear() | |
.domain(domains.y) | |
.range([h-10, 10]), | |
} | |
// initialize line of best fit with random parameters | |
var estimate = { | |
m: Math.random() * 10 - 5, | |
b: Math.random() * 5, | |
} | |
// draw the data points | |
svg.selectAll('circle').data(data).enter() | |
.append('circle') | |
.attr('cx', function(d) { return scales.x(d.x); }) | |
.attr('cy', function(d) { return scales.y(d.y); }) | |
.attr('r', 3) | |
.attr('fill', 'gray') | |
// draw the line of best fit | |
svg.append('line') | |
.attr('id', 'best-fit') | |
.attr('stroke', 'red') | |
.attr('stroke-width', 3) | |
.attr('x1', scales.x(-10) ) | |
.attr('x2', scales.x(10) ) | |
.attr('y1', scales.y(params.m * -10 + params.b) ) | |
.attr('y2', scales.y(params.m * 10 + params.b) ) | |
.attr('stroke-dasharray', '3') | |
// add a line of best fit to the svg | |
svg.append('line').attr('id', 'estimate') | |
// function to update the drawn line of best fit | |
function drawEstimate() { | |
svg.select('#estimate').transition() | |
.duration(200) | |
.attr('stroke', 'green') | |
.style('opacity', 0.7) | |
.attr('stroke-width', 3) | |
.attr('x1', scales.x(-10)) | |
.attr('x2', scales.x( 10)) | |
.attr('y1', scales.y(estimate.m * -10 + estimate.b)) | |
.attr('y2', scales.y(estimate.m * 10 + estimate.b)) | |
} | |
function sum(arr) { | |
return arr.reduce(function(s, i) { | |
s += i; return s; | |
}, 0) | |
} | |
function iterate() { | |
/* | |
* cost function = 1/2 mean squared error | |
* | |
* J(theta_0, theta_1) = 1/2m * sum(h(theta)_i - actual_i)**2 | |
* | |
* where: | |
* theta_0 = intercept term | |
* theta_1 = slope coefficient | |
* m = number of observations | |
* h = the "hypothesis" for the ith value in data | |
* ie the output of mx + b for the ith value in data | |
* actual = the actual y value for the ith value in data | |
* | |
* partial derivative with respect to theta_0: | |
* 1/m * sum(h_theta(x_i) - y_i) | |
* 1/m * sum(h_theta(x_i) - y_i) * x_i | |
*/ | |
// find the error terms | |
var errs = []; | |
for (var i=0; i<data.length; i++) { | |
var actual = data[i].y, | |
predicted = estimate.m * data[i].x + estimate.b; | |
errs.push(predicted - actual); | |
} | |
// find dJ/dtheta_0 (the intercept term) | |
djdb = sum(errs) / data.length; | |
// find dJ/dtheta_1 (the slope coefficient) | |
var prods = []; | |
for (var i=0; i<data.length; i++) { | |
prods.push(errs[i] * data[i].x) | |
} | |
djdm = sum(prods) / data.length; | |
// get the alpha value (the learning rate) | |
var alpha = parseFloat(document.querySelector('#alpha').value) / 10000; | |
// update the model parameters given these gradients | |
estimate.b -= (djdb * alpha); | |
estimate.m -= (djdm * alpha); | |
// redraw the line | |
drawEstimate(); | |
// continue iterating until the partial derivatives are minimal | |
// (as that indicates the model has converged) | |
if (Math.abs(djdb) > 0.00001 || Math.abs(djdm) > 0.00001) { | |
requestAnimationFrame(iterate) | |
} | |
} | |
// main | |
drawEstimate(); | |
iterate() | |
</script> | |
</body> | |
</html> | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment