Skip to content

Instantly share code, notes, and snippets.

@hyponymous
Last active February 10, 2017 00:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hyponymous/32b23a355904e7a25a78205d34ded1c8 to your computer and use it in GitHub Desktop.
Save hyponymous/32b23a355904e7a25a78205d34ded1c8 to your computer and use it in GitHub Desktop.
Gridworld I — SARSA(λ)
/* global d3, _ */
var CONFIG = {
transitionDuration: 0,
stepDelay: 10,
// learning rate
alpha: 0.1,
// deviation from greediness
epsilon: 1e-2,
// reward discount factor
gamma: 0.9,
// eligibility discount factor
lambda: 0.9
};
var ACTIONS = [
[0, -1],
[1, 0],
[0, 1],
[-1, 0],
];
function sum(arr) {
return arr.reduce((accum, val) => accum + val, 0);
}
function max(arr) {
return arr.reduce((accum, val) => Math.max(accum, val), Number.NEGATIVE_INFINITY);
}
function eGreedy(weights, epsilon) {
var indices;
if (Math.random() < epsilon) {
indices = _.range(weights.length);
} else {
var largest = max(weights);
indices = _.range(weights.length).filter(i => weights[i] === largest);
}
return _.sample(indices);
}
function GridView(el, dims) {
this.CELL_SIZE = 32;
var width = this.CELL_SIZE * dims[0] + 2;
var height = this.CELL_SIZE * dims[1] + 2;
this.svg = d3.select(el).append('svg')
.attr('width', width)
.attr('height', height);
}
Object.assign(GridView.prototype, {
render: function({grid, walls, treasure, agent}, qValues, eligibility) {
var CELL_SIZE = this.CELL_SIZE;
var HALF_INNER_CELL = (CELL_SIZE - 4.0) / 2.0;
var x = offset => { return d => offset + CELL_SIZE * d[0]; };
var y = offset => { return d => offset + CELL_SIZE * d[1]; };
var svg = this.svg;
function renderGridRects(data, klass) {
var boxes = svg.selectAll(['rect', klass].join('.'))
.data(data);
boxes.enter()
.append('rect')
.attr('class', klass)
.attr('width', CELL_SIZE - 2)
.attr('height', CELL_SIZE - 2)
.attr('x', x(2))
.attr('y', y(2));
boxes.exit()
.remove();
}
function renderCircles(data, klass) {
var r = d => (d[2] * HALF_INNER_CELL);
var cx = d => (x(3 + HALF_INNER_CELL)(d));
var cy = d => (y(3 + HALF_INNER_CELL)(d));
var circle = svg.selectAll(['circle', klass].join('.'))
.data(data);
circle.enter()
.append('circle')
.attr('class', klass)
.attr('r', r)
.attr('cx', cx)
.attr('cy', cy);
circle
.transition()
.duration(CONFIG.transitionDuration)
.attr('r', r)
.attr('cx', cx)
.attr('cy', cy);
circle.exit()
.remove();
}
function renderRects(data, klass) {
var bar = svg.selectAll(['rect', klass].join('.'))
.data(data);
bar.enter()
.append('rect')
.attr('class', klass)
.attr('width', d => d.width)
.attr('height', d => d.height)
.attr('x', d => d.x)
.attr('y', d => d.y);
bar
.transition()
.duration(CONFIG.transitionDuration)
.attr('width', d => d.width)
.attr('height', d => d.height)
.attr('x', d => d.x)
.attr('y', d => d.y);
bar.exit()
.remove();
}
renderGridRects(grid, 'grid-square');
renderGridRects(walls, 'wall-square');
// eligibility and qValues are in kind of an ungainly format, so unroll
// them into an array of rects
function toRects({ mapping, filter, getLongDim, inset }) {
return _.chain(mapping)
.toPairs()
.filter(filter)
.map(pair => {
var coords = pair[0].split(',').map(x => parseInt(x));
return pair[1].map((value, i) => {
var isHoriz = (i === 0 || i === 2);
var long = getLongDim(value);
var offset = 1.0 + (CELL_SIZE - long) / 2.0;
return {
width: isHoriz ? long : 2.0,
height: isHoriz ? 2.0 : long,
x: CELL_SIZE * coords[0] + [offset, CELL_SIZE - inset, offset, inset][i],
y: CELL_SIZE * coords[1] + [inset, offset, CELL_SIZE - inset, offset][i]
};
});
})
.flatten()
.value();
}
renderRects(
toRects({
mapping: eligibility,
filter: e => {
return sum(e[1]) > 0.01;
},
getLongDim: value => CELL_SIZE * value,
inset: 4.0
}),
'eligibility');
renderRects(
toRects({
mapping: qValues,
filter: _.identity,
getLongDim: value => {
return Math.max(0.0,
CELL_SIZE * (1.0 / (1.0 + Math.exp(-1.9 * (value - 1.6)))));
},
inset: 2.0
}),
'q-value');
renderCircles([agent.position].map(pos => [pos[0], pos[1], 1.0]), 'agent');
renderCircles([treasure.position].map(pos => [pos[0], pos[1], 1.0]), 'treasure');
}
});
function collidesWith(posA) {
return function(posB) {
return posA[0] === posB[0] && posA[1] === posB[1];
};
}
function spin({ dims, view, model, startPosition, qValues, eligibility }) {
function samplePolicy(state) {
return eGreedy(qValues[state], CONFIG.epsilon);
}
function Q(state, actionIndex) { return qValues[state][actionIndex]; }
function setQ(state, actionIndex, value) { qValues[state][actionIndex] = value; }
function Z(state, actionIndex) { return eligibility[state][actionIndex]; }
function setZ(state, actionIndex, value) { eligibility[state][actionIndex] = value; }
var actionIndex = samplePolicy(model.agent.position);
var episodeCount = 0;
var moveCount = 0;
(function step() {
view.render(model, qValues, eligibility);
// transition
if (collidesWith(model.agent.position)(model.treasure.position)) {
// teleport back to start
model.agent.position = startPosition;
console.log(episodeCount++, moveCount);
moveCount = 0;
} else {
var currentState = model.agent.position;
var nextState = _.clone(model.agent.position);
// epsilon-greedy policy
var action = ACTIONS[actionIndex];
nextState[0] += action[0];
nextState[0] = Math.max(nextState[0], 0);
nextState[0] = Math.min(nextState[0], dims[0] - 1);
nextState[1] += action[1];
nextState[1] = Math.max(nextState[1], 0);
nextState[1] = Math.min(nextState[1], dims[1] - 1);
// commit move if no collision (TODO: query environment -- might include wind)
if (_.some(model.walls, collidesWith(nextState))) {
nextState = currentState;
}
// rewards (TODO: query environment)
var reward = collidesWith(nextState)(model.treasure.position) ?
100 :
-1;
model.agent.satisfaction += reward;
var nextActionIndex = samplePolicy(nextState);
// SARSA update (for SARSA(lambda) need eligibility traces)
var delta = reward + CONFIG.gamma * Q(nextState, nextActionIndex) - Q(currentState, actionIndex);
setZ(currentState, actionIndex, Z(currentState, actionIndex) + 1.0);
_.each(qValues, (actionValues, state) => {
actionValues.forEach((value, i) => {
setQ(state, i, Q(state, i) + CONFIG.alpha * delta * Z(state, i));
setZ(state, i, CONFIG.gamma * CONFIG.lambda * Z(state, i));
});
});
actionIndex = nextActionIndex;
model.agent.position = nextState;
moveCount++;
}
// repeat
setTimeout(step, CONFIG.stepDelay);
})();
}
(function() {
// initialize
var dims = [15, 15];
var view = new GridView('body', dims);
var allCoords = _.flatten(_.range(dims[1]).map(i => _.range(dims[0]).map(j => [j, i])));
var sampledPositions = _.sampleSize(allCoords, 2 + 0.1 * dims[0] * dims[1]);
var treasurePosition = sampledPositions.pop();
var startPosition = sampledPositions.pop();
var agentPosition = startPosition;
var model = {
dims,
grid: allCoords,
walls: sampledPositions,
treasure: {
position: treasurePosition
},
agent: {
position: agentPosition,
// I can't get no
satisfaction: 0
}
};
var qValues = allCoords.reduce((qValues, coords) => {
// start with uniform value for all actions
qValues[coords] = [1.0, 1.0, 1.0, 1.0];
return qValues;
}, {});
var eligibility = allCoords.reduce((eligibility, coords) => {
eligibility[coords] = [0.0, 0.0, 0.0, 0.0];
return eligibility;
}, {});
spin({ dims, view, model, startPosition, qValues, eligibility });
})();
<!DOCTYPE html>
<head>
<meta charset="utf-8">
<style>
.grid-square {
fill: #fff;
stroke: #aaa;
}
.wall-square {
fill: #666;
stroke: #222;
}
.treasure {
fill: #0c0;
stroke: #060;
}
.agent {
fill: #000;
}
.eligibility {
fill: #48e;
}
.q-value {
fill: #4c9;
}
</style>
</head>
<body>
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.5.0/d3.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.4/lodash.min.js"></script>
<script src="gridworld.js"></script>
</body>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment