Skip to content

Instantly share code, notes, and snippets.

@duhaime
Last active November 14, 2018 17:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save duhaime/0805ea30469cb970e3202e400b60ad2a to your computer and use it in GitHub Desktop.
Save duhaime/0805ea30469cb970e3202e400b60ad2a to your computer and use it in GitHub Desktop.
Linear Regression Gradient Descent
<!DOCTYPE html>
<html>
<head>
<meta charset='UTF-8'>
<title>Linear Regression</title>
<style>
* {
font-family: courier;
}
.lr-container {
display: inline-block;
text-align: center;
}
</style>
</head>
<body>
<script src='https://d3js.org/d3.v5.min.js'></script>
<div class='lr-container'>
<h1>Linear Regression Gradient Descent</h1>
<div>
<span>alpha</span>
<input id='alpha' type='range' min='1' max='1001' step='1' value='400' >
</div>
<svg id='lr'></svg>
</div>
<script>
var w = 500,
h = 500;
var svg = d3.select('#lr')
.attr('width', w)
.attr('height', h)
// generate some sample data
var params = {
m: 0.3,
b: 2.0,
}
// generate fake data that fits line defined by params (with noise)
var data = [];
for (var i=0; i<10; i += 0.1) {
data.push({
x: i,
y: i * params.m + params.b + (Math.random() - 0.5),
})
}
// get data domains
var domains = {
x: d3.extent(data, function(d) { return d.x; }),
y: [0, 7],
}
// get data scales
var scales = {
x: d3.scaleLinear()
.domain(domains.x)
.range([10, w-10]),
y: d3.scaleLinear()
.domain(domains.y)
.range([h-10, 10]),
}
// initialize line of best fit with random parameters
var estimate = {
m: Math.random() * 10 - 5,
b: Math.random() * 5,
}
// draw the data points
svg.selectAll('circle').data(data).enter()
.append('circle')
.attr('cx', function(d) { return scales.x(d.x); })
.attr('cy', function(d) { return scales.y(d.y); })
.attr('r', 3)
.attr('fill', 'gray')
// draw the line of best fit
svg.append('line')
.attr('id', 'best-fit')
.attr('stroke', 'red')
.attr('stroke-width', 3)
.attr('x1', scales.x(-10) )
.attr('x2', scales.x(10) )
.attr('y1', scales.y(params.m * -10 + params.b) )
.attr('y2', scales.y(params.m * 10 + params.b) )
.attr('stroke-dasharray', '3')
// add a line of best fit to the svg
svg.append('line').attr('id', 'estimate')
// function to update the drawn line of best fit
function drawEstimate() {
svg.select('#estimate').transition()
.duration(200)
.attr('stroke', 'green')
.style('opacity', 0.7)
.attr('stroke-width', 3)
.attr('x1', scales.x(-10))
.attr('x2', scales.x( 10))
.attr('y1', scales.y(estimate.m * -10 + estimate.b))
.attr('y2', scales.y(estimate.m * 10 + estimate.b))
}
function sum(arr) {
return arr.reduce(function(s, i) {
s += i; return s;
}, 0)
}
function iterate() {
/*
* cost function = 1/2 mean squared error
*
* J(theta_0, theta_1) = 1/2m * sum(h(theta)_i - actual_i)**2
*
* where:
* theta_0 = intercept term
* theta_1 = slope coefficient
* m = number of observations
* h = the "hypothesis" for the ith value in data
* ie the output of mx + b for the ith value in data
* actual = the actual y value for the ith value in data
*
* partial derivative with respect to theta_0:
* 1/m * sum(h_theta(x_i) - y_i)
* 1/m * sum(h_theta(x_i) - y_i) * x_i
*/
// find the error terms
var errs = [];
for (var i=0; i<data.length; i++) {
var actual = data[i].y,
predicted = estimate.m * data[i].x + estimate.b;
errs.push(predicted - actual);
}
// find dJ/dtheta_0 (the intercept term)
djdb = sum(errs) / data.length;
// find dJ/dtheta_1 (the slope coefficient)
var prods = [];
for (var i=0; i<data.length; i++) {
prods.push(errs[i] * data[i].x)
}
djdm = sum(prods) / data.length;
// get the alpha value (the learning rate)
var alpha = parseFloat(document.querySelector('#alpha').value) / 10000;
// update the model parameters given these gradients
estimate.b -= (djdb * alpha);
estimate.m -= (djdm * alpha);
// redraw the line
drawEstimate();
// continue iterating until the partial derivatives are minimal
// (as that indicates the model has converged)
if (Math.abs(djdb) > 0.00001 || Math.abs(djdm) > 0.00001) {
requestAnimationFrame(iterate)
}
}
// main
drawEstimate();
iterate()
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment