Skip to content

Instantly share code, notes, and snippets.

@snoble
Created June 15, 2018 00:13
Show Gist options
  • Save snoble/da9ba3b6b382bdf794c69ce90a89c22b to your computer and use it in GitHub Desktop.
Save snoble/da9ba3b6b382bdf794c69ce90a89c22b to your computer and use it in GitHub Desktop.
A thread safe way to do a weighted k-sampling without replacement
package snoble
import scala.collection.immutable.Stream
import scala.collection.mutable.Queue
import scala.util.Random
sealed trait ImmutableWeightedTree[T] {
def weight: Double
def totalWeight: Double
def height: Int
}
case class ImmutableWeightedTreeNode[T](
value: T,
weight: Double,
left: ImmutableWeightedTree[T],
right: ImmutableWeightedTree[T],
) extends ImmutableWeightedTree[T] {
lazy val leftWeight = left.totalWeight
lazy val rightWeight = right.totalWeight
lazy val totalWeight = weight + leftWeight + rightWeight
lazy val height = left.height + 1
}
case class NilTree[T]() extends ImmutableWeightedTree[T] {
def weight = 0
def totalWeight = 0
def height = 0
}
object ImmutableWeightedTree {
def treeFromStream[T](
stream: Stream[(T, Double)],
tree: ImmutableWeightedTree[T] = NilTree[T](),
heightLimit: Option[Int] = None
): (ImmutableWeightedTree[T], Stream[(T, Double)]) = {
(stream, heightLimit) match {
case (Stream.Empty, _) => (tree, Stream.Empty)
case (_, Some(limit)) if tree.height == limit => (tree, stream)
case ((next, nextWeight) #:: rest, _) => {
val (right, nextRest) = treeFromStream(rest, NilTree(), Some(tree.height))
treeFromStream(nextRest, ImmutableWeightedTreeNode(next, nextWeight, tree, right), heightLimit)
}
}
}
def updateTree[T](
updateAncestors: List[ImmutableWeightedTree[T] => ImmutableWeightedTree[T]],
tree: ImmutableWeightedTree[T]
): ImmutableWeightedTree[T] = Function.chain(updateAncestors).apply(tree)
sealed trait Selection
case object SelectLeft extends Selection
case object SelectRight extends Selection
case object SelectValue extends Selection
def selectAndZeroWeight[S, T](
tree: ImmutableWeightedTree[T],
state: S,
fn: (S, ImmutableWeightedTreeNode[T]) => (S, Selection),
updateAncestors: List[ImmutableWeightedTree[T] => ImmutableWeightedTree[T]]
): (ImmutableWeightedTree[T], Option[T]
) = tree match {
case _: NilTree[T] => (updateTree(updateAncestors, tree), None)
case nonemptyTree: ImmutableWeightedTreeNode[T] => {
val (nextState, selection) = fn(state, nonemptyTree)
selection match {
case SelectLeft => selectAndZeroWeight(nonemptyTree.left, nextState, fn, ((childTree: ImmutableWeightedTree[T]) => nonemptyTree.copy(left = childTree)) :: updateAncestors)
case SelectRight => selectAndZeroWeight(nonemptyTree.right, nextState, fn, ((childTree: ImmutableWeightedTree[T]) => nonemptyTree.copy(right = childTree)) :: updateAncestors)
case SelectValue => (updateTree(updateAncestors, nonemptyTree.copy(weight = 0)), Some(nonemptyTree.value))
}
}
}
def randomSelector[T](rng: Random, node: ImmutableWeightedTreeNode[T]) = {
val rn = rng.nextDouble * node.totalWeight
val selection = if(rn <= node.weight) {
SelectValue
} else if(rn <= node.weight + node.leftWeight) {
SelectLeft
} else {
SelectRight
}
(rng, selection)
}
def sampleNValues[T](tree: ImmutableWeightedTree[T], n: Long) = {
val rng = new Random()
val fn: (Random, ImmutableWeightedTreeNode[T]) => (Random, Selection) = randomSelector(_, _)
Stream.range(0, n)
.foldLeft((List[Option[T]](), tree)) { case((sample, prunedTree), _) =>
val (nextTree, nextSample) = selectAndZeroWeight(
prunedTree, rng, fn, List()
)
(nextSample :: sample, nextTree)
}
}
}
object Main extends App {
val tree = ImmutableWeightedTree.treeFromStream(Stream.range(0,100).map(x => (x.toString,105D - x.toDouble)))._1
def printTree(tree: ImmutableWeightedTree[String]) {
tree match {
case unemptyTree: ImmutableWeightedTreeNode[String] => {
val q = Queue[(ImmutableWeightedTreeNode[String], Int)]()
q.enqueue((unemptyTree, 0))
var lastDepth = 0
while(q.length > 0) {
val (n, depth) = q.dequeue()
if(depth > lastDepth) {
print("\n")
lastDepth = depth
}
print(s"${n.value.toString}(${n.weight.toString}) ")
n.left match {
case l: ImmutableWeightedTreeNode[String] => q.enqueue((l, depth+1))
case _ => ()
}
n.right match {
case r: ImmutableWeightedTreeNode[String] => q.enqueue((r, depth+1))
case _ => ()
}
}
}
case _ => ()
}
}
printTree(tree)
val (samples, nextTree) = ImmutableWeightedTree.sampleNValues(tree, 5)
println(samples)
printTree(nextTree)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment