def test_ScalingNshifting_saveload(): p = source.Gaussian([2, 2]) s = float(np.random.randint(1, 11)) t = float(np.random.randint(0, 11)) f = flow.ScalingNshifting(s, t) f.prior = p pp = source.Gaussian([2, 2]) s = float(np.random.randint(1, 11)) t = float(np.random.randint(0, 11)) blankf = flow.ScalingNshifting(s, t) blankf.prior = pp saveload(f, blankf)
def test_grad(): decimal = flow.ScalingNshifting(256, -128) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) fcopy = deepcopy(f) fcopy.rounding = torch.round field = p.sample(100).detach() cfield = deepcopy(field).requires_grad_() field.requires_grad_() xfield, _ = f.inverse(field) xcfield, _ = fcopy.inverse(cfield) L = xfield.sum() Lc = xcfield.sum() L.backward() Lc.backward() ou = [term for term in f.parameters()] ouc = [term for term in fcopy.parameters()] assert not np.all(ou[-1].grad.detach().numpy() == ouc[-1].grad.detach().numpy())
def test_saveload(): decimal = flow.ScalingNshifting(256, -128) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] blankf = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) saveload(f,blankf)
def test_ScalingNshifting(): p = source.Gaussian([2, 2]) s = float(np.random.randint(1, 11)) t = float(np.random.randint(0, 11)) f = flow.ScalingNshifting(s, t) f.prior = p bijective(f)
def test_bijective(): decimal = flow.ScalingNshifting(256, -128) layerList = [] for i in range(4 * 2): f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 3, 3, padding=1)) layerList.append(f) meanNNlist = [] scaleNNlist = [] meanNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) scaleNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 2, None, 5, decimal, utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() zSamples, _ = t.inverse(samples) rcnSamples, _ = t.forward(zSamples) prob = t.logProbability(samples) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy()) # Test the depth argument t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 2, 2, 5, decimal, utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() zSamples, _ = t.inverse(samples) rcnSamples, _ = t.forward(zSamples) #prob = t.logProbability(samples) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy())
def test_integer(): decimal = flow.ScalingNshifting(256, -128) p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient) maskList = [] for n in range(4): if n % 2 == 0: b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)] f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p) x, _ = f.sample(100) assert np.all(np.equal(np.mod(x.detach().numpy(), 1), 0)) zx, _ = f.inverse(x) assert np.all(np.equal(np.mod(zx.detach().numpy(), 1), 0)) xzx, _ = f.forward(zx) assert np.all(np.equal(np.mod(xzx.detach().numpy(), 1), 0))
def test_grad(): length = 8 channel = 3 decimal = flow.ScalingNshifting(256, -128) p1 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient) p2 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient) p3 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient) P = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2) x = P.sample(100) logp = P.logProbability(x) L = logp.mean() L.backward() assert p1.mean.grad.sum().detach().item() != 0 assert p2.mean.grad.sum().detach().item() != 0 assert p3.mean.grad.sum().detach().item() != 0 assert p1.logscale.grad.sum().detach().item() != 0 assert p2.logscale.grad.sum().detach().item() != 0 assert p3.logscale.grad.sum().detach().item() != 0 assert p3.mixing.grad.sum().detach().item() != 0
def test_saveload(): shapeList2D = [3] + [12] * (1 + 1) + [3 * 3] shapeList1D = [3] + [12] * (1 + 1) + [3] def buildLayers2D(shapeList): layers = [] for no, chn in enumerate(shapeList[:-1]): if no != 0 and no != len(shapeList) - 2: layers.append(torch.nn.Conv2d(chn, shapeList[no + 1], 1)) else: layers.append( torch.nn.Conv2d(chn, shapeList[no + 1], 3, padding=1)) if no != len(shapeList) - 2: layers.append(torch.nn.ReLU(inplace=True)) return layers def buildLayers1D(shapeList): layers = [] for no, chn in enumerate(shapeList[:-1]): if no != 0 and no != len(shapeList) - 2: layers.append(torch.nn.Conv1d(chn, shapeList[no + 1], 1)) else: layers.append( torch.nn.Conv1d(chn, shapeList[no + 1], 3, padding=1)) if no != len(shapeList) - 2: layers.append(torch.nn.ReLU(inplace=True)) layers = torch.nn.Sequential(*layers) #torch.nn.init.zeros_(layers[-1].weight) #torch.nn.init.zeros_(layers[-1].bias) return layers decimal = flow.ScalingNshifting(256, 0) layerList = [] for i in range(2 * 2): layerList.append(buildLayers1D(shapeList1D)) meanNNlist = [] scaleNNlist = [] layers = buildLayers2D(shapeList2D) meanNNlist.append(torch.nn.Sequential(*layers)) layers = buildLayers2D(shapeList2D) scaleNNlist.append(torch.nn.Sequential(*layers)) t = flow.OneToTwoMERA(8, layerList, meanNNlist, scaleNNlist, 2, None, 5, decimal=decimal, rounding=utils.roundingWidentityGradient) decimal = flow.ScalingNshifting(256, 0) layerList = [] for i in range(2 * 2): layerList.append(buildLayers1D(shapeList1D)) meanNNlist = [] scaleNNlist = [] layers = buildLayers2D(shapeList2D) meanNNlist.append(torch.nn.Sequential(*layers)) layers = buildLayers2D(shapeList2D) scaleNNlist.append(torch.nn.Sequential(*layers)) tt = flow.OneToTwoMERA(8, layerList, meanNNlist, scaleNNlist, 2, None, 5, decimal=decimal, rounding=utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() torch.save(t.save(), "testsaving.saving") tt.load(torch.load("testsaving.saving")) tzSamples, _ = t.inverse(samples) ttzSamples, _ = tt.inverse(samples) rcnSamples, _ = t.forward(tzSamples) ttrcnSamples, _ = tt.forward(ttzSamples) assert_allclose(tzSamples.detach().numpy(), ttzSamples.detach().numpy()) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy()) assert_allclose(rcnSamples.detach().numpy(), ttrcnSamples.detach().numpy())
def test_wavelet(): def back01(tensor): ten = tensor.clone().float() ten = ten.view(ten.shape[0] * ten.shape[1], -1) ten -= ten.min(1, keepdim=True)[0] ten /= ten.max(1, keepdim=True)[0] ten = ten.view(tensor.shape) return ten # yet another renorm fn def batchNorm(tensor, base=1.0): m = nn.BatchNorm2d(tensor.shape[1], affine=False) return m(tensor).float() + base renormFn = lambda x: back01(batchNorm(x)) def im2grp(t): return t.reshape(t.shape[0], t.shape[1], t.shape[2] // 2, 2, t.shape[3] // 2, 2).permute([0, 1, 2, 4, 3, 5]).reshape(t.shape[0], t.shape[1], -1, 4) def grp2im(t): return t.reshape(t.shape[0], t.shape[1], int(t.shape[2]**0.5), int(t.shape[2]**0.5), 2, 2).permute([0, 1, 2, 4, 3, 5]).reshape(t.shape[0], t.shape[1], int(t.shape[2]**0.5) * 2, int(t.shape[2]**0.5) * 2) decimal = flow.ScalingNshifting(256, 0) psudoRounding = torch.nn.Identity() IMG = Image.open('./etc/lena512color.tiff') IMG = torch.from_numpy(np.array(IMG)).permute([2, 0, 1]) IMG = IMG.reshape(1, *IMG.shape).float() v = IMG def buildTransMatrix(n): core = torch.tensor([[0.5, 0.5], [-1, 1]]) gap = torch.zeros(2, n) return torch.cat([core if i % 2 == 0 else gap for i in range(n - 1)], -1).reshape(2, n // 2, n).permute([1, 0, 2]).reshape(n, n) depth = int(math.log(v.shape[-1], 2)) up = v blockSize = v.shape[-1] UR = [] DL = [] DR = [] for i in range(depth): transMatrix = buildTransMatrix(blockSize) for _ in range(2): up = torch.matmul(up, transMatrix.t()) up = up.permute([0, 1, 3, 2]) blockSize //= 2 _x = im2grp(up) ul = _x[:, :, :, 0].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() ur = _x[:, :, :, 1].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() dl = _x[:, :, :, 2].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() dr = _x[:, :, :, 3].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() UR.append(ur) DL.append(dl) DR.append(dr) up = ul ul = up for no in reversed(range(depth)): ur = UR[no].reshape(*ul.shape, 1) dl = DL[no].reshape(*ul.shape, 1) dr = DR[no].reshape(*ul.shape, 1) ul = ul.reshape(*ul.shape, 1) _x = torch.cat([ul, ur, dl, dr], -1).reshape(*ul.shape[:2], -1, 4) ul = grp2im(_x).contiguous() transV = ul ''' ul = ul UR = [] DL = [] DR = [] for _ in range(depth): _x = im2grp(ul) ul = _x[:, :, :, 0].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous() ur = _x[:, :, :, 1].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous() dl = _x[:, :, :, 2].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous() dr = _x[:, :, :, 3].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous() UR.append(renormFn(ur)) DL.append(renormFn(dl)) DR.append(renormFn(dr)) #ul = back01(backMeanStd(batchNorm(ul, 0))) ul = renormFn(ul) #ul = back01(clip(backMeanStd(batchNorm(ul)))) for no in reversed(range(depth)): ur = UR[no] dl = DL[no] dr = DR[no] upper = torch.cat([ul, ur], -1) down = torch.cat([dl, dr], -1) ul = torch.cat([upper, down], -2) # convert zremaoin to numpy array zremain = ul.permute([0, 2, 3, 1]).detach().cpu().numpy() waveletPlot = plt.figure(figsize=(8, 8)) waveletAx = waveletPlot.add_subplot(111) waveletAx.imshow(zremain[0]) plt.axis('off') plt.savefig('./testWavelet.pdf', bbox_inches="tight", pad_inches=0) plt.close() ''' initMethods = [] initMethods.append(lambda: harrInitMethod1(3)) initMethods.append(lambda: harrInitMethod2(3)) orders = [True, False] layerList = [] for j in range(2): layerList.append( buildWaveletLayers(initMethods[j], 3, 12, 1, orders[j])) shapeList2D = [3] + [12] * (1 + 1) + [3 * 3] shapeList1D = [3] + [12] * (1 + 1) + [3] def buildLayers2D(shapeList): layers = [] for no, chn in enumerate(shapeList[:-1]): if no != 0 and no != len(shapeList) - 2: layers.append(torch.nn.Conv2d(chn, shapeList[no + 1], 1)) else: layers.append( torch.nn.Conv2d(chn, shapeList[no + 1], 3, padding=1)) if no != len(shapeList) - 2: layers.append(torch.nn.ReLU(inplace=True)) return layers def buildLayers1D(shapeList): layers = [] for no, chn in enumerate(shapeList[:-1]): if no != 0 and no != len(shapeList) - 2: layers.append(torch.nn.Conv1d(chn, shapeList[no + 1], 1)) else: layers.append( torch.nn.Conv1d(chn, shapeList[no + 1], 3, padding=1)) if no != len(shapeList) - 2: layers.append(torch.nn.ReLU(inplace=True)) layers = torch.nn.Sequential(*layers) torch.nn.init.zeros_(layers[-1].weight) torch.nn.init.zeros_(layers[-1].bias) return layers # repeat, add one more layer of NICE for _ in range(2): layerList.append(buildLayers1D(shapeList1D)) meanNNlist = [] scaleNNlist = [] layers = buildLayers2D(shapeList2D) meanNNlist.append(torch.nn.Sequential(*layers)) layers = buildLayers2D(shapeList2D) scaleNNlist.append(torch.nn.Sequential(*layers)) torch.nn.init.zeros_(meanNNlist[-1][-1].weight) torch.nn.init.zeros_(meanNNlist[-1][-1].bias) torch.nn.init.zeros_(scaleNNlist[-1][-1].weight) torch.nn.init.zeros_(scaleNNlist[-1][-1].bias) f = flow.OneToTwoMERA(v.shape[-1], layerList, meanNNlist, scaleNNlist, repeat=2, depth=depth, nMixing=5, decimal=decimal, rounding=psudoRounding.forward) vpp = f.inverse(v)[0] assert_allclose(vpp.detach().numpy(), transV.detach().numpy()) # Test depth vp = IMG depth = 2 up = vp blockSize = v.shape[-1] UR = [] DL = [] DR = [] for i in range(depth): transMatrix = buildTransMatrix(blockSize) for _ in range(2): up = torch.matmul(up, transMatrix.t()) up = up.permute([0, 1, 3, 2]) blockSize //= 2 _x = im2grp(up) ul = _x[:, :, :, 0].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() ur = _x[:, :, :, 1].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() dl = _x[:, :, :, 2].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() dr = _x[:, :, :, 3].reshape(*_x.shape[:2], int(_x.shape[2]**0.5), int(_x.shape[2]**0.5)).contiguous() UR.append(ur) DL.append(dl) DR.append(dr) up = ul ul = up for no in reversed(range(depth)): ur = UR[no].reshape(*ul.shape, 1) dl = DL[no].reshape(*ul.shape, 1) dr = DR[no].reshape(*ul.shape, 1) ul = ul.reshape(*ul.shape, 1) _x = torch.cat([ul, ur, dl, dr], -1).reshape(*ul.shape[:2], -1, 4) ul = grp2im(_x).contiguous() transVp = ul fp = flow.OneToTwoMERA(8, layerList, meanNNlist, scaleNNlist, 2, depth, 5, decimal=decimal, rounding=psudoRounding.forward) vpp = fp.inverse(vp)[0] assert_allclose(vpp.detach().numpy(), transVp.detach().numpy()) '''
else: # load saved parameters, and decoding them to mem with open(rootFolder + "/parameter.json", 'r') as f: config = json.load(f) locals().update(config) # Building the target dataset if target == "CIFAR": # Define dimensions targetSize = [3, 32, 32] dimensional = 2 channel = targetSize[0] blockLength = targetSize[-1] # Define nomaliziation and decimal decimal = flow.ScalingNshifting(256, 0) rounding = utils.roundingWidentityGradient # Building train & test datasets lambd = lambda x: (x * 255).byte().to(torch.float32).to(device) trainsetTransform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambd)]) trainTarget = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True, transform=trainsetTransform) testTarget = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True, transform=trainsetTransform) targetTrainLoader = torch.utils.data.DataLoader(trainTarget, batch_size=batch, shuffle=True) targetTestLoader = torch.utils.data.DataLoader(testTarget, batch_size=batch, shuffle=False) elif target == "ImageNet32": # Define dimensions targetSize = [3, 32, 32] dimensional = 2 channel = targetSize[0] blockLength = targetSize[-1]
if HUE: lambd = lambda x: (x * 255).byte().to(torch.float32).to(device) else: lambd = lambda x: utils.rgb2ycc( (x * 255).byte().float(), True).to(torch.float32).to(device) # Building the target dataset if target == "CIFAR": # Define dimensions targetSize = [3, 32, 32] dimensional = 2 channel = targetSize[0] blockLength = targetSize[-1] # Define nomaliziation and decimal decimal = flow.ScalingNshifting(256, -128) rounding = utils.roundingWidentityGradient # Building train & test datasets trainsetTransform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambd) ]) trainTarget = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True, transform=trainsetTransform) testTarget = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True, transform=trainsetTransform)
def test_hierarchyPrior(): class UniTestPrior(source.Source): def __init__(self, nvars, element, name="UniTestPrior"): super(UniTestPrior, self).__init__(nvars, 1.0, name) self.element = torch.nn.Parameter(torch.tensor(element), requires_grad=False) def sample(self, batchSize): return torch.ones([batchSize] + self.nvars).to(self.element).float() * self.element def _energy(self, z): return (torch.tensor([2])**self.element * np.prod(z.shape[2:])) length = 32 channel = 3 decimal = flow.ScalingNshifting(256, -128) p1 = source.DiscreteLogistic([channel, 256, 3], decimal, rounding=utils.roundingWidentityGradient) p2 = source.DiscreteLogistic([channel, 64, 3], decimal, rounding=utils.roundingWidentityGradient) p3 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient) p4 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient) p5 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient) P = source.HierarchyPrior(channel, length, [p1, p2, p3, p4, p5], repeat=1) x = P.sample(100) logp = P.logProbability(x) import math zparts = [] for no in range(int(math.log(length, 2))): _, parts = utils.dispatch(P.factorOutIList[no], P.factorOutJList[no], x) zparts.append(parts) rcnX = torch.zeros_like(x) for no in range(int(math.log(length, 2))): part = zparts[no] rcnX = utils.collect(P.factorOutIList[no], P.factorOutJList[no], rcnX, part) assert_allclose(x.detach(), rcnX.detach()) length = 8 p1 = UniTestPrior([channel, 16, 3], 1) p2 = UniTestPrior([channel, 4, 3], 2) p3 = UniTestPrior([channel, 1, 4], 3) Pp = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2) x = Pp.sample(1) logp = Pp.logProbability(x) target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]]) assert_allclose(x[0, 0].detach().numpy(), target) assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3) p1 = UniTestPrior([channel, 16, 3], 1) p2 = UniTestPrior([channel, 4, 3], 2) p3 = UniTestPrior([channel, 1, 4], 3) Ppodd = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=1) x = Ppodd.sample(1) logp = Ppodd.logProbability(x) target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]]) assert_allclose(x[0, 0].detach().numpy(), target) assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3)
def test_saveload(): decimal = flow.ScalingNshifting(256, -128) layerList = [] for i in range(4): f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 3, 3, padding=1)) layerList.append(f) meanNNlist = [] scaleNNlist = [] meanNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) scaleNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 1, None, 5, decimal, utils.roundingWidentityGradient) decimal = flow.ScalingNshifting(256, -128) layerList = [] for i in range(4): f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 3, 3, padding=1)) layerList.append(f) meanNNlist = [] scaleNNlist = [] meanNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) scaleNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) tt = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 1, None, 5, decimal, utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() torch.save(t.save(), "testsaving.saving") tt.load(torch.load("testsaving.saving")) tzSamples, _ = t.inverse(samples) ttzSamples, _ = tt.inverse(samples) rcnSamples, _ = t.forward(tzSamples) ttrcnSamples, _ = tt.forward(ttzSamples) assert_allclose(tzSamples.detach().numpy(), ttzSamples.detach().numpy()) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy()) assert_allclose(rcnSamples.detach().numpy(), ttrcnSamples.detach().numpy())