import os
from collections import defaultdict
import numpy as np
import scipy.stats
import torch
ts = torch.tensor
mt = torch.empty
zs = torch.zeros
import torch.distributions as dist
from torch.distributions import constraints
from matplotlib import pyplot
import pyro
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete
from utilities.debugGizmos import *
This shows an example of using the "speed-of-light" trick to map between a vector space and a polytope.
In this case, the polytope exists in RC-dimensional space, conceived of as the possible values of an RxC matrix.
The row and column sums are known, which restricts the matrix to a (R-1)(C-1) dimensional space. The polytope
boundaries are defined by the restriction that all the matrix entries are non-negative. (This is for a voting
application, so the rows are races and the columns are candidates; row sums and column sums are observed, but
the secret ballot prevents observing individual entries.)
The key logic can be seen on lines 122 to 129 for the mapping from vector space to polytope, and
lines 199-205 for the reverse mapping. The actual math here is quite simple, because the "speed of light"
we're using here is linear. This would be substantially (though not unmanageably) more complex if
we used quadratic "speed of light"; on the other hand, the resulting mappings would then have a continuous
first derivative.
MIN_DIFF = 5e-2
def approx_eq(a,b):
"""Returns `True` if tensors a and b are approximately equal at all points.
return(torch.all(, -b.type(TTYPE))), MIN_DIFF)))
def get_indep(R, C, ns, vs): #note, not-voting is a candidate
"""In matrix example, gets the "center" of the polytope; the point we will map the origin to.
R : number of rows
C : number of cols
ns : row sums
vs : column sums
assert len(ns)==R
assert len(vs)==C
assert torch.all(torch.eq(ns,0.) ^ 1) #`^ 1` means "not".
assert torch.all(torch.eq(vs,0.) ^ 1) #`^ 1` means "not".
tot = torch.sum(ns,0)
assert approx_eq(tot,sum(vs)), f'#print("sums",{tot},{sum(vs)})'
indep = ts([[(rnum * cnum / tot) for cnum in vs] for rnum in ns])#TODO: better pytorch
return indep
def process_data(data):
"""Preprocessing; gets "center" and total population for a pile of precincts
data: tuple of (ns, vs) where
ns : tensor of row sums; dim 0 is precinct, dim 1 is row
vs : tensor of column sums, as above
ns, vs = data
R = len(ns[0])
C = len(vs[0])
indeps = [get_indep(R,C,n,v) for n, v in zip(ns, vs)]
tots = [torch.sum(n) for n in ns]
return (ns, vs, indeps, tots)
def process_dataU(data):
"""Vectorized version of process_data. Not optimized.
ns, vs, indeps, tots = process_data(data)
indepsU = torch.stack(indeps).view(-1,torch.numel(indeps[0]))
totsU = torch.stack(tots)
return (ns, vs, indepsU, totsU)
def to_subspace(raw, R, C, ns, vs):
"""takes an arbitrary value in the (R-1)x(C-1) vector space, and projects it to the appropriate subspace of RxC
vdiffs = vs - torch.sum(raw,0)
tot = torch.sum(vs)
result = raw + torch.stack([vdiffs*ns[r]/tot for r in range(R)],0)
assert approx_eq(torch.sum(result,0), vs)
print(f"to_subspace error:{torch.sum(raw,0)}")
print(f"to_subspace error:{torch.sum(result,0)}")
print(f"to_subspace error:{vs}")
print(f"to_subspace error:{(torch.sum(result,1), ns)}")
print(f"to_subspace error:{vdiffs}")
print(f"to_subspace error:{tot}")
print(f"to_subspace error:{vdiffs[0]*vs[0]/tot}")
print(f"to_subspace error:{torch.stack([vdiffs*ns[r]/tot for r in range(R)],0)}")
print(f"to_subspace error:{vdiffs}")
print(f"to_subspace error:{vdiffs}")
def polytopize(R, C, raw, start, do_aug=True):
"""takes an arbitrary value in the (R-1)x(C-1) vector space, and projects it to the polytope inside the subspace inside RxC
if do_aug:
aug1 =,-raw.sum(0).unsqueeze(0)),0)
aug2 =,-aug1.sum(1).unsqueeze(1)),1)
aug2 = raw
if 0==torch.max(torch.abs(aug2)):
ratio = torch.div(aug2, -start)
print(f"line 67:{R},{C},{raw.size()},{start.size()}")
closest = torch.argmax(ratio)
r = start[closest//C,closest%C]
edgedir = -r * aug2 / aug2[closest//C,closest%C]
edgepoint = start + edgedir
backoff = torch.exp(-ratio[closest//C,closest%C])
return edgepoint - backoff * edgedir
def polytopizeU(R, C, raw, start, return_ldaj=False, return_plural=False):
"""vectorized(tensorized?) version of polytopize
aug1 =,-raw.sum(1).unsqueeze(1)),1)
aug2 =,-aug1.sum(2).unsqueeze(2)),2).view(-1,R*C)
ratio = torch.div(aug2, -start)
print(f"line 67:{R},{C},{raw.size()},{start.size()}")
closest = torch.argmax(ratio, 1)
r = start.gather(1,closest.unsqueeze(1))
edgedir = -r * aug2 / aug2.gather(1,closest.unsqueeze(1))
edgepoint = start + edgedir
closest_ratio = ratio.gather(1,closest.unsqueeze(1))
backoff = torch.exp(-closest_ratio)
result = (edgepoint - backoff * edgedir).view(-1,R,C)
if return_ldaj: # log det abs jacobian
lowdim = (R-1)*(C-1)
ldajs = -closest_ratio*lowdim + torch.log(closest_ratio)*(lowdim + 1)
if return_plural:
ldaj = ldajs
ldaj = torch.sum(ldajs)
return (result, ldaj)
return result
def depolytopizeU(R, C, rawpoly, start, line=None):
"""vectorized(tensorized?) version of depolytopize
poly = rawpoly.view(-1,R*C)
assert poly.size() == start.size(), f"depoly fail {R},{C},{poly.size()},{start.size()}"
rawdiff = poly - start
diff = rawdiff + (rawdiff == 0).type(TTYPE) * DEPOLY_EPSILON
ratio = torch.div(poly, -start)
closest = torch.argmax(ratio, 1)
facs = start * torch.log(-ratio) / diff
result = facs.gather(1,closest.unsqueeze(1)) * diff
if torch.any(torch.isnan(result)):
print("depolytopizeU fail",line)
print(R, C, poly[:3,], start[:3,])
print("2depolytopize fail")
for i in range(rawpoly.size()[0]):
if torch.any(torch.isnan(result[i])):
print("problem index: ",i)
import pdb; pdb.set_trace()
return result.view(-1,R,C)[:,:(R-1),:(C-1)]
def depolytopize(R, C, poly, start):
"""takes an arbitrary value in the polytope inside the subspace inside RxC, and projects it to the (R-1)x(C-1) vector space
assert poly.size() == start.size(), f"depoly fail {R},{C},{poly.size()},{start.size()}"
diff = poly - start
ratio = torch.div(diff, -start)
closest = torch.argmax(ratio)
r = start[closest//C,closest%C]
fac = r * torch.log(1-ratio[closest//C,closest%C]) / diff[closest//C,closest%C]
result = fac * diff
if torch.any(torch.isnan(result)):
print("depolytopize fail")
print(R, C, poly, start)
print("2depolytopize fail...")
return result[:(R-1),:(C-1)]
def dummyPrecinct(R, C, i=0, israndom=True):
"""Create a dummy precinct, for testing
if israndom:
ns = dist.Exponential(.01).sample(torch.Size([R]))
vs = dist.Exponential(.01).sample(torch.Size([C]))
#print("Not random")
ns = ts([r+i+1. for r in range(R)])
vs = ts([c+i+2. for c in range(C)])
vs = vs / torch.sum(vs) * torch.sum(ns)
indep = get_indep(R,C,ns,vs)
def test_funs(R, C, innerReps=4, outerReps=4, israndom=True):
"""Test to make sure polytopize and depolytopize are inverses, on some dummy precincts
for i in range(outerReps):
ns,vs,indep = dummyPrecinct(R,C,i,israndom)
for j in range(innerReps):
loc = pyro.distributions.Normal(0.,4.).sample(torch.Size([R-1,C-1]))
polytopedLoc = polytopize(R,C,loc,indep)
depolytopedLoc = depolytopize(R,C,polytopedLoc,indep)
assert approx_eq(ns, torch.sum(polytopedLoc, dim=1).view(R)) , "ns fail"
assert approx_eq(vs, torch.sum(polytopedLoc, dim=0).view(C)) , "vs fail"
assert torch.all(,0)) , "ge fail"
assert approx_eq(loc,depolytopedLoc) , "round-trip fail"
dp(" ((success))")
except Exception as e:
print(" loc",loc)
print(" indep",indep.view(R,C))
print(" polytopedLoc",polytopedLoc)
print(" depolytopedLoc",depolytopedLoc)
def test_funsU(U, R, C, innerReps=16, outerReps=16, israndom=True):
"""Test to make sure polytopizeU and depolytopizeU are inverses, on some dummy precincts
for i in range(outerReps):
ns,vs,indep = zip(*[dummyPrecinct(R,C,i,israndom) for u in range(U)])
ns,vs,indep = [torch.stack(a) for a in [ns,vs,indep]]
indep = indep.view(U,R*C)
for j in range(innerReps):
loc = pyro.distributions.Normal(0.,4.).sample(torch.Size([U,R-1,C-1]))
polytopedLoc = polytopizeU(R,C,loc,indep)
depolytopedLoc = depolytopizeU(R,C,polytopedLoc,indep)
assert approx_eq(ns, torch.sum(polytopedLoc, dim=2)) , "ns fail"
assert approx_eq(vs, torch.sum(polytopedLoc, dim=1)) , "vs fail"
assert torch.all(,0)) , ">=0 fail"
assert approx_eq(loc,depolytopedLoc) , "round-trip fail"
dp(" ((success))")
except Exception as e:
dp(" (fail)")
print(" loc",loc[0])
print(" indep",indep[0].view(R,C))
print(" polytopedLoc",polytopedLoc[0])
print(" depolytopedLoc",depolytopedLoc[0])
