Skip to content

Instantly share code, notes, and snippets.

@feyderm
Last active April 14, 2017 23:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save feyderm/6bd8e75420d7aff0b19aa204651eab76 to your computer and use it in GitHub Desktop.
Save feyderm/6bd8e75420d7aff0b19aa204651eab76 to your computer and use it in GitHub Desktop.
Exploring Gradient Descent with Momentum

The green decision boundary uses momentum, while the grey decision boundary does not. Inspired by a recent article on Distill.

<!DOCTYPE html>
<meta charset="utf-8">
<style>
text {
font-family: sans-serif;
fill: #000000;
}
.pts {
stroke: #595959;
}
.group1 {
fill: steelblue;
}
.group2 {
fill: red;
}
line {
fill: none;
opacity: 0.6;
}
#dec_boundary {
stroke: #000000;
stroke-width: 2px;
}
#dec_boundary_m {
stroke: #008000;
stroke-width: 4px;
}
#beta_val {
font-family: sans-serif;
position: relative;
left: 20px;
}
</style>
<body>
<!--range slider for beta (i.e. momentum coefficient)-->
<form>
<input type="range" name="beta" min="0" max="1.0" step="0.01" oninput=displayBeta(this.value) onchange=runGradientDescent(this.value)>
<label id="beta_val"></label>
</form>
<!--viz-->
<div id="chart"></div>
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="http://feyderm.github.io/math/math.js"></script>
<script type="text/javascript">
// dims
var margin = {top: 20, right: 0, bottom: 50, left: 85},
svg_dx = 500,
svg_dy = 400,
plot_dx = svg_dx - margin.right - margin.left,
plot_dy = svg_dy - margin.top - margin.bottom;
// scales
var xPos = d3.scaleLinear()
.range([margin.left, plot_dx]),
yPos = d3.scaleLinear()
.range([plot_dy, margin.top]);
var svg = d3.select("#chart")
.append("svg")
.attr("width", svg_dx)
.attr("height", svg_dy);
d3.csv("logistic_reg_grad_decent.csv", d => {
xPos.domain(d3.extent(d, d => +d.x));
yPos.domain(d3.extent(d, d => +d.y));
plotAxes(d3.axisBottom(xPos), d3.axisLeft(yPos));
plotPts(d);
runGradientDescent(0.5); // initial beta = 0.5
});
function runGradientDescent(beta) {
removeDecBnds();
displayBeta(beta);
var d = d3.selectAll(".pts").data();
var d_extent_x = d3.extent(d, pt => +pt.x);
var X = d.map(pt => [1, +pt.x, +pt.y]),
y = d.map(pt => +pt.group);
X = math.matrix(X);
y = math.matrix(y);
var iteration = 0,
iterationNumber = 400,
m = math.subset(math.size(X), math.index(0)),
alpha = 0.0004,
velocity = math.matrix([0.0, 0.0, 0.0]),
theta = math.matrix([-24, 0.5, 0.2]),
theta_m = math.matrix([-24, 0.5, 0.2]);
// decision boundary w/o momentum
var dec_bnd = svg.append("line")
.attr("class", "dec_boundary")
.attr("id", "dec_boundary");
// decision boundary w/ momentum
var dec_bnd_m = svg.append("line")
.attr("class", "dec_boundary")
.attr("id", "dec_boundary_m");
var iterate = d3.timer(() => {
// update theta w/o momentum and plot decision boundary
var h = math.multiply(X, theta).map(z => sigmoid(z)),
grad = computeGradient(m, y, h, X);
theta = theta.map((t, i) => t - (alpha * math.subset(grad, math.index(i))))
updateDecisionBoundary(dec_bnd, theta, d_extent_x);
// update theta w/ momemtum and plot decision boundary
var h_m = math.multiply(X, theta_m).map(z => sigmoid(z)),
grad_m = computeGradient(m, y, h_m, X);
// velocity = beta * velocity + grad_m
velocity = math.add(math.multiply(beta, velocity), grad_m);
theta_m = theta_m.map((t, i) => t - (alpha * math.subset(velocity, math.index(i))))
updateDecisionBoundary(dec_bnd_m, theta_m, d_extent_x);
if (iteration++ > iterationNumber) {
iterate.stop();
}
}, 200)
}
function updateDecisionBoundary(dec_bnd, theta, d_extent_x) {
var theta0 = math.subset(theta, math.index(0)),
theta1 = math.subset(theta, math.index(1)),
theta2 = math.subset(theta, math.index(2));
dec_bnd.attr("x1",xPos(d_extent_x[0]))
.attr("y1",yPos((-1 / theta2) * (theta1 * d_extent_x[0] + theta0)))
.attr("x2",xPos(d_extent_x[1]))
.attr("y2",yPos((-1 / theta2) * (theta1 * (d_extent_x[1] * .95) + theta0)));
}
function sigmoid(z) {
var s = 1 / (1 + Math.pow(Math.E, -z));
return s;
}
function computeGradient(m, y, h, X) {
// conversion from octave of grad = (1 / m) * (h - y)' * X;
var grad = math.multiply(h.map((h, i) => h - math.subset(y, math.index(i))), X)
.map(d => (1 / m) * d);
return grad;
}
function removeDecBnds() {
d3.selectAll(".dec_boundary").remove();
}
function displayBeta(beta) {
d3.select("#beta_val")
.text("Momentum Coefficient: " + beta);
}
function plotPts(d) {
svg.append("g")
.selectAll("path")
.data(d)
.enter()
.append("path")
.attr("class", d => d.group == "1" ? "pts group1" : "pts group2")
.attr("d", d3.symbol().type((d,i) => d.group == "1" ? d3.symbolCircle : d3.symbolCross))
.attr("transform", d => "translate(" + xPos(d.x) + "," + yPos(d.y) + ")");
}
function plotAxes(x, y) {
svg.append("g")
.attr("id", "axis_x")
.attr("transform", "translate(0," + (plot_dy + margin.bottom / 2) + ")")
.call(x);
svg.append("g")
.attr("id", "axis_y")
.attr("transform", "translate(" + (margin.left / 2) + ", 0)")
.call(y);
}
</script>
</body>
x y group
34.62365962451697 78.0246928153624 0
30.28671076822607 43.89499752400101 0
35.84740876993872 72.90219802708364 0
60.18259938620976 86.30855209546826 1
79.0327360507101 75.3443764369103 1
45.08327747668339 56.3163717815305 0
61.10666453684766 96.51142588489624 1
75.02474556738889 46.55401354116538 1
76.09878670226257 87.42056971926803 1
84.43281996120035 43.53339331072109 1
95.86155507093572 38.22527805795094 0
75.01365838958247 30.60326323428011 0
82.30705337399482 76.48196330235604 1
69.36458875970939 97.71869196188608 1
39.53833914367223 76.03681085115882 0
53.9710521485623 89.20735013750205 1
69.07014406283025 52.74046973016765 1
67.94685547711617 46.67857410673128 0
70.66150955499435 92.92713789364831 1
76.97878372747498 47.57596364975532 1
67.37202754570876 42.83843832029179 0
89.67677575072079 65.79936592745237 1
50.534788289883 48.85581152764205 0
34.21206097786789 44.20952859866288 0
77.9240914545704 68.9723599933059 1
62.27101367004632 69.95445795447587 1
80.1901807509566 44.82162893218353 1
93.114388797442 38.80067033713209 0
61.83020602312595 50.25610789244621 0
38.78580379679423 64.99568095539578 0
61.379289447425 72.80788731317097 1
85.40451939411645 57.05198397627122 1
52.10797973193984 63.12762376881715 0
52.04540476831827 69.43286012045222 1
40.23689373545111 71.16774802184875 0
54.63510555424817 52.21388588061123 0
33.91550010906887 98.86943574220611 0
64.17698887494485 80.90806058670817 1
74.78925295941542 41.57341522824434 0
34.1836400264419 75.2377203360134 0
83.90239366249155 56.30804621605327 1
51.54772026906181 46.85629026349976 0
94.44336776917852 65.56892160559052 1
82.36875375713919 40.61825515970618 0
51.04775177128865 45.82270145776001 0
62.22267576120188 52.06099194836679 0
77.19303492601364 70.45820000180959 1
97.77159928000232 86.7278223300282 1
62.07306379667647 96.76882412413983 1
91.56497449807442 88.69629254546599 1
79.94481794066932 74.16311935043758 1
99.2725269292572 60.99903099844988 1
90.54671411399852 43.39060180650027 1
34.52451385320009 60.39634245837173 0
50.2864961189907 49.80453881323059 0
49.58667721632031 59.80895099453265 0
97.64563396007767 68.86157272420604 1
32.57720016809309 95.59854761387875 0
74.24869136721598 69.82457122657193 1
71.79646205863379 78.45356224515052 1
75.3956114656803 85.75993667331619 1
35.28611281526193 47.02051394723416 0
56.25381749711624 39.26147251058019 0
30.05882244669796 49.59297386723685 0
44.66826172480893 66.45008614558913 0
66.56089447242954 41.09209807936973 0
40.45755098375164 97.53518548909936 1
49.07256321908844 51.88321182073966 0
80.27957401466998 92.11606081344084 1
66.74671856944039 60.99139402740988 1
32.72283304060323 43.30717306430063 0
64.0393204150601 78.03168802018232 1
72.34649422579923 96.22759296761404 1
60.45788573918959 73.09499809758037 1
58.84095621726802 75.85844831279042 1
99.82785779692128 72.36925193383885 1
47.26426910848174 88.47586499559782 1
50.45815980285988 75.80985952982456 1
60.45555629271532 42.50840943572217 0
82.22666157785568 42.71987853716458 0
88.9138964166533 69.80378889835472 1
94.83450672430196 45.69430680250754 1
67.31925746917527 66.58935317747915 1
57.23870631569862 59.51428198012956 1
80.36675600171273 90.96014789746954 1
68.46852178591112 85.59430710452014 1
42.0754545384731 78.84478600148043 0
75.47770200533905 90.42453899753964 1
78.63542434898018 96.64742716885644 1
52.34800398794107 60.76950525602592 0
94.09433112516793 77.15910509073893 1
90.44855097096364 87.50879176484702 1
55.48216114069585 35.57070347228866 0
74.49269241843041 84.84513684930135 1
89.84580670720979 45.35828361091658 1
83.48916274498238 48.38028579728175 1
42.2617008099817 87.10385094025457 1
99.31500880510394 68.77540947206617 1
55.34001756003703 64.9319380069486 1
74.77589300092767 89.52981289513276 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment