Skip to content

Instantly share code, notes, and snippets.

@bricedev
Last active March 24, 2016 16:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bricedev/b426500ba721d1b1a7fe to your computer and use it in GitHub Desktop.
Save bricedev/b426500ba721d1b1a7fe to your computer and use it in GitHub Desktop.
Learning rate

Cost function J(θ) convergence for different learning rate α values.

population profit
6.1101 17.592
5.5277 9.1302
8.5186 13.662
7.0032 11.854
5.8598 6.8233
8.3829 11.886
7.4764 4.3483
8.5781 12
6.4862 6.5987
5.0546 3.8166
5.7107 3.2522
14.164 15.505
5.734 3.1551
8.4084 7.2258
5.6407 0.71618
5.3794 3.5129
6.3654 5.3048
5.1301 0.56077
6.4296 3.6518
7.0708 5.3893
6.1891 3.1386
20.27 21.767
5.4901 4.263
6.3261 5.1875
5.5649 3.0825
18.945 22.638
12.828 13.501
10.957 7.0467
13.176 14.692
22.203 24.147
5.2524 -1.22
6.5894 5.9966
9.2482 12.134
5.8918 1.8495
8.2111 6.5426
7.9334 4.5623
8.0959 4.1164
5.6063 3.3928
12.836 10.117
6.3534 5.4974
5.4069 0.55657
6.8825 3.9115
11.708 5.3854
5.7737 2.4406
7.8247 6.7318
7.0931 1.0463
5.0702 5.1337
5.8014 1.844
11.7 8.0043
5.5416 1.0179
7.5402 6.7504
5.3077 1.8396
7.4239 4.2885
7.6031 4.9981
6.3328 1.4233
6.3589 -1.4211
6.2742 2.4756
5.6397 4.6042
9.3102 3.9624
9.4536 5.4141
8.8254 5.1694
5.1793 -0.74279
21.279 17.929
14.908 12.054
18.959 17.054
7.2182 4.8852
8.2951 5.7442
10.236 7.7754
5.4994 1.0173
20.341 20.992
10.136 6.6799
7.3345 4.0259
6.0062 1.2784
7.2259 3.3411
5.0269 -2.6807
6.5479 0.29678
7.5386 3.8845
5.0365 5.7014
10.274 6.7526
5.1077 2.0576
5.7292 0.47953
5.1884 0.20421
6.3557 0.67861
9.7687 7.5435
6.5159 5.3436
8.5172 4.2415
9.1802 6.7981
6.002 0.92695
5.5204 0.152
5.0594 2.8214
5.7077 1.8451
7.6366 4.2959
5.8707 7.2029
5.3054 1.9869
8.2934 0.14454
13.394 9.0551
5.4369 0.61705
<!DOCTYPE html>
<meta charset="utf-8">
<style>
body {
font: 10px sans-serif;
}
.axis path,
.axis line {
fill: none;
stroke: #fff;
shape-rendering: crispEdges;
}
.axis text {
font-weight: bold;
}
.line {
fill: none;
stroke: #000;
stroke-width: 1.2px;
stroke-linejoin: round;
}
</style>
<body>
<script src="https://d3js.org/d3.v3.min.js"></script>
<script>
var margin = {top: 20, right: 20, bottom: 20, left: 30},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var x = d3.scale.linear()
.range([0, width]);
var y = d3.scale.linear()
.range([height, 0]);
var xAxis = d3.svg.axis()
.scale(x)
.tickSize(-height)
.tickPadding(8)
.orient("bottom");
var yAxis = d3.svg.axis()
.scale(y)
.tickSize(-width)
.tickPadding(8)
.orient("left");
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
d3.csv("data.csv", function(error, data) {
// Some gradient descent settings
var iterationNumber = 150,
m = data.length,
alpha = [0.01,0.001,0.0005,0.0001],
theta0 = 0,
theta1 = 0;
data.forEach(function(d) {
d.population = +d.population;
d.profit = +d.profit;
});
var values = alpha.map(function(alphaValue) {
var costHistory = [];
theta0 = 0;
theta1 = 0;
for(i=0;i<iterationNumber;i++) {
costHistory.push({iteration: i, cost: computeCost(data, theta0, theta1) });
var temp0 = theta0 - alphaValue * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); }));
var temp1 = theta1 - alphaValue * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; }));
theta0 = temp0;
theta1 = temp1;
};
return {alpha: alphaValue, data: costHistory};
})
x.domain([0,iterationNumber]);
var yMin = d3.min(values, function(alphaValue) { return d3.min(alphaValue.data, function(d) { return d.cost; }); });
var yMax = d3.max(values, function(alphaValue) { return d3.max(alphaValue.data, function(d) { return d.cost; }); });
y.domain([yMin,yMax]).nice();
var line = d3.svg.line()
.interpolate("basis")
.x(function(d) { return x(d.iteration); })
.y(function(d) { return y(d.cost); });
svg.append("rect")
.attr("class", "background")
.attr("width", width)
.attr("height", height)
.style("fill","#e7e7e7");
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis)
.append("text")
.attr("class", "label")
.attr("x", width-5)
.attr("y", -6)
.style("text-anchor", "end")
.style("font-weight","bold")
.text("Number of iterations");
svg.append("g")
.attr("class", "y axis")
.call(yAxis)
.append("text")
.attr("class", "label")
.attr("transform", "rotate(-90)")
.attr("dx","-.71em")
.attr("y", 6)
.attr("dy", ".71em")
.style("font-weight","bold")
.style("text-anchor", "end")
.text("J(θ)")
var path = svg.selectAll(".line")
.data(values).enter()
path.append("path")
.attr("id", function(d) { console.log(d); return d.alpha; })
.attr("class","line")
.attr("d", function(d) { return line(d.data); });
path.append("text")
.attr("dy", "-3px")
.append("textPath")
.attr("xlink:href", function(d) { return "#" + d.alpha; })
.attr("startOffset", "24%")
.style("font-weight","bold")
.text(function(d) { return "α = " + d.alpha; });
function computeCost (data, theta0, theta1) {
var cost = 0;
data.forEach(function(d) {
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2);
});
return cost/(2 * m);
};
});
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment