Created
June 15, 2018 00:13
-
-
Save snoble/da9ba3b6b382bdf794c69ce90a89c22b to your computer and use it in GitHub Desktop.
A thread safe way to do a weighted k-sampling without replacement
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package snoble | |
import scala.collection.immutable.Stream | |
import scala.collection.mutable.Queue | |
import scala.util.Random | |
sealed trait ImmutableWeightedTree[T] { | |
def weight: Double | |
def totalWeight: Double | |
def height: Int | |
} | |
case class ImmutableWeightedTreeNode[T]( | |
value: T, | |
weight: Double, | |
left: ImmutableWeightedTree[T], | |
right: ImmutableWeightedTree[T], | |
) extends ImmutableWeightedTree[T] { | |
lazy val leftWeight = left.totalWeight | |
lazy val rightWeight = right.totalWeight | |
lazy val totalWeight = weight + leftWeight + rightWeight | |
lazy val height = left.height + 1 | |
} | |
case class NilTree[T]() extends ImmutableWeightedTree[T] { | |
def weight = 0 | |
def totalWeight = 0 | |
def height = 0 | |
} | |
object ImmutableWeightedTree { | |
def treeFromStream[T]( | |
stream: Stream[(T, Double)], | |
tree: ImmutableWeightedTree[T] = NilTree[T](), | |
heightLimit: Option[Int] = None | |
): (ImmutableWeightedTree[T], Stream[(T, Double)]) = { | |
(stream, heightLimit) match { | |
case (Stream.Empty, _) => (tree, Stream.Empty) | |
case (_, Some(limit)) if tree.height == limit => (tree, stream) | |
case ((next, nextWeight) #:: rest, _) => { | |
val (right, nextRest) = treeFromStream(rest, NilTree(), Some(tree.height)) | |
treeFromStream(nextRest, ImmutableWeightedTreeNode(next, nextWeight, tree, right), heightLimit) | |
} | |
} | |
} | |
def updateTree[T]( | |
updateAncestors: List[ImmutableWeightedTree[T] => ImmutableWeightedTree[T]], | |
tree: ImmutableWeightedTree[T] | |
): ImmutableWeightedTree[T] = Function.chain(updateAncestors).apply(tree) | |
sealed trait Selection | |
case object SelectLeft extends Selection | |
case object SelectRight extends Selection | |
case object SelectValue extends Selection | |
def selectAndZeroWeight[S, T]( | |
tree: ImmutableWeightedTree[T], | |
state: S, | |
fn: (S, ImmutableWeightedTreeNode[T]) => (S, Selection), | |
updateAncestors: List[ImmutableWeightedTree[T] => ImmutableWeightedTree[T]] | |
): (ImmutableWeightedTree[T], Option[T] | |
) = tree match { | |
case _: NilTree[T] => (updateTree(updateAncestors, tree), None) | |
case nonemptyTree: ImmutableWeightedTreeNode[T] => { | |
val (nextState, selection) = fn(state, nonemptyTree) | |
selection match { | |
case SelectLeft => selectAndZeroWeight(nonemptyTree.left, nextState, fn, ((childTree: ImmutableWeightedTree[T]) => nonemptyTree.copy(left = childTree)) :: updateAncestors) | |
case SelectRight => selectAndZeroWeight(nonemptyTree.right, nextState, fn, ((childTree: ImmutableWeightedTree[T]) => nonemptyTree.copy(right = childTree)) :: updateAncestors) | |
case SelectValue => (updateTree(updateAncestors, nonemptyTree.copy(weight = 0)), Some(nonemptyTree.value)) | |
} | |
} | |
} | |
def randomSelector[T](rng: Random, node: ImmutableWeightedTreeNode[T]) = { | |
val rn = rng.nextDouble * node.totalWeight | |
val selection = if(rn <= node.weight) { | |
SelectValue | |
} else if(rn <= node.weight + node.leftWeight) { | |
SelectLeft | |
} else { | |
SelectRight | |
} | |
(rng, selection) | |
} | |
def sampleNValues[T](tree: ImmutableWeightedTree[T], n: Long) = { | |
val rng = new Random() | |
val fn: (Random, ImmutableWeightedTreeNode[T]) => (Random, Selection) = randomSelector(_, _) | |
Stream.range(0, n) | |
.foldLeft((List[Option[T]](), tree)) { case((sample, prunedTree), _) => | |
val (nextTree, nextSample) = selectAndZeroWeight( | |
prunedTree, rng, fn, List() | |
) | |
(nextSample :: sample, nextTree) | |
} | |
} | |
} | |
object Main extends App { | |
val tree = ImmutableWeightedTree.treeFromStream(Stream.range(0,100).map(x => (x.toString,105D - x.toDouble)))._1 | |
def printTree(tree: ImmutableWeightedTree[String]) { | |
tree match { | |
case unemptyTree: ImmutableWeightedTreeNode[String] => { | |
val q = Queue[(ImmutableWeightedTreeNode[String], Int)]() | |
q.enqueue((unemptyTree, 0)) | |
var lastDepth = 0 | |
while(q.length > 0) { | |
val (n, depth) = q.dequeue() | |
if(depth > lastDepth) { | |
print("\n") | |
lastDepth = depth | |
} | |
print(s"${n.value.toString}(${n.weight.toString}) ") | |
n.left match { | |
case l: ImmutableWeightedTreeNode[String] => q.enqueue((l, depth+1)) | |
case _ => () | |
} | |
n.right match { | |
case r: ImmutableWeightedTreeNode[String] => q.enqueue((r, depth+1)) | |
case _ => () | |
} | |
} | |
} | |
case _ => () | |
} | |
} | |
printTree(tree) | |
val (samples, nextTree) = ImmutableWeightedTree.sampleNValues(tree, 5) | |
println(samples) | |
printTree(nextTree) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment