Last active
November 26, 2018 13:32
-
-
Save duhaime/afbb5535c07b1eed68febe31740359b5 to your computer and use it in GitHub Desktop.
Gradient Descent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
license: MIT | |
height: 560 | |
scrolling: no | |
border: yes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset='UTF-8'> | |
<title>Gradient Descent</title> | |
<style> | |
* { | |
font-family: courier; | |
} | |
.gd-container { | |
display: inline-block; | |
text-align: center; | |
} | |
input { | |
position: relative; | |
top: 7px; | |
} | |
code { | |
display: block; | |
} | |
.numerator { | |
display: inline-block; | |
border-bottom: 1px solid black; | |
} | |
.denominator { | |
display: block; | |
} | |
.left { | |
display: inline-block; | |
} | |
.right { | |
vertical-align: top; | |
position: relative; | |
top: 7px; | |
} | |
div#input-container { | |
position: relative; | |
font-size: 11px; | |
display: inline-block; | |
} | |
div#min { | |
position: absolute; | |
left: 0; | |
padding: 3px; | |
} | |
div#max { | |
position: absolute; | |
right: 0; | |
padding: 3px; | |
} | |
</style> | |
</head> | |
<body> | |
<div class='gd-container'> | |
<h1>Gradient Descent</h1> | |
<div> | |
<span>alpha</span> | |
<div id='input-container'> | |
<input id='alpha' type='range' value='10' min='10' max='1001'> | |
<div id='min'>min</div> | |
<div id='max'>max</div> | |
</div> | |
<button id='restart'>Restart</button> | |
</div> | |
<svg id='gd'></svg> | |
<code>f(x) = x<sup>2</sup></code> | |
<br/> | |
<code> | |
<span class='left'> | |
<span class='numerator'> | |
<i>df</i> | |
</span> | |
<span class='denominator'> | |
<i>dx</i> | |
</span> | |
</span> | |
<span class='right'> = 2x</span> | |
</code> | |
</div> | |
<script src='https://d3js.org/d3.v5.min.js'></script> | |
<script> | |
var w = 480, | |
h = 350; | |
var svg = d3.select('#gd') | |
.attr('width', w) | |
.attr('height', h) | |
// identify the function to plot | |
function f(x) { return x**2 }; | |
// use f() to generate points | |
var data = [], | |
x = -6; | |
for (var i=0; i<1200; i++) { | |
x += 0.01; | |
data.push([x, f(x)]); | |
} | |
// get chart scales given the data | |
var scales = { | |
x: d3.scaleLinear() | |
.domain(d3.extent(data, function(d) { return d[0]; })) | |
.range([10, w-10]), | |
y: d3.scaleLinear() | |
.domain(d3.extent(data, function(d) { return d[1]; })) | |
.range([h-10, 10]), | |
} | |
// plot the points | |
svg.selectAll('circle').data(data).enter() | |
.append('circle') | |
.attr('cx', function(d) { return scales.x(d[0]); }) | |
.attr('cy', function(d) { return scales.y(d[1]); }) | |
.attr('r', 1) | |
.attr('fill', 'gray') | |
// create function to show the minimum estimate | |
function drawEstimate() { | |
svg.append('circle') | |
.attr('class', 'gd-estimate') | |
.attr('fill', 'red') | |
.attr('r', 4) | |
.attr('cx', scales.x(estimate)) | |
.attr('cy', scales.y(f(estimate))) | |
} | |
// create an iteration function that computes the gradient | |
// and updates the estimate | |
function iterate() { | |
var alpha = parseFloat(document.querySelector('#alpha').value) / 1000, | |
oldEstimate = estimate, | |
gradient = 2 * estimate, | |
scaled = gradient * alpha; | |
estimate -= scaled; | |
// to watch the learning unfold, uncomment the following | |
//console.log(' *', gradient, scaled, estimate, oldEstimate - estimate) | |
drawEstimate() | |
// continue iterating until the change between new and old | |
// estimates is minimal | |
if (Math.abs(oldEstimate - estimate) > .001) { | |
window.setTimeout(iterate, 200) | |
} | |
} | |
// restart the simulation | |
document.querySelector('#restart').addEventListener('click', function() { | |
d3.selectAll('.gd-estimate').remove(); | |
estimate = Math.random() > 0.5 ? 5 : -5; | |
drawEstimate(); | |
iterate(); | |
}) | |
// initialize and draw the estimate state | |
var estimate = 5; | |
drawEstimate(); | |
iterate(); | |
</script> | |
</body> | |
</html> | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment