示例#1
0
def test_checkerboard_cuda_cudaNot0():
    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3).cuda(maxGPU // 2)
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]]
    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d,
                        gaussian3d).cuda(maxGPU // 2)
    mask3d = realNVP3d.createMask(["checkerboard"] * 4, cuda=maxGPU // 2)

    z3d = realNVP3d.generate(x3d)
    zp3d = realNVP3d.inference(z3d)

    print(realNVP3d.logProbability(z3d))

    assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy())
示例#2
0
def test_sample():
    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3)
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]]
    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d,
                        "checkerboard")

    z3d = realNVP3d.sample(100, True)

    zp3d = realNVP3d.sample(100, False)

    print(realNVP3d.logProbability(z3d))
示例#3
0
def test_3d():

    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3)
    #z3dp = z3d[:,0,:,:].view(10,-1,4,4)
    #print(z3dp)

    #print(x)
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0],
                    [1, 2, 1, 0]]  # [channel, filter_size, stride, padding]

    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d,
                        gaussian3d)  #,maskType = "checkerboard")
    print(realNVP3d.mask)
    #mask3d = realNVP3d.createMask()

    assert realNVP3d.mask.shape[0] == 4
    assert realNVP3d.mask.shape[1] == 2
    assert realNVP3d.mask.shape[2] == 4
    assert realNVP3d.mask.shape[3] == 4

    print("test high dims")

    print("Testing 3d")
    print("3d original:")
    #print(x3d)

    z3d = realNVP3d.generate(x3d)
    print("3d forward:")
    #print(z3d)

    zp3d = realNVP3d.inference(z3d)
    print("Backward")
    #print(zp3d)

    print("3d logProbability")
    print(realNVP3d.logProbability(z3d))

    saveDict3d = realNVP3d.saveModel({})
    torch.save(saveDict3d, './saveNet3d.testSave')
    # realNVP.loadModel({})
    sListp3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tListp3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d)
    saveDictp3d = torch.load('./saveNet3d.testSave')
    realNVPp3d.loadModel(saveDictp3d)

    zz3d = realNVPp3d.generate(x3d)
    print("3d Forward after restore")
    #print(zz3d)

    assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy())
    assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())