Esempio n. 1
0
def test_bijective_different_depth():
    p = source.Gaussian([1, 4, 4])

    BigList = []
    for _ in range(2 * 2 * 2):
        maskList = []
        for n in range(4):
            if n % 2 == 0:
                b = torch.zeros(1, 4)
                i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2)
                b.zero_()[:, i] = 1
                b = b.reshape(1, 1, 2, 2)
            else:
                b = 1 - b
            maskList.append(b)
        maskList = torch.cat(maskList, 0).to(torch.float32)
        BigList.append(maskList)

    layers = [
        flow.RNVP(BigList[n], [
            utils.SimpleMLPreshape(
                [4, 32, 32, 4], [nn.ELU(), nn.ELU(), None]) for _ in range(4)
        ], [
            utils.SimpleMLPreshape(
                [4, 32, 32, 4],
                [nn.ELU(), nn.ELU(), utils.ScalableTanh(4)]) for _ in range(4)
        ]) for n in range(2 * 2 * 2)
    ]

    length = 4
    repeat = 2

    t = flow.MERA(2, length, layers, repeat, depth=1, prior=p)
    bijective(t)
Esempio n. 2
0
def symmetryMERAInit(L,
                     d,
                     nlayers,
                     nmlp,
                     nhidden,
                     nrepeat,
                     symmetryList,
                     device,
                     dtype,
                     name=None,
                     channel=1,
                     depthMERA=None):
    s = source.Gaussian([channel] + [L] * d)

    depth = int(math.log(L, 2)) * nrepeat * 2

    coreSize = 4 * channel

    MaskList = []
    for _ in range(depth):
        masklist = []
        for n in range(nlayers):
            if n % 2 == 0:
                b = torch.zeros(1, coreSize)
                i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2)
                b.zero_()[:, i] = 1
                b = b.view(1, channel, 2, 2)
            else:
                b = 1 - b
            masklist.append(b)
        masklist = torch.cat(masklist, 0).to(torch.float32)
        MaskList.append(masklist)

    dimList = [coreSize]
    for _ in range(nmlp):
        dimList.append(nhidden)
    dimList.append(coreSize)

    layers = [
        flow.RNVP(MaskList[n], [
            utils.SimpleMLPreshape(dimList, [nn.ELU()
                                             for _ in range(nmlp)] + [None])
            for _ in range(nlayers)
        ], [
            utils.SimpleMLPreshape(dimList, [nn.ELU() for _ in range(nmlp)] +
                                   [utils.ScalableTanh(coreSize)])
            for _ in range(nlayers)
        ]) for n in range(depth)
    ]

    f = flow.MERA(2, L, layers, nrepeat, depth=depthMERA, prior=s)
    if symmetryList is not None:
        f = Symmetrized(f, symmetryList, name=name)
    f.to(device=device, dtype=dtype)
    return f
Esempio n. 3
0
def test_bijective():
    p = source.Gaussian([4, 4])

    BigList = []
    for _ in range(2 * 2 * 2):
        maskList = []
        for n in range(4):
            if n % 2 == 0:
                b = torch.zeros(1, 4)
                i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2)
                b.zero_()[:, i] = 1
                b = b.view(1, 2, 2)
            else:
                b = 1 - b
            maskList.append(b)
        maskList = torch.cat(maskList, 0).to(torch.float32)
        BigList.append(maskList)

    layers = [
        flow.RNVP(BigList[n], [
            utils.SimpleMLPreshape(
                [4, 32, 32, 4], [nn.ELU(), nn.ELU(), None]) for _ in range(4)
        ], [
            utils.SimpleMLPreshape(
                [4, 32, 32, 4],
                [nn.ELU(), nn.ELU(), utils.ScalableTanh(4)]) for _ in range(4)
        ]) for n in range(2 * 2 * 2)
    ]

    length = 4
    repeat = 2

    t = flow.MERA(2, length, layers, repeat, p)

    def op(x):
        return -x

    sym = [op]
    m = train.Symmetrized(t, sym)
    z = m.prior.sample(100)
    xz1, _ = m.inverse(z)
    xz2, _ = m.inverse(z)
    p1 = m.logProbability(xz1)
    p2 = m.logProbability(xz2)

    z1, _ = m.forward(xz1)
    xz1p, _ = m.inverse(z1)

    assert ((xz1 == xz2).sum() + (xz1 == -xz2).sum()) == 100 * 4 * 4
    assert_array_almost_equal(p1.detach().numpy(),
                              p2.detach().numpy(),
                              decimal=5)
    assert_array_almost_equal(np.fabs(xz1.detach().numpy()),
                              np.fabs(xz1p.detach().numpy()))