Mock data is from the week 3 programming exercise of Andrew Ng's machine learning course on Coursera.
Last active
February 21, 2017 13:09
-
-
Save feyderm/61f613ceb9d8a2007e1de1b3f52362b9 to your computer and use it in GitHub Desktop.
Gradient Decent for Logistic Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<meta charset="utf-8"> | |
<style> | |
text { | |
font-family: sans-serif; | |
fill: #000000; | |
} | |
.pts { | |
stroke: #595959; | |
} | |
.group1 { | |
fill: steelblue; | |
} | |
.group2 { | |
fill: red; | |
} | |
#dec_boundary { | |
fill: none; | |
stroke: #000000; | |
stroke-width: 2px; | |
opacity: 0.6; | |
} | |
</style> | |
<body> | |
<!--viz--> | |
<div id="chart"></div> | |
<script src="https://d3js.org/d3.v4.min.js"></script> | |
<script src="http://feyderm.github.io/math/math.js"></script> | |
<script type="text/javascript"> | |
var margin = {top: 20, right: 0, bottom: 50, left: 85}, | |
svg_dx = 500, | |
svg_dy = 400, | |
plot_dx = svg_dx - margin.right - margin.left, | |
plot_dy = svg_dy - margin.top - margin.bottom; | |
var xPos = d3.scaleLinear().range([margin.left, plot_dx]), | |
yPos = d3.scaleLinear().range([plot_dy, margin.top]); | |
var svg = d3.select("#chart") | |
.append("svg") | |
.attr("width", svg_dx) | |
.attr("height", svg_dy); | |
d3.csv("logistic_reg_grad_decent.csv", d => { | |
var d_extent_x = d3.extent(d, d => +d.x), | |
d_extent_y = d3.extent(d, d => +d.y); | |
xPos.domain(d_extent_x); | |
yPos.domain(d_extent_y); | |
var axis_x = d3.axisBottom(xPos), | |
axis_y = d3.axisLeft(yPos); | |
svg.append("g") | |
.attr("id", "axis_x") | |
.attr("transform", "translate(0," + (plot_dy + margin.bottom / 2) + ")") | |
.call(axis_x); | |
svg.append("g") | |
.attr("id", "axis_y") | |
.attr("transform", "translate(" + (margin.left / 2) + ", 0)") | |
.call(axis_y); | |
svg.append("g") | |
.selectAll("path") | |
.data(d) | |
.enter() | |
.append("path") | |
.attr("class", d => d.group == "1" ? "pts group1" : "pts group2") | |
.attr("d", d3.symbol().type((d,i) => d.group == "1" ? d3.symbolCircle : d3.symbolCross)) | |
.attr("transform", d => "translate(" + xPos(d.x) + "," + yPos(d.y) + ")"); | |
runGradientDescent(); | |
}); | |
function sigmoid(z) { | |
var s = 1 / (1 + Math.pow(Math.E, -z)); | |
return s; | |
} | |
function computeGradient(m, y, h, X) { | |
// conversion from octave of grad = (1 / m) * (h - y)' * X; | |
var grad = math.multiply(h.map((h, i) => h - math.subset(y, math.index(i))), X) | |
.map(d => (1 / m) * d); | |
return grad; | |
} | |
function runGradientDescent() { | |
var d = d3.selectAll(".pts").data(); | |
var d_extent_x = d3.extent(d, pt => +pt.x); | |
var X = d.map(pt => [1, +pt.x, +pt.y]), | |
y = d.map(pt => +pt.group); | |
X = math.matrix(X); | |
y = math.matrix(y); | |
var iteration = 0, | |
iterationNumber = 400, | |
m = math.subset(math.size(X), math.index(0)), | |
alpha = 0.0004, | |
theta = math.matrix([-24, 0.5, 0.2]); | |
var dec_bnd = svg.append("line") | |
.attr("id", "dec_boundary"); | |
var iterate = d3.timer(() => { | |
var h = math.multiply(X, theta).map(z => sigmoid(z)), | |
grad = computeGradient(m, y, h, X); | |
// update theta | |
theta = theta.map((t, i) => t - (alpha * math.subset(grad, math.index(i)))) | |
var theta0 = math.subset(theta, math.index(0)), | |
theta1 = math.subset(theta, math.index(1)), | |
theta2 = math.subset(theta, math.index(2)); | |
dec_bnd.attr("x1",xPos(d_extent_x[0])) | |
.attr("y1",yPos((-1 / theta2) * (theta1 * d_extent_x[0] + theta0))) | |
.attr("x2",xPos(d_extent_x[1])) | |
.attr("y2",yPos((-1 / theta2) * (theta1 * (d_extent_x[1] * .95) + theta0))); | |
if (iteration++ > iterationNumber) { | |
iterate.stop(); | |
} | |
}, 200) | |
} | |
</script> | |
</body> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x | y | group | |
---|---|---|---|
34.62365962451697 | 78.0246928153624 | 0 | |
30.28671076822607 | 43.89499752400101 | 0 | |
35.84740876993872 | 72.90219802708364 | 0 | |
60.18259938620976 | 86.30855209546826 | 1 | |
79.0327360507101 | 75.3443764369103 | 1 | |
45.08327747668339 | 56.3163717815305 | 0 | |
61.10666453684766 | 96.51142588489624 | 1 | |
75.02474556738889 | 46.55401354116538 | 1 | |
76.09878670226257 | 87.42056971926803 | 1 | |
84.43281996120035 | 43.53339331072109 | 1 | |
95.86155507093572 | 38.22527805795094 | 0 | |
75.01365838958247 | 30.60326323428011 | 0 | |
82.30705337399482 | 76.48196330235604 | 1 | |
69.36458875970939 | 97.71869196188608 | 1 | |
39.53833914367223 | 76.03681085115882 | 0 | |
53.9710521485623 | 89.20735013750205 | 1 | |
69.07014406283025 | 52.74046973016765 | 1 | |
67.94685547711617 | 46.67857410673128 | 0 | |
70.66150955499435 | 92.92713789364831 | 1 | |
76.97878372747498 | 47.57596364975532 | 1 | |
67.37202754570876 | 42.83843832029179 | 0 | |
89.67677575072079 | 65.79936592745237 | 1 | |
50.534788289883 | 48.85581152764205 | 0 | |
34.21206097786789 | 44.20952859866288 | 0 | |
77.9240914545704 | 68.9723599933059 | 1 | |
62.27101367004632 | 69.95445795447587 | 1 | |
80.1901807509566 | 44.82162893218353 | 1 | |
93.114388797442 | 38.80067033713209 | 0 | |
61.83020602312595 | 50.25610789244621 | 0 | |
38.78580379679423 | 64.99568095539578 | 0 | |
61.379289447425 | 72.80788731317097 | 1 | |
85.40451939411645 | 57.05198397627122 | 1 | |
52.10797973193984 | 63.12762376881715 | 0 | |
52.04540476831827 | 69.43286012045222 | 1 | |
40.23689373545111 | 71.16774802184875 | 0 | |
54.63510555424817 | 52.21388588061123 | 0 | |
33.91550010906887 | 98.86943574220611 | 0 | |
64.17698887494485 | 80.90806058670817 | 1 | |
74.78925295941542 | 41.57341522824434 | 0 | |
34.1836400264419 | 75.2377203360134 | 0 | |
83.90239366249155 | 56.30804621605327 | 1 | |
51.54772026906181 | 46.85629026349976 | 0 | |
94.44336776917852 | 65.56892160559052 | 1 | |
82.36875375713919 | 40.61825515970618 | 0 | |
51.04775177128865 | 45.82270145776001 | 0 | |
62.22267576120188 | 52.06099194836679 | 0 | |
77.19303492601364 | 70.45820000180959 | 1 | |
97.77159928000232 | 86.7278223300282 | 1 | |
62.07306379667647 | 96.76882412413983 | 1 | |
91.56497449807442 | 88.69629254546599 | 1 | |
79.94481794066932 | 74.16311935043758 | 1 | |
99.2725269292572 | 60.99903099844988 | 1 | |
90.54671411399852 | 43.39060180650027 | 1 | |
34.52451385320009 | 60.39634245837173 | 0 | |
50.2864961189907 | 49.80453881323059 | 0 | |
49.58667721632031 | 59.80895099453265 | 0 | |
97.64563396007767 | 68.86157272420604 | 1 | |
32.57720016809309 | 95.59854761387875 | 0 | |
74.24869136721598 | 69.82457122657193 | 1 | |
71.79646205863379 | 78.45356224515052 | 1 | |
75.3956114656803 | 85.75993667331619 | 1 | |
35.28611281526193 | 47.02051394723416 | 0 | |
56.25381749711624 | 39.26147251058019 | 0 | |
30.05882244669796 | 49.59297386723685 | 0 | |
44.66826172480893 | 66.45008614558913 | 0 | |
66.56089447242954 | 41.09209807936973 | 0 | |
40.45755098375164 | 97.53518548909936 | 1 | |
49.07256321908844 | 51.88321182073966 | 0 | |
80.27957401466998 | 92.11606081344084 | 1 | |
66.74671856944039 | 60.99139402740988 | 1 | |
32.72283304060323 | 43.30717306430063 | 0 | |
64.0393204150601 | 78.03168802018232 | 1 | |
72.34649422579923 | 96.22759296761404 | 1 | |
60.45788573918959 | 73.09499809758037 | 1 | |
58.84095621726802 | 75.85844831279042 | 1 | |
99.82785779692128 | 72.36925193383885 | 1 | |
47.26426910848174 | 88.47586499559782 | 1 | |
50.45815980285988 | 75.80985952982456 | 1 | |
60.45555629271532 | 42.50840943572217 | 0 | |
82.22666157785568 | 42.71987853716458 | 0 | |
88.9138964166533 | 69.80378889835472 | 1 | |
94.83450672430196 | 45.69430680250754 | 1 | |
67.31925746917527 | 66.58935317747915 | 1 | |
57.23870631569862 | 59.51428198012956 | 1 | |
80.36675600171273 | 90.96014789746954 | 1 | |
68.46852178591112 | 85.59430710452014 | 1 | |
42.0754545384731 | 78.84478600148043 | 0 | |
75.47770200533905 | 90.42453899753964 | 1 | |
78.63542434898018 | 96.64742716885644 | 1 | |
52.34800398794107 | 60.76950525602592 | 0 | |
94.09433112516793 | 77.15910509073893 | 1 | |
90.44855097096364 | 87.50879176484702 | 1 | |
55.48216114069585 | 35.57070347228866 | 0 | |
74.49269241843041 | 84.84513684930135 | 1 | |
89.84580670720979 | 45.35828361091658 | 1 | |
83.48916274498238 | 48.38028579728175 | 1 | |
42.2617008099817 | 87.10385094025457 | 1 | |
99.31500880510394 | 68.77540947206617 | 1 | |
55.34001756003703 | 64.9319380069486 | 1 | |
74.77589300092767 | 89.52981289513276 | 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment