def test_workmode2(): gaussian = Gaussian([2]) sList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVP = RealNVP([2], sList, tList, gaussian, mode=2) z = realNVP.prior(10) x = realNVP.generate(z, sliceDim=0) zp = realNVP.inference(x, sliceDim=0) assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) saveDict = realNVP.saveModel({}) torch.save(saveDict, './saveNet.testSave') # realNVP.loadModel({}) sListp = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tListp = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVPp = RealNVP([2], sListp, tListp, gaussian) saveDictp = torch.load('./saveNet.testSave') realNVPp.loadModel(saveDictp) xx = realNVP.generate(z, sliceDim=0) print("Forward after restore") assert_array_almost_equal(xx.data.numpy(), x.data.numpy())
def test_logProbabilityWithInference_cuda(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3).cuda() netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d).cuda() mask3d = realNVP3d.createMask(["checkerboard"] * 4, cuda=0) z3d = realNVP3d.generate(x3d) zp3d = realNVP3d.inference(z3d) print(realNVP3d.logProbabilityWithInference(z3d)[1]) assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy())
def test_invertible(): print("test realNVP") gaussian = Gaussian([2]) sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVP = RealNVP([2], sList, tList, gaussian) print(realNVP.mask) print(realNVP.mask_) z = realNVP.prior(10) #mask = realNVP.createMask() assert realNVP.mask.shape[0] == 4 assert realNVP.mask.shape[1] == 2 print("original") #print(x) x = realNVP.generate(z) print("Forward") #print(z) zp = realNVP.inference(x) print("Backward") #print(zp) assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) saveDict = realNVP.saveModel({}) torch.save(saveDict, './saveNet.testSave') # realNVP.loadModel({}) sListp = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tListp = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVPp = RealNVP([2], sListp, tListp, gaussian) saveDictp = torch.load('./saveNet.testSave') realNVPp.loadModel(saveDictp) xx = realNVP.generate(z) print("Forward after restore") assert_array_almost_equal(xx.data.numpy(), x.data.numpy())
def test_3d(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) #,maskType = "checkerboard") print(realNVP3d.mask) #mask3d = realNVP3d.createMask() assert realNVP3d.mask.shape[0] == 4 assert realNVP3d.mask.shape[1] == 2 assert realNVP3d.mask.shape[2] == 4 assert realNVP3d.mask.shape[3] == 4 print("test high dims") print("Testing 3d") print("3d original:") #print(x3d) z3d = realNVP3d.generate(x3d) print("3d forward:") #print(z3d) zp3d = realNVP3d.inference(z3d) print("Backward") #print(zp3d) print("3d logProbability") print(realNVP3d.logProbability(z3d)) saveDict3d = realNVP3d.saveModel({}) torch.save(saveDict3d, './saveNet3d.testSave') # realNVP.loadModel({}) sListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d) saveDictp3d = torch.load('./saveNet3d.testSave') realNVPp3d.loadModel(saveDictp3d) zz3d = realNVPp3d.generate(x3d) print("3d Forward after restore") #print(zz3d) assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy()) assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())