def main(args): scale_network = lambda: nn.Sequential( nn.Linear(args.hidden_dim, args.hidden_size), nn.LeakyReLU(), nn.Linear(args.hidden_size, args.hidden_size), nn.LeakyReLU(), nn.Linear(args.hidden_size, args.hidden_dim), nn.Tanh()) translation_network = lambda: nn.Sequential( nn.Linear(args.hidden_dim, args.hidden_size), nn.LeakyReLU(), nn.Linear(args.hidden_size, args.hidden_size), nn.LeakyReLU(), nn.Linear(args.hidden_size, args.hidden_dim)) masks = torch.tensor([1, 0] * args.num_layers, dtype=torch.float) masks = torch.stack((masks, 1 - masks), dim=1) prior = torch.distributions.MultivariateNormal( torch.zeros(args.hidden_dim), torch.eye(args.hidden_dim)) flow = RealNVP(scale_network, translation_network, masks, prior) optimizer = optim.Adam( filter(lambda x: x.requires_grad == True, flow.parameters())) for epoch in range(args.num_epoch): loss = run_epoch(flow, optimizer, args.batch_size) if epoch % args.log_train_step == 0: logging.info(" Epoch: {} | Loss: {}".format(epoch, loss)) test(flow)
def test_parallel(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) z = realNVP(x) print(z) net = torch.nn.DataParallel(realNVP.cuda(0), device_ids=[0, 1]) output = net(x.cuda()) print(output) assert_array_almost_equal(z.data.numpy(), output.cpu().data.numpy(), decimal=5)
def test_sample(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d, "checkerboard") z3d = realNVP3d.sample(100, True) zp3d = realNVP3d.sample(100, False) print(realNVP3d.logProbability(z3d))
def test_invertible_2d(): #RNVP block Nlayers = 4 Hs = 10 Ht = 10 #sList = [MLPreshape(4, Hs) for _ in range(Nlayers)] #tList = [MLPreshape(4, Ht) for _ in range(Nlayers)] sList = [MLPreshape(4, Hs) for _ in range(Nlayers)] tList = [MLPreshape(4, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([4, 4]) layers = [ RealNVP([2, 2], sList, tList, Gaussian([2, 2]), masktypelist) for _ in range(4) ] print(layers[0]) print(layers[0].generate( Variable(torch.FloatTensor([1, 2, 3, 4]).view(1, 2, 2)))) model = TEBD(2, [2, 2], 4, layers, prior) z = model.prior(1) print("original") x = model.generate(z) print("Forward") zp = model.inference(x) print("Backward") assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) saveDict = model.saveModel({}) torch.save(saveDict, './saveNet.testSave') sListp = [MLPreshape(4, Hs) for _ in range(Nlayers)] tListp = [MLPreshape(4, Ht) for _ in range(Nlayers)] masktypelistp = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer priorp = Gaussian([4, 4]) layersp = [ RealNVP([2, 2], sList, tList, Gaussian([2, 2]), masktypelist) for _ in range(4) ] modelp = TEBD(2, [2, 2], 4, layersp, priorp) saveDictp = torch.load('./saveNet.testSave') modelp.loadModel(saveDictp) xp = modelp.generate(z) assert_array_almost_equal(xp.data.numpy(), x.data.numpy())
def test_invertible(): #RNVP block Nlayers = 4 Hs = 10 Ht = 10 sList = [MLP(2, Hs) for _ in range(Nlayers)] tList = [MLP(2, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8]) layers = [ RealNVP([2], sList, tList, Gaussian([2]), masktypelist) for _ in range(4) ] model = TEBD(1, 2, 4, layers, prior) z = model.prior(10) print("original") x = model.generate(z, ifLogjac=True) print("Forward") zp = model.inference(x, ifLogjac=True) print("Backward") assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) assert_array_almost_equal(model._generateLogjac.data.numpy(), -model._inferenceLogjac.data.numpy()) saveDict = model.saveModel({}) torch.save(saveDict, './saveNet.testSave') sListp = [MLP(2, Hs) for _ in range(Nlayers)] tListp = [MLP(2, Ht) for _ in range(Nlayers)] masktypelistp = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer priorp = Gaussian([8]) layersp = [ RealNVP([2], sListp, tListp, Gaussian([2]), masktypelistp) for _ in range(4) ] modelp = TEBD(1, 2, 4, layersp, priorp) saveDictp = torch.load('./saveNet.testSave') modelp.loadModel(saveDictp) xp = modelp.generate(z) assert_array_almost_equal(xp.data.numpy(), x.data.numpy())
def test_invertible_2d(): Nlayers = 4 Hs = 10 Ht = 10 sList = [MLPreshape(4, Hs) for _ in range(Nlayers)] tList = [MLPreshape(4, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'vchannel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8, 8]) layers = [ RealNVP([2, 2], sList, tList, Gaussian([2, 2]), masktypelist) for _ in range(6) ] print(layers[0].mask) #layers = [debugRealNVP() for _ in range(6)] model = MERA(2, [2, 2], 64, layers, prior) #z = prior(1) z = Variable(torch.from_numpy(np.arange(64)).float().view(1, 8, 8)) x = model.generate(z) zz = model.inference(x) print(zz) print(z) assert_array_almost_equal( z.data.numpy(), zz.data.numpy(), decimal=4) # don't work for decimal >=5, maybe caz by float saveDict = model.saveModel({}) torch.save(saveDict, './saveNet.testSave') Nlayersp = 4 Hsp = 10 Htp = 10 sListp = [MLPreshape(4, Hsp) for _ in range(Nlayersp)] tListp = [MLPreshape(4, Htp) for _ in range(Nlayersp)] masktypelistp = ['channel', 'vchannel'] * (Nlayersp // 2) #assamble RNVP blocks into a TEBD layer priorp = Gaussian([8, 8]) layersp = [ RealNVP([2, 2], sListp, tListp, Gaussian([2, 2]), masktypelistp) for _ in range(6) ] modelp = MERA(2, [2, 2], 64, layersp, priorp) saveDictp = torch.load('./saveNet.testSave') modelp.loadModel(saveDictp) xp = modelp.generate(z) assert_array_almost_equal(xp.data.numpy(), x.data.numpy())
def train_and_eval(epochs, lr, train_loader, test_loader, target_distribution): transforms = [ AffineTransform2D(True), AffineTransform2D(False), AffineTransform2D(True), AffineTransform2D(False) ] flow = RealNVP(transforms) optimizer = torch.optim.Adam(flow.parameters(), lr=lr) train_losses, test_losses = [], [] for epoch in range(epochs): train(flow, train_loader, optimizer, target_distribution) train_losses.append(eval_loss(flow, train_loader, target_distribution)) test_losses.append(eval_loss(flow, test_loader, target_distribution)) return flow, train_losses, test_losses
def test_translationalinvariance_1d(): Nlayers = 2 Hs = 10 Ht = 10 sList = [MLP(2, Hs, activation=ScalableTanh([2])) for _ in range(Nlayers)] tList = [MLP(2, Ht, activation=ScalableTanh([2])) for _ in range(Nlayers)] masktypelist = ['evenodd', 'evenodd'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8]) layers = [ RealNVP([2], sList, tList, Gaussian([2]), masktypelist) for _ in range(6) ] model = MERA(1, 2, 8, layers, prior) x = model.sample(10) xright = Roll(4, 1).forward(x) xleft = Roll(-4, 1).forward(x) logp = model.logProbability(x) assert_array_almost_equal(logp.data.numpy(), model.logProbability(xleft).data.numpy(), decimal=6) assert_array_almost_equal(logp.data.numpy(), model.logProbability(xright).data.numpy(), decimal=6)
def test_translationalinvariance(): #RNVP block Depth = 8 Nlayers = 2 Hs = 10 Ht = 10 sList = [MLP(2, Hs) for _ in range(Nlayers)] tList = [MLP(2, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8]) layers = [ RealNVP([2], sList, tList, Gaussian([2]), masktypelist) for _ in range(Depth) ] model = TEBD(1, 2, Depth, layers, prior) x = model.sample(10) xright = Roll(2, 1).forward(x) xleft = Roll(-2, 1).forward(x) logp = model.logProbability(x) assert_array_almost_equal(logp.data.numpy(), model.logProbability(xleft).data.numpy(), decimal=6) assert_array_almost_equal(logp.data.numpy(), model.logProbability(xright).data.numpy(), decimal=6)
def test_invertible_1d(): Nlayers = 4 Hs = 10 Ht = 10 sList = [MLP(2, Hs) for _ in range(Nlayers)] tList = [MLP(2, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8]) layers = [ RealNVP([2], sList, tList, Gaussian([2]), masktypelist) for _ in range(6) ] model = MERA(1, 2, 8, layers, prior) z = prior(4) x = model.generate(z, ifLogjac=True) zz = model.inference(x, ifLogjac=True) assert_array_almost_equal(z.data.numpy(), zz.data.numpy()) print(model._generateLogjac) print(model._inferenceLogjac) assert_array_almost_equal(model._generateLogjac.data.numpy(), -model._inferenceLogjac.data.numpy()) saveDict = model.saveModel({}) torch.save(saveDict, './saveNet.testSave') Nlayersp = 4 Hsp = 10 Htp = 10 sListp = [MLP(2, Hsp) for _ in range(Nlayersp)] tListp = [MLP(2, Htp) for _ in range(Nlayersp)] masktypelistp = ['channel', 'channel'] * (Nlayersp // 2) #assamble RNVP blocks into a TEBD layer priorp = Gaussian([8]) layersp = [ RealNVP([2], sListp, tListp, Gaussian([2]), masktypelistp) for _ in range(6) ] modelp = MERA(1, 2, 8, layersp, priorp) saveDictp = torch.load('./saveNet.testSave') modelp.loadModel(saveDictp) xp = modelp.generate(z) assert_array_almost_equal(xp.data.numpy(), x.data.numpy())
def test_tempalte_invertibleMLP(): print("test mlp") gaussian = Gaussian([2]) sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVP = RealNVP([2], sList, tList, gaussian) x = realNVP.prior(10) mask = realNVP.createMask(["channel"] * 4, ifByte=0) print("original") #print(x) z = realNVP._generate(x, realNVP.mask, realNVP.mask_, True) print("Forward") #print(z) zp = realNVP._inference(z, realNVP.mask, realNVP.mask_, True) print("Backward") #print(zp) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy()) print("logProbability") print(realNVP._logProbability(z, realNVP.mask, realNVP.mask_)) assert_array_almost_equal(x.data.numpy(), zp.data.numpy())
def test_tempalte_contractionCNN_checkerboard_cuda(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3).cuda() netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [2, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) realNVP3d = realNVP3d.cuda() mask3d = realNVP3d.createMask(["checkerboard"] * 4, ifByte=0, cuda=0) z3d = realNVP3d._generate(x3d, realNVP3d.mask, realNVP3d.mask_, True) zp3d = realNVP3d._inference(z3d, realNVP3d.mask, realNVP3d.mask_, True) print(realNVP3d._logProbability(z3d, realNVP3d.mask, realNVP3d.mask_)) assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy()) assert_array_almost_equal(realNVP3d._generateLogjac.data.cpu().numpy(), -realNVP3d._inferenceLogjac.data.cpu().numpy())
def test_tempalte_contraction_mlp(): gaussian = Gaussian([2]) sList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVP = RealNVP([2], sList, tList, gaussian) x = realNVP.prior(10) mask = realNVP.createMask(["channel"] * 4, ifByte=1) print("original") #print(x) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 0, True) print("Forward") #print(z) zp = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 0, True) print("Backward") #print(zp) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy()) x_data = realNVP.prior(10) y_data = realNVP.prior.logProbability(x_data) print("logProbability") '''
def test_slice_cudaNo0(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3).cuda(2) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure), CNN(netStructure), CNN(netStructure), CNN(netStructure) ] tList3d = [ CNN(netStructure), CNN(netStructure), CNN(netStructure), CNN(netStructure) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) realNVP = realNVP.cuda(2) z = realNVP._generateWithSlice(x, 0, True) print(realNVP._logProbabilityWithSlice(z, 0)) zz = realNVP._inferenceWithSlice(z, 0, True) assert_array_almost_equal(x.cpu().data.numpy(), zz.cpu().data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.cpu().numpy(), -realNVP._inferenceLogjac.data.cpu().numpy())
def test_logProbabilityWithInference_cuda(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3).cuda() netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d).cuda() mask3d = realNVP3d.createMask(["checkerboard"] * 4, cuda=0) z3d = realNVP3d.generate(x3d) zp3d = realNVP3d.inference(z3d) print(realNVP3d.logProbabilityWithInference(z3d)[1]) assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy())
def test_multiplyMask_generateWithSlice_CNN(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure), CNN(netStructure), CNN(netStructure), CNN(netStructure) ] tList3d = [ CNN(netStructure), CNN(netStructure), CNN(netStructure), CNN(netStructure) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) mask = realNVP.createMask( ["channel", "checkerboard", "channel", "checkerboard"], ifByte=1) z = realNVP._generateWithSlice(x, 0, True) #print(z) zz = realNVP._inferenceWithSlice(z, 0, True) #print(zz) assert_array_almost_equal(x.data.numpy(), zz.data.numpy()) #print(realNVP._generateLogjac.data.numpy()) #print(realNVP._inferenceLogjac.data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy())
def test_template_contraction_function_with_channel(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) mask = realNVP.createMask(["channel"] * 4, 1) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 2, True) #print(z) zz = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 2, True) #print(zz) assert_array_almost_equal(x.data.numpy(), zz.data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy())
def test_template_slice_function(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure), CNN(netStructure), CNN(netStructure), CNN(netStructure) ] tList3d = [ CNN(netStructure), CNN(netStructure), CNN(netStructure), CNN(netStructure) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) z = realNVP._generateWithSlice(x, 0, True) #print(z) zz = realNVP._inferenceWithSlice(z, 0, True) #print(zz) assert_array_almost_equal(x.data.numpy(), zz.data.numpy()) #print(realNVP._generateLogjac.data.numpy()) #print(realNVP._inferenceLogjac.data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy())
def test_forward(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) z = realNVP(x) assert (list(z.data.shape) == [3]) #assert(z.shape ==) realNVP.pointer = "generate" z = realNVP(x) assert (list(z.data.shape) == [3, 2, 4, 4])
def train(param, x, y): dim_in = x.shape[1] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(device) dataloader = DataLoader(torch.from_numpy(x.astype(np.float32)), batch_size=param.batch_size, shuffle=True, num_workers=2) flow = RealNVP(dim_in, device) flow.to(device) flow.train() optimizer = torch.optim.Adam( [p for p in flow.parameters() if p.requires_grad == True], lr=param.lr) it, print_cnt = 0, 0 while it < param.total_it: for i, data in enumerate(dataloader): loss = -flow.log_prob(data.to(device)).mean() optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() it += data.shape[0] print_cnt += data.shape[0] if print_cnt > PRINT_FREQ: print('it {:d} -- loss {:.03f}'.format(it, loss)) print_cnt = 0 torch.save(flow.state_dict(), 'flow_model.pytorch')
def test_invertible_2d_metaDepth3(): Nlayers = 4 Hs = 10 Ht = 10 sList = [MLPreshape(4, Hs) for _ in range(Nlayers)] tList = [MLPreshape(4, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8, 8]) layers = [ RealNVP([2, 2], sList, tList, Gaussian([2, 2]), masktypelist) for _ in range(9) ] #layers = [debugRealNVP() for _ in range(6)] model = MERA(2, [2, 2], 64, layers, prior, metaDepth=3) z = Variable(torch.from_numpy(np.arange(64)).float().view(1, 8, 8)) x = model.generate(z) zz = model.inference(x) assert_array_almost_equal( z.data.numpy(), zz.data.numpy(), decimal=4) # don't work for decimal >=5, maybe caz by float
def test_contraction_cuda_withDifferentMasks(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3).cuda() #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) realNVP = realNVP.cuda() mask = realNVP.createMask( ["channel", "checkerboard", "channel", "checkerboard"], 1, cuda=0) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 2, True) print( realNVP._logProbabilityWithContraction(z, realNVP.mask, realNVP.mask_, 2)) zz = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 2, True) assert_array_almost_equal(x.cpu().data.numpy(), zz.cpu().data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.cpu().numpy(), -realNVP._inferenceLogjac.data.cpu().numpy())
depth = int(math.log(Nvars, mlpsize)) print('depth of the mera network', depth) sList = [[ MLPreshape(mlpsize, Hs, activation=ScalableTanh([mlpsize])) for _ in range(Nlayers) ] for l in range(nperdepth * depth)] tList = [[MLPreshape(mlpsize, Ht) for _ in range(Nlayers)] for l in range(nperdepth * depth)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) prior = Gaussian([L, L]) #assamble RNVP blocks into a MERA layers = [ RealNVP(kernel_size, sList[l], tList[l], None, masktypelist) for l in range(nperdepth * depth) ] model = MERA(d, kernel_size, Nvars, layers, prior, metaDepth=Ndisentangler + 1) model.loadModel(torch.load(args.modelname)) z = prior(args.batch) x = model.generate(z, save=True) N = len(model.saving) // (Ndisentangler + 1) import matplotlib.pyplot as plt from matplotlib import cm
depth = int(math.log(Nvars,mlpsize)) print ('depth of the mera network', depth) sList = [[MLPreshape(mlpsize, args.Hs, activation=ScalableTanh([mlpsize])) for _ in range(args.Nlayers)] for l in range(nperdepth*depth)] tList = [[MLPreshape(mlpsize, args.Ht) for _ in range(args.Nlayers)] for l in range(nperdepth*depth)] masktypelist = ['channel', 'channel'] * (args.Nlayers//2) #assamble RNVP blocks into a MERA layers = [RealNVP(kernel_size, sList[l], tList[l], None, masktypelist) for l in range(nperdepth*depth)] model = MERA(args.d, kernel_size, Nvars, layers, prior, metaDepth =args.Ndisentangler+1, name=key) if args.modelname is not None: try: model.loadModel(torch.load(args.modelname)) print('#load model', args.modelname) except FileNotFoundError: print('model file not found:', args.modelname) if args.cuda: model = model.cuda() print("moving model to GPU")
def test_invertible(): print("test realNVP") gaussian = Gaussian([2]) sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVP = RealNVP([2], sList, tList, gaussian) print(realNVP.mask) print(realNVP.mask_) z = realNVP.prior(10) #mask = realNVP.createMask() assert realNVP.mask.shape[0] == 4 assert realNVP.mask.shape[1] == 2 print("original") #print(x) x = realNVP.generate(z) print("Forward") #print(z) zp = realNVP.inference(x) print("Backward") #print(zp) assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) saveDict = realNVP.saveModel({}) torch.save(saveDict, './saveNet.testSave') # realNVP.loadModel({}) sListp = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tListp = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVPp = RealNVP([2], sListp, tListp, gaussian) saveDictp = torch.load('./saveNet.testSave') realNVPp.loadModel(saveDictp) xx = realNVP.generate(z) print("Forward after restore") assert_array_almost_equal(xx.data.numpy(), x.data.numpy())
def train_and_eval(flow, epochs, lr, train_loader, test_loader, target_distribution): print('no of parameters is', sum(param.numel() for param in flow.parameters())) optimizer = torch.optim.Adam(flow.parameters(), lr=lr) train_losses, test_losses = [], [] for epoch in range(epochs): print('Starting epoch:', epoch + 1, 'of', epochs) train(flow, train_loader, optimizer, target_distribution) train_losses.append(eval_loss(flow, train_loader, target_distribution)) test_losses.append(eval_loss(flow, test_loader, target_distribution)) return flow, train_losses, test_losses if __name__ == '__main__': print('Device is:', device) from torch.distributions.normal import Normal import numpy as np flow = RealNVP(INPUT_H, INPUT_W).to(device) target_distribution = Normal( torch.tensor(0).float().to(device), torch.tensor(1).float().to(device)) flow, train_losses, test_losses = train_and_eval(flow, 100, 5e-4, train_loader, test_loader, target_distribution) print('train losses are', train_losses) print('test losses are', test_losses) torch.save(flow.state_dict(), 'trained_weights.pt')
def test_3d(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) #,maskType = "checkerboard") print(realNVP3d.mask) #mask3d = realNVP3d.createMask() assert realNVP3d.mask.shape[0] == 4 assert realNVP3d.mask.shape[1] == 2 assert realNVP3d.mask.shape[2] == 4 assert realNVP3d.mask.shape[3] == 4 print("test high dims") print("Testing 3d") print("3d original:") #print(x3d) z3d = realNVP3d.generate(x3d) print("3d forward:") #print(z3d) zp3d = realNVP3d.inference(z3d) print("Backward") #print(zp3d) print("3d logProbability") print(realNVP3d.logProbability(z3d)) saveDict3d = realNVP3d.saveModel({}) torch.save(saveDict3d, './saveNet3d.testSave') # realNVP.loadModel({}) sListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d) saveDictp3d = torch.load('./saveNet3d.testSave') realNVPp3d.loadModel(saveDictp3d) zz3d = realNVPp3d.generate(x3d) print("3d Forward after restore") #print(zz3d) assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy()) assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {} train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, **kwargs) # Set figures plt.subplots(nrows=2, ncols=2) plt.subplots_adjust(hspace=0.5, wspace=0.3) plt.subplot(2, 2, 1) plt.scatter(input[:, 0], input[:, 1], c='b', s=10) plt.title("INPUT: x ~ p(x)") # Set model mask = torch.tensor([0.0, 1.0]) model = RealNVP(INPUT_DIM, OUTPUT_DIM, HID_DIM, mask, 8) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) log_gaussian = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)) def train(args): "Forward flow data for construction normalization distribution" model.train() for epoch in range(args.epochs): train_loss = 0.0 for i, data in enumerate(train_loader): optimizer.zero_grad() z, log_det_sum = model(data)
def test_mera_1d(): masks = [ Variable(torch.ByteTensor([1, 0, 1, 0, 1, 0, 1, 0])), Variable(torch.ByteTensor([1, 0, 0, 0, 1, 0, 0, 0])) ] masks_ = [ Variable(torch.ByteTensor([0, 1, 0, 1, 0, 1, 0, 1])), Variable(torch.ByteTensor([0, 1, 1, 1, 0, 1, 1, 1])) ] rollList = [ Placeholder(), Roll(1, 1), Placeholder(), Roll(1, 1), Placeholder(), Roll(1, 1) ] maskList = [ Placeholder(2), Placeholder(2), Mask(masks[0], masks_[0]), Mask(masks[0], masks_[0]), Mask(masks[1], masks_[1]), Mask(masks[1], masks_[1]) ] Nlayers = 4 Hs = 10 Ht = 10 sList = [MLP(2, Hs) for _ in range(Nlayers)] tList = [MLP(2, Ht) for _ in range(Nlayers)] masktypelist = ['channel', 'channel'] * (Nlayers // 2) #assamble RNVP blocks into a TEBD layer prior = Gaussian([8]) layers = [ RealNVP([2], sList, tList, Gaussian([2]), masktypelist) for _ in range(6) ] model = HierarchyBijector(1, [2 for _ in range(6)], rollList, layers, maskList, Gaussian([8])) z = prior(4) print(z) x = model.inference(z, True) print(x) fLog = model._inferenceLogjac print(model._inferenceLogjac) zz = model.generate(x, True) print(zz) bLog = model._generateLogjac print(model._generateLogjac) print(model.sample(10)) saveDict = model.saveModel({}) torch.save(saveDict, './savetest.testSave') masksp = [ Variable(torch.ByteTensor([1, 0, 1, 0, 1, 0, 1, 0])), Variable(torch.ByteTensor([1, 0, 0, 0, 1, 0, 0, 0])) ] masks_p = [ Variable(torch.ByteTensor([0, 1, 0, 1, 0, 1, 0, 1])), Variable(torch.ByteTensor([0, 1, 1, 1, 0, 1, 1, 1])) ] rollListp = [ Placeholder(), Roll(1, 1), Placeholder(), Roll(1, 1), Placeholder(), Roll(1, 1) ] maskListp = [ Placeholder(2), Placeholder(2), Mask(masksp[0], masks_p[0]), Mask(masksp[0], masks_p[0]), Mask(masksp[1], masks_p[1]), Mask(masksp[1], masks_p[1]) ] Nlayersp = 4 Hsp = 10 Htp = 10 sListp = [MLP(2, Hsp) for _ in range(Nlayersp)] tListp = [MLP(2, Htp) for _ in range(Nlayersp)] masktypelistp = ['channel', 'channel'] * (Nlayersp // 2) #assamble RNVP blocks into a TEBD layer priorp = Gaussian([8]) layersp = [ RealNVP([2], sListp, tListp, Gaussian([2]), masktypelistp) for _ in range(6) ] modelp = HierarchyBijector(1, [2 for _ in range(6)], rollListp, layersp, maskListp, Gaussian([8])) saveDictp = torch.load('./savetest.testSave') modelp.loadModel(saveDictp) xp = modelp.inference(z) print(xp) assert_array_almost_equal(z.data.numpy(), zz.data.numpy()) assert_array_almost_equal(fLog.data.numpy(), -bLog.data.numpy()) assert_array_almost_equal(xp.data.numpy(), x.data.numpy())
def test_tempalte_invertibleCNN(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [2, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) mask3d = realNVP3d.createMask(["channel"] * 4, ifByte=0) print("Testing 3d") print("3d original:") #print(x3d) z3d = realNVP3d._generate(x3d, realNVP3d.mask, realNVP3d.mask_, True) print("3d forward:") #print(z3d) zp3d = realNVP3d._inference(z3d, realNVP3d.mask, realNVP3d.mask_, True) print("Backward") #print(zp3d) assert_array_almost_equal(realNVP3d._generateLogjac.data.numpy(), -realNVP3d._inferenceLogjac.data.numpy()) print("3d logProbability") print(realNVP3d._logProbability(z3d, realNVP3d.mask, realNVP3d.mask_)) saveDict3d = realNVP3d.saveModel({}) torch.save(saveDict3d, './saveNet3d.testSave') # realNVP.loadModel({}) sListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d) saveDictp3d = torch.load('./saveNet3d.testSave') realNVPp3d.loadModel(saveDictp3d) zz3d = realNVPp3d._generate(x3d, realNVPp3d.mask, realNVPp3d.mask_) print("3d Forward after restore") #print(zz3d) assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy()) assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())