def test_saveload(): decimal = flow.ScalingNshifting(256, -128) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] blankf = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) saveload(f,blankf)
def test_grad(): decimal = flow.ScalingNshifting(256, -128) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) fcopy = deepcopy(f) fcopy.rounding = torch.round field = p.sample(100).detach() cfield = deepcopy(field).requires_grad_() field.requires_grad_() xfield, _ = f.inverse(field) xcfield, _ = fcopy.inverse(cfield) L = xfield.sum() Lc = xcfield.sum() L.backward() Lc.backward() ou = [term for term in f.parameters()] ouc = [term for term in fcopy.parameters()] assert not np.all(ou[-1].grad.detach().numpy() == ouc[-1].grad.detach().numpy())
def __init__(self, length, layerList, meanNNlist=None, scaleNNlist=None, repeat=1, depth=None, nMixing=5, decimal=None, rounding=None, name="OneToTwoMERA"): kernelSize = 2 if depth is None or depth == -1: depth = int(math.log(length, kernelSize)) if meanNNlist is None or scaleNNlist is None: prior = source.SimpleHierarchyPrior(length, nMixing, decimal, rounding) else: lastPrior = source.MixtureDiscreteLogistic([3, 1, 4], nMixing, decimal, rounding) prior = source.PassiveHierarchyPrior(length, lastPrior, decimal=decimal, rounding=rounding) super(OneToTwoMERA, self).__init__(prior, name) self.decimal = decimal self.rounding = rounding self.repeat = repeat self.depth = depth layerList = layerList * depth self.layerList = torch.nn.ModuleList(layerList) if meanNNlist is not None and scaleNNlist is not None: meanNNlist = meanNNlist * depth scaleNNlist = scaleNNlist * depth self.meanNNlist = torch.nn.ModuleList(meanNNlist) self.scaleNNlist = torch.nn.ModuleList(scaleNNlist) else: self.meanNNlist = None self.scaleNNlist = None
def test_integer(): decimal = flow.ScalingNshifting(256, -128) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) x, _ = f.sample(100) assert np.all(np.equal(np.mod(x.detach().numpy(), 1), 0)) zx, _ = f.inverse(x) assert np.all(np.equal(np.mod(zx.detach().numpy(), 1), 0)) xzx, _ = f.forward(zx) assert np.all(np.equal(np.mod(xzx.detach().numpy(), 1), 0))
def test_grad(): length = 8 channel = 3 decimal = flow.ScalingNshifting(256, -128) p1 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient) p2 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient) p3 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient) P = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2) x = P.sample(100) logp = P.logProbability(x) L = logp.mean() L.backward() assert p1.mean.grad.sum().detach().item() != 0 assert p2.mean.grad.sum().detach().item() != 0 assert p3.mean.grad.sum().detach().item() != 0 assert p1.logscale.grad.sum().detach().item() != 0 assert p2.logscale.grad.sum().detach().item() != 0 assert p3.logscale.grad.sum().detach().item() != 0 assert p3.mixing.grad.sum().detach().item() != 0
def test_hierarchyPrior(): class UniTestPrior(source.Source): def __init__(self, nvars, element, name="UniTestPrior"): super(UniTestPrior, self).__init__(nvars, 1.0, name) self.element = torch.nn.Parameter(torch.tensor(element), requires_grad=False) def sample(self, batchSize): return torch.ones([batchSize] + self.nvars).to(self.element).float() * self.element def _energy(self, z): return (torch.tensor([2])**self.element * np.prod(z.shape[2:])) length = 32 channel = 3 decimal = flow.ScalingNshifting(256, -128) p1 = source.DiscreteLogistic([channel, 256, 3], decimal, rounding=utils.roundingWidentityGradient) p2 = source.DiscreteLogistic([channel, 64, 3], decimal, rounding=utils.roundingWidentityGradient) p3 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient) p4 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient) p5 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient) P = source.HierarchyPrior(channel, length, [p1, p2, p3, p4, p5], repeat=1) x = P.sample(100) logp = P.logProbability(x) import math zparts = [] for no in range(int(math.log(length, 2))): _, parts = utils.dispatch(P.factorOutIList[no], P.factorOutJList[no], x) zparts.append(parts) rcnX = torch.zeros_like(x) for no in range(int(math.log(length, 2))): part = zparts[no] rcnX = utils.collect(P.factorOutIList[no], P.factorOutJList[no], rcnX, part) assert_allclose(x.detach(), rcnX.detach()) length = 8 p1 = UniTestPrior([channel, 16, 3], 1) p2 = UniTestPrior([channel, 4, 3], 2) p3 = UniTestPrior([channel, 1, 4], 3) Pp = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2) x = Pp.sample(1) logp = Pp.logProbability(x) target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]]) assert_allclose(x[0, 0].detach().numpy(), target) assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3) p1 = UniTestPrior([channel, 16, 3], 1) p2 = UniTestPrior([channel, 4, 3], 2) p3 = UniTestPrior([channel, 1, 4], 3) Ppodd = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=1) x = Ppodd.sample(1) logp = Ppodd.logProbability(x) target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]]) assert_allclose(x[0, 0].detach().numpy(), target) assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3)
def __init__(self, length, layerList, meanNNlist=None, scaleNNlist=None, repeat=1, depth=None, nMixing=5, decimal=None, rounding=None, clamp=None, sameDetail=True, compatible=False, name="SimpleMERA"): kernelSize = 2 if depth is None or depth == -1: depth = int(math.log(length, kernelSize)) if meanNNlist is None or scaleNNlist is None: prior = source.SimpleHierarchyPrior(length, nMixing, decimal, rounding, clamp=clamp, sameDetail=sameDetail, compatible=compatible) else: lastPrior = source.MixtureDiscreteLogistic([3, 1, 4], nMixing, decimal, rounding, clamp=clamp) prior = source.PassiveHierarchyPrior(length, lastPrior, decimal=decimal, rounding=rounding, compatible=compatible) super(SimpleMERA, self).__init__(prior, name) self.decimal = decimal self.rounding = rounding self.repeat = repeat self.depth = depth self.compatible = compatible if compatible: assert len(layerList) == 4 * repeat * depth assert len(meanNNlist) == depth assert len(scaleNNlist) == depth if len(layerList) != 4 * repeat * depth: layerList = layerList * depth assert len(layerList) == 4 * repeat * depth self.layerList = torch.nn.ModuleList(layerList) if meanNNlist is not None and scaleNNlist is not None: if len(meanNNlist) != depth: meanNNlist = meanNNlist * depth scaleNNlist = scaleNNlist * depth assert len(meanNNlist) == depth assert len(scaleNNlist) == depth self.meanNNlist = torch.nn.ModuleList(meanNNlist) self.scaleNNlist = torch.nn.ModuleList(scaleNNlist) else: self.meanNNlist = None self.scaleNNlist = None