Linear regression using gradient descent method
Last active
March 24, 2016 16:06
-
-
Save bricedev/1bd45a5f6d727499ee46 to your computer and use it in GitHub Desktop.
Gradient descent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
population | profit | |
---|---|---|
6.1101 | 17.592 | |
5.5277 | 9.1302 | |
8.5186 | 13.662 | |
7.0032 | 11.854 | |
5.8598 | 6.8233 | |
8.3829 | 11.886 | |
7.4764 | 4.3483 | |
8.5781 | 12 | |
6.4862 | 6.5987 | |
5.0546 | 3.8166 | |
5.7107 | 3.2522 | |
14.164 | 15.505 | |
5.734 | 3.1551 | |
8.4084 | 7.2258 | |
5.6407 | 0.71618 | |
5.3794 | 3.5129 | |
6.3654 | 5.3048 | |
5.1301 | 0.56077 | |
6.4296 | 3.6518 | |
7.0708 | 5.3893 | |
6.1891 | 3.1386 | |
20.27 | 21.767 | |
5.4901 | 4.263 | |
6.3261 | 5.1875 | |
5.5649 | 3.0825 | |
18.945 | 22.638 | |
12.828 | 13.501 | |
10.957 | 7.0467 | |
13.176 | 14.692 | |
22.203 | 24.147 | |
5.2524 | -1.22 | |
6.5894 | 5.9966 | |
9.2482 | 12.134 | |
5.8918 | 1.8495 | |
8.2111 | 6.5426 | |
7.9334 | 4.5623 | |
8.0959 | 4.1164 | |
5.6063 | 3.3928 | |
12.836 | 10.117 | |
6.3534 | 5.4974 | |
5.4069 | 0.55657 | |
6.8825 | 3.9115 | |
11.708 | 5.3854 | |
5.7737 | 2.4406 | |
7.8247 | 6.7318 | |
7.0931 | 1.0463 | |
5.0702 | 5.1337 | |
5.8014 | 1.844 | |
11.7 | 8.0043 | |
5.5416 | 1.0179 | |
7.5402 | 6.7504 | |
5.3077 | 1.8396 | |
7.4239 | 4.2885 | |
7.6031 | 4.9981 | |
6.3328 | 1.4233 | |
6.3589 | -1.4211 | |
6.2742 | 2.4756 | |
5.6397 | 4.6042 | |
9.3102 | 3.9624 | |
9.4536 | 5.4141 | |
8.8254 | 5.1694 | |
5.1793 | -0.74279 | |
21.279 | 17.929 | |
14.908 | 12.054 | |
18.959 | 17.054 | |
7.2182 | 4.8852 | |
8.2951 | 5.7442 | |
10.236 | 7.7754 | |
5.4994 | 1.0173 | |
20.341 | 20.992 | |
10.136 | 6.6799 | |
7.3345 | 4.0259 | |
6.0062 | 1.2784 | |
7.2259 | 3.3411 | |
5.0269 | -2.6807 | |
6.5479 | 0.29678 | |
7.5386 | 3.8845 | |
5.0365 | 5.7014 | |
10.274 | 6.7526 | |
5.1077 | 2.0576 | |
5.7292 | 0.47953 | |
5.1884 | 0.20421 | |
6.3557 | 0.67861 | |
9.7687 | 7.5435 | |
6.5159 | 5.3436 | |
8.5172 | 4.2415 | |
9.1802 | 6.7981 | |
6.002 | 0.92695 | |
5.5204 | 0.152 | |
5.0594 | 2.8214 | |
5.7077 | 1.8451 | |
7.6366 | 4.2959 | |
5.8707 | 7.2029 | |
5.3054 | 1.9869 | |
8.2934 | 0.14454 | |
13.394 | 9.0551 | |
5.4369 | 0.61705 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
body { | |
font: 10px sans-serif; | |
} | |
.axis path, | |
.axis line { | |
fill: none; | |
stroke: #000; | |
shape-rendering: crispEdges; | |
} | |
.line { | |
fill: none; | |
stroke: black; | |
stroke-width: 1px; | |
} | |
</style> | |
<body> | |
<script src="https://d3js.org/d3.v3.min.js"></script> | |
<script> | |
var margin = {top: 20, right: 20, bottom: 30, left: 40}, | |
width = 960 - margin.left - margin.right, | |
height = 500 - margin.top - margin.bottom; | |
var format = d3.format(".3f"); | |
var x = d3.scale.linear() | |
.range([0, width]); | |
var y = d3.scale.linear() | |
.range([height, 0]); | |
var xAxis = d3.svg.axis() | |
.scale(x) | |
.orient("bottom"); | |
var yAxis = d3.svg.axis() | |
.scale(y) | |
.orient("left"); | |
var svg = d3.select("body").append("svg") | |
.attr("width", width + margin.left + margin.right) | |
.attr("height", height + margin.top + margin.bottom) | |
.append("g") | |
.attr("transform", "translate(" + margin.left + "," + margin.top + ")"); | |
d3.csv("data.csv", function(error, data) { | |
data.forEach(function(d) { | |
d.population = +d.population; | |
d.profit = +d.profit; | |
}); | |
x.domain(d3.extent(data, function(d) { return d.population; })).nice(); | |
y.domain(d3.extent(data, function(d) { return d.profit; })).nice(); | |
var xMin = x.domain()[0], | |
xMax = x.domain()[1], | |
yMin = y.domain()[0], | |
yMax = y.domain()[1]; | |
svg.append("g") | |
.attr("class", "x axis") | |
.attr("transform", "translate(0," + height + ")") | |
.call(xAxis) | |
.append("text") | |
.attr("class", "label") | |
.attr("x", width) | |
.attr("y", -6) | |
.style("text-anchor", "end") | |
.style("font-weight","bold") | |
.text("Population of City in 10,000s"); | |
svg.append("g") | |
.attr("class", "y axis") | |
.call(yAxis) | |
.append("text") | |
.attr("class", "label") | |
.attr("transform", "rotate(-90)") | |
.attr("y", 6) | |
.attr("dy", ".71em") | |
.style("font-weight","bold") | |
.style("text-anchor", "end") | |
.text("Profit in $10,000s") | |
svg.selectAll(".dot") | |
.data(data) | |
.enter().append("circle") | |
.attr("class", "dot") | |
.attr("r", 3.5) | |
.attr("cx", function(d) { return x(d.population); }) | |
.attr("cy", function(d) { return y(d.profit); }) | |
.style("fill","#d73027"); | |
// Some gradient descent settings | |
var iteration = 0, | |
iterationNumber = 1500, | |
m = data.length, | |
alpha = 0.01; | |
theta0 = 0, | |
theta1 = 0; | |
var line = svg.append("line") | |
.attr("class", "line") | |
.attr("x1",x( xMin )) | |
.attr("y1",y( theta1 * xMin + theta0 )) | |
.attr("x2",x( xMax )) | |
.attr("y2",y( theta1 * xMax + theta0 )); | |
var hyp = svg.append("text") | |
.attr("x", width/2) | |
.attr("y", 40) | |
.style("text-anchor","middle") | |
.style("font-size","35px") | |
.text("hθ(x) = 0 + 0x"); | |
function computeCost (data, theta0, theta1) { | |
var cost = 0; | |
data.forEach(function(d) { | |
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2); | |
}); | |
return cost/(2 * m); | |
}; | |
d3.timer(function() { | |
var temp0 = theta0 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); })); | |
var temp1 = theta1 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; })); | |
theta0 = temp0; | |
theta1 = temp1; | |
line.attr("x1",x( xMin )) | |
.attr("y1",y( theta1 * xMin + theta0 )) | |
.attr("x2",x( xMax )) | |
.attr("y2",y( theta1 * xMax + theta0 )); | |
hyp.text("hθ(x) = " + format(theta0) + " + " + format(theta1) + "x"); | |
return ++iteration > iterationNumber; | |
},200); | |
}); | |
</script> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment