def test_tempalte_contraction_mlp(): gaussian = Gaussian([2]) sList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVP = RealNVP([2], sList, tList, gaussian) x = realNVP.prior(10) mask = realNVP.createMask(["channel"] * 4, ifByte=1) print("original") #print(x) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 0, True) print("Forward") #print(z) zp = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 0, True) print("Backward") #print(zp) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy()) x_data = realNVP.prior(10) y_data = realNVP.prior.logProbability(x_data) print("logProbability") '''
def test_tempalte_invertibleMLP(): print("test mlp") gaussian = Gaussian([2]) sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVP = RealNVP([2], sList, tList, gaussian) x = realNVP.prior(10) mask = realNVP.createMask(["channel"] * 4, ifByte=0) print("original") #print(x) z = realNVP._generate(x, realNVP.mask, realNVP.mask_, True) print("Forward") #print(z) zp = realNVP._inference(z, realNVP.mask, realNVP.mask_, True) print("Backward") #print(zp) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy()) print("logProbability") print(realNVP._logProbability(z, realNVP.mask, realNVP.mask_)) assert_array_almost_equal(x.data.numpy(), zp.data.numpy())
def test_workmode2(): gaussian = Gaussian([2]) sList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVP = RealNVP([2], sList, tList, gaussian, mode=2) z = realNVP.prior(10) x = realNVP.generate(z, sliceDim=0) zp = realNVP.inference(x, sliceDim=0) assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) saveDict = realNVP.saveModel({}) torch.save(saveDict, './saveNet.testSave') # realNVP.loadModel({}) sListp = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tListp = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVPp = RealNVP([2], sListp, tListp, gaussian) saveDictp = torch.load('./saveNet.testSave') realNVPp.loadModel(saveDictp) xx = realNVP.generate(z, sliceDim=0) print("Forward after restore") assert_array_almost_equal(xx.data.numpy(), x.data.numpy())
def test_invertible(): print("test realNVP") gaussian = Gaussian([2]) sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVP = RealNVP([2], sList, tList, gaussian) print(realNVP.mask) print(realNVP.mask_) z = realNVP.prior(10) #mask = realNVP.createMask() assert realNVP.mask.shape[0] == 4 assert realNVP.mask.shape[1] == 2 print("original") #print(x) x = realNVP.generate(z) print("Forward") #print(z) zp = realNVP.inference(x) print("Backward") #print(zp) assert_array_almost_equal(z.data.numpy(), zp.data.numpy()) saveDict = realNVP.saveModel({}) torch.save(saveDict, './saveNet.testSave') # realNVP.loadModel({}) sListp = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] tListp = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)] realNVPp = RealNVP([2], sListp, tListp, gaussian) saveDictp = torch.load('./saveNet.testSave') realNVPp.loadModel(saveDictp) xx = realNVP.generate(z) print("Forward after restore") assert_array_almost_equal(xx.data.numpy(), x.data.numpy())