Skip to content

Instantly share code, notes, and snippets.

@avibryant
Created October 17, 2017 05:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save avibryant/00bc2ca495766f86c62b2573263cd26e to your computer and use it in GitHub Desktop.
Save avibryant/00bc2ca495766f86c62b2573263cd26e to your computer and use it in GitHub Desktop.
//A type-level implementation of the broadcasting rules from NumPy,
//such that incompatible shapes are a compile-time error.
//see eg http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc
//for more info on what the constraints this is enccoding
//Note: this is only the shape logic. It does not include,
//and is agnostic to, any particular multi-dimensional array implementation.
sealed trait Shape
sealed trait Dimension extends Shape
case class One() extends Dimension {
def by[X<:Shape](inner: X) = By(this, inner)
}
case class N[A]() extends Dimension {
def by[X<:Shape](inner: X) = By(this, inner)
}
case class By[D<:Dimension,X<:Shape](outer: D, inner: X) extends Shape
object Shape {
def broadcast[X <: Shape,Y <: Shape,Z <: Shape](x: X, y: Y)(implicit b: Broadcaster[X,Y,Z]): Z
= b(x,y)
def newAxis[X <: Shape, Z <: Shape](x: X)(implicit n: NewAxis[X,Z]): Z = n(x)
}
trait Broadcaster[X <: Shape,Y <: Shape,Z <: Shape] {
def apply(x: X, y: Y): Z
}
trait BroadcasterLowPriority {
implicit def one2one = new Broadcaster[One,One,One] {
def apply(x: One, y: One) = x
}
implicit def one2n[A] = new Broadcaster[One,N[A],N[A]] {
def apply(x: One, y: N[A]) = y
}
implicit def n2one[A] = new Broadcaster[N[A],One,N[A]] {
def apply(x: N[A], y: One) = x
}
implicit def n2n[A] = new Broadcaster[N[A],N[A],N[A]] {
def apply(x: N[A], y: N[A]) = x
}
implicit def leftInner[X <: Dimension, B<:Shape, C<:Shape, D<:Shape](
implicit innerBroadcaster: Broadcaster[B,C,D]
) = new Broadcaster[By[X,B], C, By[X,D]] {
def apply(x: By[X,B], y: C) = By(x.outer, innerBroadcaster(x.inner,y))
}
implicit def rightInner[X <: Dimension, B<:Shape, C<:Shape, D<:Shape](
implicit innerBroadcaster: Broadcaster[B,C,D]
) = new Broadcaster[B,By[X,C],By[X,D]] {
def apply(x: B, y: By[X,C]) = By(y.outer, innerBroadcaster(x,y.inner))
}
}
object Broadcaster extends BroadcasterLowPriority {
implicit def outerInner[X <: Dimension, Y <: Dimension, Z <: Dimension, B<:Shape, C<:Shape, D<:Shape](
implicit outerBroadcaster: Broadcaster[X,Y,Z],
innerBroadcaster: Broadcaster[B,C,D]
) = new Broadcaster[By[X,B],By[Y,C],By[Z,D]] {
def apply(x: By[X,B], y: By[Y,C]) =
By(outerBroadcaster(x.outer,y.outer), innerBroadcaster(x.inner,y.inner))
}
}
trait NewAxis[X <: Shape, Z <: Shape] {
def apply(x: X): Z
}
object NewAxis {
implicit def one: NewAxis[One, By[One,One]] = new NewAxis[One, By[One,One]] {
def apply(x: One) = x.by(One())
}
implicit def n[A]: NewAxis[N[A], By[N[A],One]] = new NewAxis[N[A], By[N[A],One]] {
def apply(x: N[A]) = x.by(One())
}
implicit def by[D <: Dimension, X <: Shape, Z <: Shape](
implicit innerNewAxis: NewAxis[X,Z]
) = new NewAxis[By[D,X], By[D,Z]] {
def apply(x: By[D,X]) = By(x.outer, innerNewAxis(x.inner))
}
}
object Example {
def example[Foo,Bar] = {
val scalar = One()
val vector = N[Foo]
val vector2 = N[Bar]
val z1: One = Shape.broadcast(scalar, scalar)
val z2: N[Foo] = Shape.broadcast(vector, vector)
val z3: N[Foo] = Shape.broadcast(scalar, vector)
val z4: N[Foo] = Shape.broadcast(vector, scalar)
//fails to compile
//val z5 = Shape.broadcast(vector, vector2)
val matrix1: One By N[Foo] = scalar by vector
val matrix2: N[Bar] By N[Foo] = vector2 by vector
val matrix3: N[Foo] By N[Foo] = vector by vector
val z6: One By N[Foo] = Shape.broadcast(matrix1, matrix1)
val z7: N[Bar] By N[Foo] = Shape.broadcast(matrix2, matrix2)
val z8: N[Bar] By N[Foo] = Shape.broadcast(matrix1, matrix2)
//fails to compile
//val z9 = Shape.broadcast(matrix2, matrix3)
val z10: One By N[Foo] = Shape.broadcast(vector, matrix1)
val z11: N[Bar] By N[Foo] = Shape.broadcast(vector, matrix2)
val z12: N[Foo] By N[Foo] = Shape.broadcast(vector, matrix3)
//fails to compile
//val z13 = Shape.broadcast(vector2, matrix1)
val z14: One By N[Foo] = Shape.broadcast(scalar, matrix1)
val z15: N[Bar] By N[Foo] = Shape.broadcast(scalar, matrix2)
val z16: N[Foo] By N[Foo] = Shape.broadcast(scalar, matrix3)
val z17: One By N[Foo] = Shape.broadcast(matrix1, scalar)
val z18: N[Bar] By N[Foo] = Shape.broadcast(matrix2, scalar)
val z19: N[Foo] By N[Foo] = Shape.broadcast(matrix3, scalar)
val z20: N[Foo] By (N[Foo] By One) = Shape.newAxis(matrix3)
val z21: N[Foo] By (N[Foo] By (One By One)) = Shape.newAxis(z20)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment