Skip to content

Instantly share code, notes, and snippets.

@bricof
Last active April 21, 2017 13:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save bricof/289866256121e20949abfc3a39d5805d to your computer and use it in GitHub Desktop.
Save bricof/289866256121e20949abfc3a39d5805d to your computer and use it in GitHub Desktop.
Latent Value Learning

This is a cleaned and simplified version of a simulation / animation of latent variable learning used in the Recommendation Systems section of the Stitch Fix Algorithms Tour.

Each circle is assumed to have some latent value along the horizontal axis - some true value for an attribute that we cannot observe directly but that we can try to estimate based on feedback from attempted pair-matches involving one A element and one B element.

The algorithm used to find them is as follows:

  • assign each entity a current estimated latent value, initialized at the center of the scale
  • select A-B pairs randomly, weighted by the distance between their current estimated latent value (shorter distances produce higher probabilities of selection)
  • if the feedback from the pair attempt says their relative latent values are different than what our estimates suggest, move both of the current estimated latent values in the direction of feedback (e.g. if A says B is too small, then move A to the right and B to the left), multiplied by a learning rate
  • repeat

The underlying simulation, then, is running this algorithm over a set of entities while also simulating the entities - each has its own latent value and the feedback it provides when paired with other entities is based on the actual differences between their latent values, with some noise added for good measure.

The svg update is straightforward - at each timestep, pairs are shown by lines between the circles, and the circles are transitioned to their new location based on their current estimated latent value.

function animated_learning() {
var svg = d3.select("body").select("svg")
var x = d3.scaleLinear().range([200,500]).domain([0,1])
var learning_rate = 0.2
var n_As = 10
var n_Bs = 10
var n_pairs = 8
// create elements with latent values and initial positions
var As = []
for (var i=0; i < n_As; i++){
As.push({id: i, latent_value: Math.random(), current_position: 0.5, next_position: 0.5})
}
var Bs = []
for (var i=0; i < n_Bs; i++){
Bs.push({id: i, latent_value: Math.random(), current_position: 0.5, next_position: 0.5})
}
// construct circles
svg.selectAll(".A").data(As, function(d){ return d.id; })
.enter().append("circle")
.attr("class", "A A-color")
.attr("cx", 350)
.attr("cy", 230)
.attr("r", 3)
svg.selectAll(".B").data(Bs, function(d){ return d.id; })
.enter().append("circle")
.attr("class", "B B-color")
.attr("cx", 350)
.attr("cy", 270)
.attr("r", 3)
// simulation / animation loop
d3.interval(function(){
// *** SIMULATION ***
var pairs = []
for (var i=0; i < n_pairs; i++) {
var A_id = Math.floor(Math.random()*n_As)
// pair selection a stochastic function of distance from respective current positions
var weights = Bs.map(function(d){ return d.current_position - As[A_id].current_position; })
var cum_weights = []
weights.reduce(function(a,b,i) { cum_weights[i] = {v: a+b, id:i}; return a + b; },0)
cum_weights = cum_weights.sort(function(a,b){ return a.v > b.v; })
var B_id = Math.floor(Math.random()*n_Bs)
if (cum_weights[cum_weights.length - 1].v != 0) {
var sel_random = Math.random() * cum_weights[cum_weights.length - 1].v
var sel = cum_weights.find(function(d){ return d.v >= sel_random; })
if (!(sel == null)) {
B_id = sel.id
}
}
pairs.push({A_id: A_id, B_id: B_id})
// big = 1, small = -1
var feedback = -1 + 2 * (As[A_id].latent_value > Bs[B_id].latent_value)
// use feedback if it contradicts current
if ((feedback == -1) && (As[A_id].current_position <= Bs[B_id].current_position)) {
As[A_id].next_position = As[A_id].current_position + Math.random() * learning_rate
Bs[B_id].next_position = Bs[B_id].current_position - Math.random() * learning_rate
}
if ((feedback == 1) && (As[A_id].current_position >= Bs[B_id].current_position)) {
As[A_id].next_position = As[A_id].current_position - Math.random() * learning_rate
Bs[B_id].next_position = Bs[B_id].current_position + Math.random() * learning_rate
}
As[A_id].next_position = Math.min(1, Math.max(0, As[A_id].next_position))
Bs[B_id].next_position = Math.min(1, Math.max(0, Bs[B_id].next_position))
}
// *** SVG ANIMATION ***
var delay = 400
var move = 500
// draw and animate pair lines
svg.selectAll(".pair").remove()
svg.selectAll(".pair").data(pairs).enter().append("line")
.attr("class", "pair")
.attr("y1", 230)
.attr("y2", 270)
.style("stroke", "#000")
.style("stroke-width", 0.25)
.style("fill", "none")
.attr("x1", function(d){ return x(As[d.A_id].current_position); })
.attr("x2", function(d){ return x(Bs[d.B_id].current_position); })
.transition().delay(delay).duration(move)
.attr("x1", function(d){ return x(As[d.A_id].next_position); })
.attr("x2", function(d){ return x(Bs[d.B_id].next_position); })
// animate circles
svg.selectAll(".A")
.transition().delay(delay).duration(move)
.attr("cx", function(d){ return x(d.next_position); })
svg.selectAll(".B")
.transition().delay(delay).duration(move)
.attr("cx", function(d){ return x(d.next_position); })
// *** end of svg animation code ***
// prep next timestep
As.forEach(function(d){
d.current_position = d.next_position
})
Bs.forEach(function(d){
d.current_position = d.next_position
})
}, 1200)
}
<!DOCTYPE html>
<meta charset="utf-8">
<style>
.A-color {
fill: #4B90A6;
}
.B-color {
fill: #F3A54A;
}
</style>
<body>
<svg width="960" height="500" viewBox="120 50 580 450">
<line style="stroke:#000; stroke-width: 0.75; fill: none;" y2="250" y1="250" x2="520" x1="180"></line>
<text class="A-color" text-anchor="middle" font-size="12px" y="200" x="350">A elements</text>
<text class="B-color" text-anchor="middle" font-size="12px" y="310" x="350">B elements</text>
</svg>
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="animated-learning.js"></script>
<script>
animated_learning()
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment