Skip to content

Instantly share code, notes, and snippets.

@puzzler10
Last active May 25, 2017 11:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save puzzler10/f8955e779fd81f5cf7befaef76bdf503 to your computer and use it in GitHub Desktop.
Save puzzler10/f8955e779fd81f5cf7befaef76bdf503 to your computer and use it in GitHub Desktop.
The Multi-Armed Bandit
height: 1000
scrolling: true

The problem

This is a simulation of the multi-armed bandit problem. Given a number of options to choose between, this multi-armed bandit problem describes how to choose the best option when you don't know much about any of them.

You are faced repeatedly with n choices, of which you must choose one. After your choice, you are again faced with n choices of which you must choose one, and so on.

After each choice, you receive a numerical reward chosen from a normally-distributed probability distribution that corresponds to your choice. You don't know what the probability distribution is for that choice before choosing it, but after you have picked it a few times you will start to get an idea of its underlying probability distribution.

The aim is to maximise your total reward over a given number of selections.

One analogy for this problem is this: you are placed in a room with a number of slot machines, and each slot machine when played will spit out a reward sampled from its probability distribution. Your aim is to maximise your total reward.

The simulation

In keeping with the slot machine analogy, each "choice" in this simulation is labelled as a slot machine, and your knowledge about each slot machine is represented by a circle. The vertical position of the circle is the average reward you have received from the slot machine. The size of the circle is representative of how many times you have chosen that particular slot machine.

Inherent to this problem is a trade-off between exploration and reward maximisation. You would like to sample from each slot machine a few times so you can find the best one, but at the same time you would also like to maximise your total reward. Spend too much time on exploration and you choose the non-optimal slot machines too many times. Spend too much time on reward maximisation and you might not find the best slot machine.

One way to solve this problem is to choose from the slot machine you think is best a set percentage of the time (defaults at 0.9), which is called a greedy selection. The rest of the time you randomly select from the other slot machines. The chance of not making a greedy selection is referred to by the Greek letter ε (epsilon), and the method of optimising the multi-armed bandit problem like this is called the ε-greedy method.

You have the option to control the number of slot machines n by a slider, where the default is 10. You can also control the chance of choosing randomly ε by a slider, and you can change this while the simulation is running. There's a slider for t, the number of "choice selections" the simulation runs for, and which defaults at 2000. There is also a slider for controlling the speed of the simulation, although this cannot be changed while the simulation is running.

You also have the option to display the true mean of the probability distribution that corresponds to each slot machine. This will not affect the simulation and is purely for your own interest.

/** Functions from clicking radio buttons and buttons **/
/**
Called when the Start button is pressed
Pulls values of t and n from sliders
If reset button has not been pressed before, sets up the simulation and runs it
Else simply runs the simulation
**/
function start(){
//pull value from sliders
n = +document.getElementById("n_slot_machines").value;
t = +document.getElementById("t").value;
//Case 1: reset button hasn't been pressed
//Case 2: reset button has been pressed
//set reset flag to false to enable the simulation to run and the reset button to work
if(resetFlag === true){
resetFlag = false;
run_simulation();
} else {
//first time
setup_simulation();
run_simulation();
}
}
/**
Called when the Reset button is pressed
Pulls values of t and n from sliders
Sets score to 0
Resets the slot machines, scales, axes and circle positions
Shows ghosts if the "Yes" radio box is ticked for true means
**/
function reset(){
//pull value from sliders
n = +document.getElementById("n_slot_machines").value;
t = +document.getElementById("t").value;
//reset the score
score = 0;
avg_score = 0;
document.getElementById("score_display").innerHTML = 0;
//setting to true kills the simulation if it's running
resetFlag = true;
//reset slot machines
slot_machines = create_slot_machines(n);
//reset the scales
setup_scales();
//reset axes
//probably don't have to reset x axis but i'll do it anyway
reset_axes();
//bring all circles back to 0, change the amount if n has changed
setup_circle_position();
reset_circles();
//if true means radiobox checked, call it
if(document.getElementById("ghosts_yes").checked === true){
show_ghosts();
}
}
/**
Display the true means of each circle
Triggered by the Yes radio button for true means
**/
function show_ghosts(){
ghostFlag = true;
ghost_circles = svg.selectAll("circle.true_mean")
.data(slot_machines);
//drop old circles
ghost_circles
.exit()
.remove();
//change existing circles
svg.selectAll("circle.true_mean")
.attr("class", "true_mean")
.attr("r", radius)
.attr("cx", (d)=> d.cx)
.attr("cy", (d)=> yScale(d.true_mean))
.style("fill", "#F66A00");
//add new circles
ghost_circles
.enter()
.append("circle")
.attr("class", "true_mean")
.attr("r", radius)
.attr("cx", (d)=> d.cx)
.attr("cy", (d)=> yScale(d.true_mean))
.style("fill", "#F66A00");
}
/**
Hide the true means of each circle
Triggered when No radio button for true means is checked
**/
function hide_ghosts(){
ghostFlag = false;
ghost_circles = svg.selectAll("circle.true_mean")
.remove();
}
/**
Slider Functions:
A collection of functions that update variables when the sliders are moved
**/
//Speed of simulation - time between ticks
function set_time_between_ticks(){
//pull the slider value
interval = +document.getElementById("speed").value;
//update the label with display
document.getElementById('speed_display').innerHTML = interval;
}
// Number of slot machines
function set_n_slot_machines(){
// problems with changing n mid run
// n = +document.getElementById("n_slot_machines").value;
//update HTML display
n_temp = +document.getElementById("n_slot_machines").value;
document.getElementById('n_slot_machine_display').innerHTML = n_temp;
}
// Epsilon
function set_epsilon(){
epsilon = +document.getElementById("epsilon").value;
//update HTML display
document.getElementById('epsilon_display').innerHTML = epsilon;
}
// Number of ticks t
function set_t(){
t_temp = +document.getElementById("t").value;
//update HTML display
document.getElementById('t_display').innerHTML = t_temp;
}
<!DOCTYPE html>
<meta charset="utf-8">
<div class="settings">
<div class="row">
<div class="slider col1" id="n_slot_machine_slider_div">
<label class="label" for="n_slot_machines">Number of slot machines (n): </label>
<em id="n_slot_machine_display" class="label" style="font-style: normal;">10</em>
<br>
<input type="range" name="n_slot_machines" id="n_slot_machines" min="1" max="200" value="10"
oninput="set_n_slot_machines()" />
</div>
<div class="slider col2" id="epsilon_slider_div">
<label class="label" for="epsilon">Chance to choose randomly (ε): </label>
<em id="epsilon_display" class="label" style="font-style: normal;">0.1</em>
<br>
<input type="range" name="epsilon" id="epsilon" min="0" max="1" value="0.1"
step="0.01" oninput="set_epsilon()" />
<br>
</div>
</div>
<div class="row">
<div class="slider col1" id="t_div">
<label class="label">Number of action selections (t): </label>
<em id = 't_display' class="label" style='font-style: normal;'> 2000 </em>
<br>
<input type="range" name="t" id="t" min="200" max="5000" value="2000" step="100"
oninput="set_t()" />
<br>
</div>
<div class="slider col2" id="speed_slider_div">
<label class="label">Time between action selections: </label>
<em id = 'speed_display' class="label" style='font-style: normal;'> 5 </em>
<br>
<input type="range" name="speed" id="speed" min="0" max="1000" value="5" step="5"
oninput="set_time_between_ticks()" />
<br>
</div>
</div>
<div class="row">
<div class ="ghosts_radio_div col1">
<label class="label"> Show true means? </label>
<label class="label" >
<input name="ghosts" type="radio" id="ghosts_yes" value="Yes" onclick="show_ghosts()">
Yes
</label>
<label class="label">
<input name="ghosts" type="radio" id="ghosts_no" value="No" onclick="hide_ghosts()" checked="checked">
No
</label>
</div>
<div class="score_div col2">
<em class="score label" style='font-style: normal; font-weight: bold;'> Average Score: <span id="score_display">0</span> </em>
</div>
</div>
<div class="row">
<div class="button_div col1">
<button id="startButton" class="button" onclick="start()"> Start </button>
</div>
<div class="button_div col2">
<button id="resetButton" class="button" onclick="reset()"> Reset </button>
</div>
</div>
</div>
<div class="simulation">
<svg class="svg_box" width="940" height="700"></svg>
</div>
<link rel="stylesheet" type="text/css" href="styling.css">
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="svg_functions.js"></script>
<script src="html_functions.js"></script>
<script src="slot_machine_functions.js"></script>
<script src="statistical_functions.js"></script>
<script>
//variables for the svg block and the borders
var svg,
border,
width,
height;
//variables for the axis and scales
var xScale, yScale;
var xAxis, yAxis;
//variables that are set by the sliders
//n-armed bandit - how many slot machines there are
var n = 10;
var n_temp; // a temp variable that allows the slider display to update
//t time steps
var t = 2000;
var t_temp; // a temp variable that allows the slider display to update
//epsilon - the chance to choose randomly
var epsilon = 0.1;
//interval between loops - set by slider with id "speed"
var interval = 5;
//set by radio button for true means
var ghostFlag = false;
//score - the total score at any point from the slot machines
//avg_score - score divided by the number of selections
var score = 0;
var avg_score = 0;
// Slot machine and circle properties
//Minimum circle radius
var radius = 5;
//the slot machines
var slot_machines;
//The svg circles themselves to represent slot_machines
var circles;
// "Ghost" circles that show the true mean
var ghost_circles;
//flag to determine if we've hit the reset button or not
var resetFlag = false;
//Setup the svg canvas and the borders
// initialise the variables svg, border, width, height
setup_svg();
// From this point all action is by buttons and sliders.
</script>
/** Functions related to creation of slot machines and simulating **/
/**
Make an array of objects that represent slot machines. Each object has:
- a name
- true mean
- a function to allow drawing (normal dist with mean true mean and sd 1 )
- number of times chosen
- total reward
- average reward
**/
function create_slot_machines(n_slot_machines) {
slot_machines = Array(n_slot_machines);
standard = gaussian(0,1);
//create true means - normally distributed around 0,1
//instalise the number of times each slot machine is chosen as 0 at the start
//the total reward is 0 and the average reward is 0
for(var i = 0; i < n_slot_machines; i++){
slot_machines[i] = {};
slot_machines[i].name = "#" + (i+1);
slot_machines[i].true_mean = standard();
slot_machines[i].pull_lever = gaussian(slot_machines[i].true_mean, 1);
slot_machines[i].n_chosen = 0;
slot_machines[i].total_reward = 0;
slot_machines[i].average_reward = 0;
}
return(slot_machines);
}
/**
Returns an integer denoting which slot machine to pick by the epsilon-greedy method.
Inputs:
- epsilon: the chance of choosing randomly
- n: the number of slot machines:
- average_rewards: the average reward of each slot machine
**/
function choose_slot_machine(epsilon, n, average_rewards){
//simulate a random number - if this number is below epsilon choose randomly
//else choose whichever slot machine has the highest probability
x = Math.random()
if(x <= epsilon){
//choose randomly
return Math.floor(Math.random()* n);
} else {
//greedy choice
original = average_rewards;
max_reward = average_rewards.slice(0).sort( (x,y) => y-x)[0]
index = original.indexOf(max_reward)
return index;
}
}
/**
Runs the simulation.
Performs one iteration of the simulation every "interval" seconds.
Checks to see if the reset button is triggered at any point
If it is, stop the simulation
**/
function run_simulation(){
// used for loop termination
var count = 0;
//run the function every interval miliseconds
var interval_fn = setInterval(function () {
count = count + 1;
//Stop after we have reached t iterations
//Also stop after the reset button is clicked
if (count > t | resetFlag === true) {
count = 0;
clearInterval(interval_fn);
return;
}
//extract the average rewards from the slot machines
var average_rewards = slot_machines.map((d) => d.average_reward)
//choose slot machine
var choice = choose_slot_machine(epsilon, n, average_rewards)
//generate reward
var reward = slot_machines[choice].pull_lever()
//update slot machine parameters
slot_machines[choice].n_chosen += 1;
slot_machines[choice].total_reward += reward;
slot_machines[choice].average_reward = slot_machines[choice].total_reward / slot_machines[choice].n_chosen;
//update circle location
slot_machines[choice].cy = yScale(slot_machines[choice].average_reward)
//update circle position on the canvas
circles.attr("cx", (d) => d.cx)
circles.attr("cy", (d) => d.cy)
//update circle size to an area proportional to number chosen plus some constant (so you can see the small circles )
circles.attr("r", (d) => radius + Math.sqrt(d.n_chosen / Math.PI))
//update score
score += reward;
avg_score = score / count
document.getElementById("score_display").innerHTML = avg_score;
}, interval)
}
/** Store all the statistical functions used **/
/**
Returns a gaussian random function with the given mean and stdev.
Note that this returns a function, not a value.
**/
function gaussian(mean, stdev) {
var y2;
var use_last = false;
return function() {
var y1;
if(use_last) {
y1 = y2;
use_last = false;
}
else {
var x1, x2, w;
do {
x1 = 2.0 * Math.random() - 1.0;
x2 = 2.0 * Math.random() - 1.0;
w = x1 * x1 + x2 * x2;
} while( w >= 1.0);
w = Math.sqrt((-2.0 * Math.log(w))/w);
y1 = x1 * w;
y2 = x2 * w;
use_last = true;
}
var retval = mean + stdev * y1;
return retval;
}
}
.settings {
max-width: 800px;
width: 100%;
border-radius: 10px;
padding: 10px;
}
.simulation {
}
.row {
width: 100%;
letter-spacing: -0.31em;
}
.col1, .col2{
display: inline-block;
vertical-align: top;
text-align: center;
background: white;
padding: 5px;
margin: auto;
letter-spacing: normal;
}
.col1 {
width: 45%;
}
.col2 {
width: 45%;
}
.button {
margin: 10px;
padding: 12px 12px;
cursor: pointer;
text-align: center;
text-decoration: none !important;
text-transform: none;
text-transform: capitalize;
color: #FFFFFF;
background: #333030;
border: 0 none;
border-radius: 4px;
width: 120px;
}
.label {
font: 14px sans-serif;
}
.axis {
font: 14px sans-serif;
font-weight: bold;
}
.tick{
font: 14px sans-serif;
}
.svg_box {
border: 1px solid black;
}
body {
background-color: white
}
/** Functions to set up the simulation, draw circles,axes, scales etc **/
/**
Set up the SVG box and the borders for the simulation.
**/
function setup_svg(){
svg = d3.select("svg");
//Define a buffer to have around the borders of the svg element
border = {
top: 20,
bottom: 60,
left: 50,
right: 20
}
width = +svg.attr("width") - border.left - border.right;
height = +svg.attr("height")- border.bottom - border.top;
svg = svg.append("g")
.attr("transform", "translate(" + border.left + "," + border.top + ")");
}
/**
Creates slot machines
Creates scales
Draws axes and labels
Draws circles
**/
function setup_simulation(){
//if reset button has been pressed, we don't need to set up the axes and slot machines
// Create the slot machines, set up their properties
slot_machines = create_slot_machines(n);
// Set the x and y scales
setup_scales();
// Use the x and y scales to draw axes and axes labels
//must be called after create_slot_machines otherwise x-axis labels won't draw correctly
setup_axes();
//Set the initial position of the circles
//must be called after setup_scales otherwise the scales won't be defined
setup_circle_position();
//Draws the circles on the page
draw_circles()
}
/**
Initialises the x and y scales
**/
function setup_scales(){
// the domain is hardcoded for now
yScale = d3.scaleLinear()
.domain([4,-4])
.range([0,height]);
//Use a point scale rather than a band scale because of the circles
xScale = d3.scalePoint()
.domain(slot_machines.map((d) => d.name))
.range([0, width])
.padding(0.5);
}
/**
Draws the axes and the axis labels
**/
function setup_axes(){
// set up the y axis
yAxis = d3.axisLeft(yScale);
// add the y axis
svg.append("g")
.attr("class", "y axis")
.call(yAxis);
//add label for y axis
svg.append("text")
.attr("class", "y axis")
.attr("transform", "rotate(-90)")
.attr("x", 0 - height /2)
.attr("y", 0 - border.left + 10)
.attr("dy", "1em")
.style("text-anchor", "middle")
.text("Average Reward");
//set up x axis
//add labels to ticks only if we have less than 30 slot machines
if(n > 30){
xAxis = d3.axisBottom(xScale)
.tickFormat("");
} else {
xAxis = d3.axisBottom(xScale);
}
//add x axis
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis);
//add x axis title
svg.append("text")
.attr("class", "x axis")
.attr("y", height + (border.bottom - 20))
.attr("x", width /2)
.style("text-anchor", "middle")
.text("Slot Machine");
}
/**
Sets the initial x/y position of the circles.
Sets the initial cx and the cy properties of the slot machines.
Relies on setup_scales() being called beforehand to define xScale and yScale.
**/
function setup_circle_position(){
for(var i = 0; i < slot_machines.length; i++) {
slot_machines[i].cx = xScale(slot_machines[i].name)
slot_machines[i].cy = yScale(0)
};
}
/**
Updates the axes to reflect a new amount of slot machines
**/
function reset_axes(){
yAxis = d3.axisLeft(yScale);
//add labels to ticks only if we have less than 30 slot machines
if(n > 30){
xAxis = d3.axisBottom(xScale)
.tickFormat("");
} else {
xAxis = d3.axisBottom(xScale);
}
svg.selectAll("g.y.axis")
.attr("class", "y axis")
.call(yAxis);
svg.selectAll("g.x.axis")
.attr("class", "x axis")
.call(xAxis);
}
/**
Binds new slot_machines data to the average reward circles
Removes redundant circles
Adds new circles
Sets all circles to the origin
**/
function reset_circles(){
// debugger;
circles = svg.selectAll("circle.average_reward")
.data(slot_machines);
//drop old circles
circles.exit().remove();
//change existing circles
svg.selectAll("circle.average_reward")
.attr("cx", (d) => d.cx)
.attr("cy", (d) => d.cy)
.attr("r", radius);
//add new ones
circles.enter()
.append("circle")
.attr("class","average_reward")
.attr("cx", (d) => d.cx)
.attr("cy", (d) => d.cy)
.attr("r", radius);
circles = svg.selectAll("circle.average_reward");
}
/**
Draws circles for the average reward for slot machines
**/
function draw_circles(){
// Initalise the circles
circles = svg.selectAll("circle.average_reward")
.data(slot_machines)
.enter()
.append("circle")
.attr("class", "average_reward")
.attr("cx", (d) => d.cx)
.attr("cy", (d) => d.cy)
.attr("r", radius);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment