def test_checkerboard_cuda_cudaNot0(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3).cuda(maxGPU // 2) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d).cuda(maxGPU // 2) mask3d = realNVP3d.createMask(["checkerboard"] * 4, cuda=maxGPU // 2) z3d = realNVP3d.generate(x3d) zp3d = realNVP3d.inference(z3d) print(realNVP3d.logProbability(z3d)) assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy())
def test_sample(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d, "checkerboard") z3d = realNVP3d.sample(100, True) zp3d = realNVP3d.sample(100, False) print(realNVP3d.logProbability(z3d))
def test_3d(): gaussian3d = Gaussian([2, 4, 4]) x3d = gaussian3d(3) #z3dp = z3d[:,0,:,:].view(10,-1,4,4) #print(z3dp) #print(x) netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]] # [channel, filter_size, stride, padding] sList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tList3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d) #,maskType = "checkerboard") print(realNVP3d.mask) #mask3d = realNVP3d.createMask() assert realNVP3d.mask.shape[0] == 4 assert realNVP3d.mask.shape[1] == 2 assert realNVP3d.mask.shape[2] == 4 assert realNVP3d.mask.shape[3] == 4 print("test high dims") print("Testing 3d") print("3d original:") #print(x3d) z3d = realNVP3d.generate(x3d) print("3d forward:") #print(z3d) zp3d = realNVP3d.inference(z3d) print("Backward") #print(zp3d) print("3d logProbability") print(realNVP3d.logProbability(z3d)) saveDict3d = realNVP3d.saveModel({}) torch.save(saveDict3d, './saveNet3d.testSave') # realNVP.loadModel({}) sListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] tListp3d = [ CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2), CNN(netStructure, inchannel=2) ] realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d) saveDictp3d = torch.load('./saveNet3d.testSave') realNVPp3d.loadModel(saveDictp3d) zz3d = realNVPp3d.generate(x3d) print("3d Forward after restore") #print(zz3d) assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy()) assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())