Skip to content

Instantly share code, notes, and snippets.

@duhaime
Last active November 26, 2018 13:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save duhaime/afbb5535c07b1eed68febe31740359b5 to your computer and use it in GitHub Desktop.
Save duhaime/afbb5535c07b1eed68febe31740359b5 to your computer and use it in GitHub Desktop.
Gradient Descent
license: MIT
height: 560
scrolling: no
border: yes
<!DOCTYPE html>
<html>
<head>
<meta charset='UTF-8'>
<title>Gradient Descent</title>
<style>
* {
font-family: courier;
}
.gd-container {
display: inline-block;
text-align: center;
}
input {
position: relative;
top: 7px;
}
code {
display: block;
}
.numerator {
display: inline-block;
border-bottom: 1px solid black;
}
.denominator {
display: block;
}
.left {
display: inline-block;
}
.right {
vertical-align: top;
position: relative;
top: 7px;
}
div#input-container {
position: relative;
font-size: 11px;
display: inline-block;
}
div#min {
position: absolute;
left: 0;
padding: 3px;
}
div#max {
position: absolute;
right: 0;
padding: 3px;
}
</style>
</head>
<body>
<div class='gd-container'>
<h1>Gradient Descent</h1>
<div>
<span>alpha</span>
<div id='input-container'>
<input id='alpha' type='range' value='10' min='10' max='1001'>
<div id='min'>min</div>
<div id='max'>max</div>
</div>
<button id='restart'>Restart</button>
</div>
<svg id='gd'></svg>
<code>f(x) = x<sup>2</sup></code>
<br/>
<code>
<span class='left'>
<span class='numerator'>
<i>df</i>
</span>
<span class='denominator'>
<i>dx</i>
</span>
</span>
<span class='right'> = 2x</span>
</code>
</div>
<script src='https://d3js.org/d3.v5.min.js'></script>
<script>
var w = 480,
h = 350;
var svg = d3.select('#gd')
.attr('width', w)
.attr('height', h)
// identify the function to plot
function f(x) { return x**2 };
// use f() to generate points
var data = [],
x = -6;
for (var i=0; i<1200; i++) {
x += 0.01;
data.push([x, f(x)]);
}
// get chart scales given the data
var scales = {
x: d3.scaleLinear()
.domain(d3.extent(data, function(d) { return d[0]; }))
.range([10, w-10]),
y: d3.scaleLinear()
.domain(d3.extent(data, function(d) { return d[1]; }))
.range([h-10, 10]),
}
// plot the points
svg.selectAll('circle').data(data).enter()
.append('circle')
.attr('cx', function(d) { return scales.x(d[0]); })
.attr('cy', function(d) { return scales.y(d[1]); })
.attr('r', 1)
.attr('fill', 'gray')
// create function to show the minimum estimate
function drawEstimate() {
svg.append('circle')
.attr('class', 'gd-estimate')
.attr('fill', 'red')
.attr('r', 4)
.attr('cx', scales.x(estimate))
.attr('cy', scales.y(f(estimate)))
}
// create an iteration function that computes the gradient
// and updates the estimate
function iterate() {
var alpha = parseFloat(document.querySelector('#alpha').value) / 1000,
oldEstimate = estimate,
gradient = 2 * estimate,
scaled = gradient * alpha;
estimate -= scaled;
// to watch the learning unfold, uncomment the following
//console.log(' *', gradient, scaled, estimate, oldEstimate - estimate)
drawEstimate()
// continue iterating until the change between new and old
// estimates is minimal
if (Math.abs(oldEstimate - estimate) > .001) {
window.setTimeout(iterate, 200)
}
}
// restart the simulation
document.querySelector('#restart').addEventListener('click', function() {
d3.selectAll('.gd-estimate').remove();
estimate = Math.random() > 0.5 ? 5 : -5;
drawEstimate();
iterate();
})
// initialize and draw the estimate state
var estimate = 5;
drawEstimate();
iterate();
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment