def test_tempalte_invertibleMLP(): print("test mlp") gaussian = Gaussian([2]) sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVP = RealNVP([2], sList, tList, gaussian) x = realNVP.prior(10) mask = realNVP.createMask(["channel"] * 4, ifByte=0) print("original") #print(x) z = realNVP._generate(x, realNVP.mask, realNVP.mask_, True) print("Forward") #print(z) zp = realNVP._inference(z, realNVP.mask, realNVP.mask_, True) print("Backward") #print(zp) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy()) print("logProbability") print(realNVP._logProbability(z, realNVP.mask, realNVP.mask_)) assert_array_almost_equal(x.data.numpy(), zp.data.numpy())
def test_tempalte_contractionCNN_checkerboard_cuda(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3).cuda() netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [2, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) realNVP3d = realNVP3d.cuda() mask3d = realNVP3d.createMask(["checkerboard"] * 4, ifByte=0, cuda=0) z3d = realNVP3d._generate(x3d, realNVP3d.mask, realNVP3d.mask_, True) zp3d = realNVP3d._inference(z3d, realNVP3d.mask, realNVP3d.mask_, True) print(realNVP3d._logProbability(z3d, realNVP3d.mask, realNVP3d.mask_)) assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy()) assert_array_almost_equal(realNVP3d._generateLogjac.data.cpu().numpy(), -realNVP3d._inferenceLogjac.data.cpu().numpy())
def test_tempalte_invertibleCNN(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [2, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) mask3d = realNVP3d.createMask(["channel"] * 4, ifByte=0) print("Testing 3d") print("3d original:") #print(x3d) z3d = realNVP3d._generate(x3d, realNVP3d.mask, realNVP3d.mask_, True) print("3d forward:") #print(z3d) zp3d = realNVP3d._inference(z3d, realNVP3d.mask, realNVP3d.mask_, True) print("Backward") #print(zp3d) assert_array_almost_equal(realNVP3d._generateLogjac.data.numpy(), -realNVP3d._inferenceLogjac.data.numpy()) print("3d logProbability") print(realNVP3d._logProbability(z3d, realNVP3d.mask, realNVP3d.mask_)) saveDict3d = realNVP3d.saveModel({}) torch.save(saveDict3d, './saveNet3d.testSave') # realNVP.loadModel({}) sListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d) saveDictp3d = torch.load('./saveNet3d.testSave') realNVPp3d.loadModel(saveDictp3d) zz3d = realNVPp3d._generate(x3d, realNVPp3d.mask, realNVPp3d.mask_) print("3d Forward after restore") #print(zz3d) assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy()) assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())