def test_tempalte_contraction_mlp(): gaussian = Gaussian([2]) sList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] tList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)] realNVP = RealNVP([2], sList, tList, gaussian) x = realNVP.prior(10) mask = realNVP.createMask(["channel"] * 4, ifByte=1) print("original") #print(x) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 0, True) print("Forward") #print(z) zp = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 0, True) print("Backward") #print(zp) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy()) x_data = realNVP.prior(10) y_data = realNVP.prior.logProbability(x_data) print("logProbability") '''
def test_contraction_cuda_withDifferentMasks(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3).cuda() #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) realNVP = realNVP.cuda() mask = realNVP.createMask( ["channel", "checkerboard", "channel", "checkerboard"], 1, cuda=0) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 2, True) print( realNVP._logProbabilityWithContraction(z, realNVP.mask, realNVP.mask_, 2)) zz = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 2, True) assert_array_almost_equal(x.cpu().data.numpy(), zz.cpu().data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.cpu().numpy(), -realNVP._inferenceLogjac.data.cpu().numpy())
def test_multiplyMask_generateWithContraction_CNN(): gaussian3d = Gaussian([2, 4, 4]) x = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) mask = realNVP.createMask( ["channel", "checkerboard", "channel", "checkerboard"], ifByte=1) z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 2, True) #print(z) zz = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 2, True) #print(zz) assert_array_almost_equal(x.data.numpy(), zz.data.numpy()) assert_array_almost_equal(realNVP._generateLogjac.data.numpy(), -realNVP._inferenceLogjac.data.numpy())