def test_bijective(): decimal = flow.ScalingNshifting(256, -128) layerList = [] for i in range(4 * 2): f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 3, 3, padding=1)) layerList.append(f) meanNNlist = [] scaleNNlist = [] meanNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) scaleNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 2, None, 5, decimal, utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() zSamples, _ = t.inverse(samples) rcnSamples, _ = t.forward(zSamples) prob = t.logProbability(samples) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy()) # Test the depth argument t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 2, 2, 5, decimal, utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() zSamples, _ = t.inverse(samples) rcnSamples, _ = t.forward(zSamples) #prob = t.logProbability(samples) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy())
#scaleNNlist.append(torch.nn.Sequential(torch.nn.Conv2d(3, hchnl, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(hchnl, hchnl, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(hchnl, 9, 3, padding=1))) torch.nn.init.zeros_(meanNNlist[-1][-1].weight) torch.nn.init.zeros_(meanNNlist[-1][-1].bias) torch.nn.init.zeros_(scaleNNlist[-1][-1].weight) torch.nn.init.zeros_(scaleNNlist[-1][-1].bias) else: meanNNlist = None scaleNNlist = None # Building MERA model f = flow.SimpleMERA(blockLength, layerList, meanNNlist, scaleNNlist, repeat, None, nMixing, decimal=decimal, rounding=utils.roundingWidentityGradient, clamp=clamp, sameDetail=diffDetail).to(device) # Define plot function def plotfn(f, train, test, LOSS, VALLOSS): # loss plot lossfig = plt.figure(figsize=(8, 5)) lossax = lossfig.add_subplot(111) epoch = len(LOSS) lossax.plot(np.arange(epoch),
torch.nn.init.zeros_(scaleNNlist[-1][-1].weight) torch.nn.init.zeros_(scaleNNlist[-1][-1].bias) meanNNlist = meanNNlist * depth scaleNNlist = scaleNNlist * depth else: meanNNlist = None scaleNNlist = None # Building MERA model f = flow.SimpleMERA(blockLength, layerList, meanNNlist, scaleNNlist, repeat, None, nMixing, decimal=decimal, rounding=utils.roundingWidentityGradient, clamp=clamp, compatible=True).to(device) # Define plot function def plotfn(f, train, test, LOSS, VALLOSS): # loss plot lossfig = plt.figure(figsize=(8, 5)) lossax = lossfig.add_subplot(111) epoch = len(LOSS) lossax.plot(np.arange(epoch),
def test_saveload(): decimal = flow.ScalingNshifting(256, -128) layerList = [] for i in range(4): f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 3, 3, padding=1)) layerList.append(f) meanNNlist = [] scaleNNlist = [] meanNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) scaleNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 1, None, 5, decimal, utils.roundingWidentityGradient) decimal = flow.ScalingNshifting(256, -128) layerList = [] for i in range(4): f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 3, 3, padding=1)) layerList.append(f) meanNNlist = [] scaleNNlist = [] meanNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) scaleNNlist.append( torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(9, 9, 1, padding=0), torch.nn.ReLU(inplace=True))) tt = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 1, None, 5, decimal, utils.roundingWidentityGradient) samples = torch.randint(0, 255, (100, 3, 8, 8)).float() torch.save(t.save(), "testsaving.saving") tt.load(torch.load("testsaving.saving")) tzSamples, _ = t.inverse(samples) ttzSamples, _ = tt.inverse(samples) rcnSamples, _ = t.forward(tzSamples) ttrcnSamples, _ = tt.forward(ttzSamples) assert_allclose(tzSamples.detach().numpy(), ttzSamples.detach().numpy()) assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy()) assert_allclose(rcnSamples.detach().numpy(), ttrcnSamples.detach().numpy())
prior = f.prior prior.depth = int(math.log(targetSize[-1], 2)) if 'simplePrior_False' in name: pass else: prior.priorList = torch.nn.ModuleList([ prior.priorList[0] for _ in range(int(math.log(targetSize[-1], 2)) - 1) ] + [prior.priorList[-1]]) # Building MERA mode if 'easyMera' in name: f = flow.SimpleMERA( blockLength, layerList, meanNNlist, scaleNNlist, repeat, 1, nMixing, decimal=decimal, rounding=utils.roundingWidentityGradient).to(device) elif '1to2Mera' in name: f = flow.OneToTwoMERA( blockLength, layerList, meanNNlist, scaleNNlist, repeat, 1, nMixing, decimal=decimal, rounding=utils.roundingWidentityGradient).to(device)