Skip to content

Instantly share code, notes, and snippets.

@bricedev
Last active March 24, 2016 16:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bricedev/1bd45a5f6d727499ee46 to your computer and use it in GitHub Desktop.
Save bricedev/1bd45a5f6d727499ee46 to your computer and use it in GitHub Desktop.
Gradient descent

Linear regression using gradient descent method

population profit
6.1101 17.592
5.5277 9.1302
8.5186 13.662
7.0032 11.854
5.8598 6.8233
8.3829 11.886
7.4764 4.3483
8.5781 12
6.4862 6.5987
5.0546 3.8166
5.7107 3.2522
14.164 15.505
5.734 3.1551
8.4084 7.2258
5.6407 0.71618
5.3794 3.5129
6.3654 5.3048
5.1301 0.56077
6.4296 3.6518
7.0708 5.3893
6.1891 3.1386
20.27 21.767
5.4901 4.263
6.3261 5.1875
5.5649 3.0825
18.945 22.638
12.828 13.501
10.957 7.0467
13.176 14.692
22.203 24.147
5.2524 -1.22
6.5894 5.9966
9.2482 12.134
5.8918 1.8495
8.2111 6.5426
7.9334 4.5623
8.0959 4.1164
5.6063 3.3928
12.836 10.117
6.3534 5.4974
5.4069 0.55657
6.8825 3.9115
11.708 5.3854
5.7737 2.4406
7.8247 6.7318
7.0931 1.0463
5.0702 5.1337
5.8014 1.844
11.7 8.0043
5.5416 1.0179
7.5402 6.7504
5.3077 1.8396
7.4239 4.2885
7.6031 4.9981
6.3328 1.4233
6.3589 -1.4211
6.2742 2.4756
5.6397 4.6042
9.3102 3.9624
9.4536 5.4141
8.8254 5.1694
5.1793 -0.74279
21.279 17.929
14.908 12.054
18.959 17.054
7.2182 4.8852
8.2951 5.7442
10.236 7.7754
5.4994 1.0173
20.341 20.992
10.136 6.6799
7.3345 4.0259
6.0062 1.2784
7.2259 3.3411
5.0269 -2.6807
6.5479 0.29678
7.5386 3.8845
5.0365 5.7014
10.274 6.7526
5.1077 2.0576
5.7292 0.47953
5.1884 0.20421
6.3557 0.67861
9.7687 7.5435
6.5159 5.3436
8.5172 4.2415
9.1802 6.7981
6.002 0.92695
5.5204 0.152
5.0594 2.8214
5.7077 1.8451
7.6366 4.2959
5.8707 7.2029
5.3054 1.9869
8.2934 0.14454
13.394 9.0551
5.4369 0.61705
<!DOCTYPE html>
<meta charset="utf-8">
<style>
body {
font: 10px sans-serif;
}
.axis path,
.axis line {
fill: none;
stroke: #000;
shape-rendering: crispEdges;
}
.line {
fill: none;
stroke: black;
stroke-width: 1px;
}
</style>
<body>
<script src="https://d3js.org/d3.v3.min.js"></script>
<script>
var margin = {top: 20, right: 20, bottom: 30, left: 40},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var format = d3.format(".3f");
var x = d3.scale.linear()
.range([0, width]);
var y = d3.scale.linear()
.range([height, 0]);
var xAxis = d3.svg.axis()
.scale(x)
.orient("bottom");
var yAxis = d3.svg.axis()
.scale(y)
.orient("left");
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
d3.csv("data.csv", function(error, data) {
data.forEach(function(d) {
d.population = +d.population;
d.profit = +d.profit;
});
x.domain(d3.extent(data, function(d) { return d.population; })).nice();
y.domain(d3.extent(data, function(d) { return d.profit; })).nice();
var xMin = x.domain()[0],
xMax = x.domain()[1],
yMin = y.domain()[0],
yMax = y.domain()[1];
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis)
.append("text")
.attr("class", "label")
.attr("x", width)
.attr("y", -6)
.style("text-anchor", "end")
.style("font-weight","bold")
.text("Population of City in 10,000s");
svg.append("g")
.attr("class", "y axis")
.call(yAxis)
.append("text")
.attr("class", "label")
.attr("transform", "rotate(-90)")
.attr("y", 6)
.attr("dy", ".71em")
.style("font-weight","bold")
.style("text-anchor", "end")
.text("Profit in $10,000s")
svg.selectAll(".dot")
.data(data)
.enter().append("circle")
.attr("class", "dot")
.attr("r", 3.5)
.attr("cx", function(d) { return x(d.population); })
.attr("cy", function(d) { return y(d.profit); })
.style("fill","#d73027");
// Some gradient descent settings
var iteration = 0,
iterationNumber = 1500,
m = data.length,
alpha = 0.01;
theta0 = 0,
theta1 = 0;
var line = svg.append("line")
.attr("class", "line")
.attr("x1",x( xMin ))
.attr("y1",y( theta1 * xMin + theta0 ))
.attr("x2",x( xMax ))
.attr("y2",y( theta1 * xMax + theta0 ));
var hyp = svg.append("text")
.attr("x", width/2)
.attr("y", 40)
.style("text-anchor","middle")
.style("font-size","35px")
.text("hθ(x) = 0 + 0x");
function computeCost (data, theta0, theta1) {
var cost = 0;
data.forEach(function(d) {
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2);
});
return cost/(2 * m);
};
d3.timer(function() {
var temp0 = theta0 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); }));
var temp1 = theta1 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; }));
theta0 = temp0;
theta1 = temp1;
line.attr("x1",x( xMin ))
.attr("y1",y( theta1 * xMin + theta0 ))
.attr("x2",x( xMax ))
.attr("y2",y( theta1 * xMax + theta0 ));
hyp.text("hθ(x) = " + format(theta0) + " + " + format(theta1) + "x");
return ++iteration > iterationNumber;
},200);
});
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment