def test_bijective_different_depth(): p = source.Gaussian([1, 4, 4]) BigList = [] for _ in range(2 * 2 * 2): maskList = [] for n in range(4): if n % 2 == 0: b = torch.zeros(1, 4) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[:, i] = 1 b = b.reshape(1, 1, 2, 2) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) BigList.append(maskList) layers = [ flow.RNVP(BigList[n], [ utils.SimpleMLPreshape( [4, 32, 32, 4], [nn.ELU(), nn.ELU(), None]) for _ in range(4) ], [ utils.SimpleMLPreshape( [4, 32, 32, 4], [nn.ELU(), nn.ELU(), utils.ScalableTanh(4)]) for _ in range(4) ]) for n in range(2 * 2 * 2) ] length = 4 repeat = 2 t = flow.MERA(2, length, layers, repeat, depth=1, prior=p) bijective(t)
def symmetryMERAInit(L, d, nlayers, nmlp, nhidden, nrepeat, symmetryList, device, dtype, name=None, channel=1, depthMERA=None): s = source.Gaussian([channel] + [L] * d) depth = int(math.log(L, 2)) * nrepeat * 2 coreSize = 4 * channel MaskList = [] for _ in range(depth): masklist = [] for n in range(nlayers): if n % 2 == 0: b = torch.zeros(1, coreSize) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[:, i] = 1 b = b.view(1, channel, 2, 2) else: b = 1 - b masklist.append(b) masklist = torch.cat(masklist, 0).to(torch.float32) MaskList.append(masklist) dimList = [coreSize] for _ in range(nmlp): dimList.append(nhidden) dimList.append(coreSize) layers = [ flow.RNVP(MaskList[n], [ utils.SimpleMLPreshape(dimList, [nn.ELU() for _ in range(nmlp)] + [None]) for _ in range(nlayers) ], [ utils.SimpleMLPreshape(dimList, [nn.ELU() for _ in range(nmlp)] + [utils.ScalableTanh(coreSize)]) for _ in range(nlayers) ]) for n in range(depth) ] f = flow.MERA(2, L, layers, nrepeat, depth=depthMERA, prior=s) if symmetryList is not None: f = Symmetrized(f, symmetryList, name=name) f.to(device=device, dtype=dtype) return f
def test_bijective(): p = source.Gaussian([4, 4]) BigList = [] for _ in range(2 * 2 * 2): maskList = [] for n in range(4): if n % 2 == 0: b = torch.zeros(1, 4) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[:, i] = 1 b = b.view(1, 2, 2) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) BigList.append(maskList) layers = [ flow.RNVP(BigList[n], [ utils.SimpleMLPreshape( [4, 32, 32, 4], [nn.ELU(), nn.ELU(), None]) for _ in range(4) ], [ utils.SimpleMLPreshape( [4, 32, 32, 4], [nn.ELU(), nn.ELU(), utils.ScalableTanh(4)]) for _ in range(4) ]) for n in range(2 * 2 * 2) ] length = 4 repeat = 2 t = flow.MERA(2, length, layers, repeat, p) def op(x): return -x sym = [op] m = train.Symmetrized(t, sym) z = m.prior.sample(100) xz1, _ = m.inverse(z) xz2, _ = m.inverse(z) p1 = m.logProbability(xz1) p2 = m.logProbability(xz2) z1, _ = m.forward(xz1) xz1p, _ = m.inverse(z1) assert ((xz1 == xz2).sum() + (xz1 == -xz2).sum()) == 100 * 4 * 4 assert_array_almost_equal(p1.detach().numpy(), p2.detach().numpy(), decimal=5) assert_array_almost_equal(np.fabs(xz1.detach().numpy()), np.fabs(xz1p.detach().numpy()))