def test_saveload():
    decimal = flow.ScalingNshifting(256, -128)
    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    blankf = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    saveload(f,blankf)
def test_grad():
    decimal = flow.ScalingNshifting(256, -128)
    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    fcopy = deepcopy(f)
    fcopy.rounding = torch.round

    field = p.sample(100).detach()
    cfield = deepcopy(field).requires_grad_()
    field.requires_grad_()
    xfield, _ = f.inverse(field)
    xcfield, _ = fcopy.inverse(cfield)
    L = xfield.sum()
    Lc = xcfield.sum()
    L.backward()
    Lc.backward()

    ou = [term for term in f.parameters()]
    ouc = [term for term in fcopy.parameters()]
    assert not np.all(ou[-1].grad.detach().numpy() == ouc[-1].grad.detach().numpy())
Exemple #3
0
    def __init__(self,
                 length,
                 layerList,
                 meanNNlist=None,
                 scaleNNlist=None,
                 repeat=1,
                 depth=None,
                 nMixing=5,
                 decimal=None,
                 rounding=None,
                 name="OneToTwoMERA"):
        kernelSize = 2
        if depth is None or depth == -1:
            depth = int(math.log(length, kernelSize))

        if meanNNlist is None or scaleNNlist is None:
            prior = source.SimpleHierarchyPrior(length, nMixing, decimal,
                                                rounding)
        else:
            lastPrior = source.MixtureDiscreteLogistic([3, 1, 4], nMixing,
                                                       decimal, rounding)
            prior = source.PassiveHierarchyPrior(length,
                                                 lastPrior,
                                                 decimal=decimal,
                                                 rounding=rounding)
        super(OneToTwoMERA, self).__init__(prior, name)

        self.decimal = decimal
        self.rounding = rounding
        self.repeat = repeat
        self.depth = depth

        layerList = layerList * depth

        self.layerList = torch.nn.ModuleList(layerList)

        if meanNNlist is not None and scaleNNlist is not None:
            meanNNlist = meanNNlist * depth
            scaleNNlist = scaleNNlist * depth

            self.meanNNlist = torch.nn.ModuleList(meanNNlist)
            self.scaleNNlist = torch.nn.ModuleList(scaleNNlist)
        else:
            self.meanNNlist = None
            self.scaleNNlist = None
def test_integer():
    decimal = flow.ScalingNshifting(256, -128)
    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    x, _ = f.sample(100)
    assert np.all(np.equal(np.mod(x.detach().numpy(), 1), 0))

    zx, _ = f.inverse(x)
    assert np.all(np.equal(np.mod(zx.detach().numpy(), 1), 0))

    xzx, _ = f.forward(zx)
    assert np.all(np.equal(np.mod(xzx.detach().numpy(), 1), 0))
def test_grad():
    length = 8
    channel = 3
    decimal = flow.ScalingNshifting(256, -128)
    p1 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient)
    p2 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient)
    p3 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient)

    P = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2)

    x = P.sample(100)
    logp = P.logProbability(x)
    L = logp.mean()
    L.backward()

    assert p1.mean.grad.sum().detach().item() != 0
    assert p2.mean.grad.sum().detach().item() != 0
    assert p3.mean.grad.sum().detach().item() != 0

    assert p1.logscale.grad.sum().detach().item() != 0
    assert p2.logscale.grad.sum().detach().item() != 0
    assert p3.logscale.grad.sum().detach().item() != 0

    assert p3.mixing.grad.sum().detach().item() != 0
def test_hierarchyPrior():

    class UniTestPrior(source.Source):
        def __init__(self, nvars, element, name="UniTestPrior"):
            super(UniTestPrior, self).__init__(nvars, 1.0, name)
            self.element = torch.nn.Parameter(torch.tensor(element), requires_grad=False)

        def sample(self, batchSize):
            return torch.ones([batchSize] + self.nvars).to(self.element).float() * self.element

        def _energy(self, z):
            return (torch.tensor([2])**self.element * np.prod(z.shape[2:]))

    length = 32
    channel = 3
    decimal = flow.ScalingNshifting(256, -128)
    p1 = source.DiscreteLogistic([channel, 256, 3], decimal, rounding=utils.roundingWidentityGradient)
    p2 = source.DiscreteLogistic([channel, 64, 3], decimal, rounding=utils.roundingWidentityGradient)
    p3 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient)
    p4 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient)
    p5 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient)

    P = source.HierarchyPrior(channel, length, [p1, p2, p3, p4, p5], repeat=1)

    x = P.sample(100)
    logp = P.logProbability(x)

    import math
    zparts = []
    for no in range(int(math.log(length, 2))):
        _, parts = utils.dispatch(P.factorOutIList[no], P.factorOutJList[no], x)
        zparts.append(parts)

    rcnX = torch.zeros_like(x)
    for no in range(int(math.log(length, 2))):
        part = zparts[no]
        rcnX = utils.collect(P.factorOutIList[no], P.factorOutJList[no], rcnX, part)

    assert_allclose(x.detach(), rcnX.detach())

    length = 8

    p1 = UniTestPrior([channel, 16, 3], 1)
    p2 = UniTestPrior([channel, 4, 3], 2)
    p3 = UniTestPrior([channel, 1, 4], 3)

    Pp = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2)

    x = Pp.sample(1)
    logp = Pp.logProbability(x)

    target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]])
    assert_allclose(x[0, 0].detach().numpy(), target)
    assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3)

    p1 = UniTestPrior([channel, 16, 3], 1)
    p2 = UniTestPrior([channel, 4, 3], 2)
    p3 = UniTestPrior([channel, 1, 4], 3)

    Ppodd = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=1)

    x = Ppodd.sample(1)
    logp = Ppodd.logProbability(x)

    target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]])
    assert_allclose(x[0, 0].detach().numpy(), target)
    assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3)
Exemple #7
0
    def __init__(self,
                 length,
                 layerList,
                 meanNNlist=None,
                 scaleNNlist=None,
                 repeat=1,
                 depth=None,
                 nMixing=5,
                 decimal=None,
                 rounding=None,
                 clamp=None,
                 sameDetail=True,
                 compatible=False,
                 name="SimpleMERA"):
        kernelSize = 2
        if depth is None or depth == -1:
            depth = int(math.log(length, kernelSize))

        if meanNNlist is None or scaleNNlist is None:
            prior = source.SimpleHierarchyPrior(length,
                                                nMixing,
                                                decimal,
                                                rounding,
                                                clamp=clamp,
                                                sameDetail=sameDetail,
                                                compatible=compatible)
        else:
            lastPrior = source.MixtureDiscreteLogistic([3, 1, 4],
                                                       nMixing,
                                                       decimal,
                                                       rounding,
                                                       clamp=clamp)
            prior = source.PassiveHierarchyPrior(length,
                                                 lastPrior,
                                                 decimal=decimal,
                                                 rounding=rounding,
                                                 compatible=compatible)
        super(SimpleMERA, self).__init__(prior, name)

        self.decimal = decimal
        self.rounding = rounding
        self.repeat = repeat
        self.depth = depth
        self.compatible = compatible
        if compatible:
            assert len(layerList) == 4 * repeat * depth
            assert len(meanNNlist) == depth
            assert len(scaleNNlist) == depth

        if len(layerList) != 4 * repeat * depth:
            layerList = layerList * depth

        assert len(layerList) == 4 * repeat * depth

        self.layerList = torch.nn.ModuleList(layerList)

        if meanNNlist is not None and scaleNNlist is not None:
            if len(meanNNlist) != depth:
                meanNNlist = meanNNlist * depth
                scaleNNlist = scaleNNlist * depth

            assert len(meanNNlist) == depth
            assert len(scaleNNlist) == depth

            self.meanNNlist = torch.nn.ModuleList(meanNNlist)
            self.scaleNNlist = torch.nn.ModuleList(scaleNNlist)
        else:
            self.meanNNlist = None
            self.scaleNNlist = None