示例#1
0
def test_tempalte_invertibleMLP():

    print("test mlp")

    gaussian = Gaussian([2])

    sList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)]
    tList = [MLP(2, 10), MLP(2, 10), MLP(2, 10), MLP(2, 10)]

    realNVP = RealNVP([2], sList, tList, gaussian)
    x = realNVP.prior(10)
    mask = realNVP.createMask(["channel"] * 4, ifByte=0)
    print("original")
    #print(x)

    z = realNVP._generate(x, realNVP.mask, realNVP.mask_, True)

    print("Forward")
    #print(z)

    zp = realNVP._inference(z, realNVP.mask, realNVP.mask_, True)

    print("Backward")
    #print(zp)

    assert_array_almost_equal(realNVP._generateLogjac.data.numpy(),
                              -realNVP._inferenceLogjac.data.numpy())

    print("logProbability")
    print(realNVP._logProbability(z, realNVP.mask, realNVP.mask_))

    assert_array_almost_equal(x.data.numpy(), zp.data.numpy())
示例#2
0
def test_tempalte_contractionCNN_checkerboard_cuda():
    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3).cuda()
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [2, 2, 1, 0]]
    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d)
    realNVP3d = realNVP3d.cuda()
    mask3d = realNVP3d.createMask(["checkerboard"] * 4, ifByte=0, cuda=0)
    z3d = realNVP3d._generate(x3d, realNVP3d.mask, realNVP3d.mask_, True)
    zp3d = realNVP3d._inference(z3d, realNVP3d.mask, realNVP3d.mask_, True)
    print(realNVP3d._logProbability(z3d, realNVP3d.mask, realNVP3d.mask_))
    assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy())
    assert_array_almost_equal(realNVP3d._generateLogjac.data.cpu().numpy(),
                              -realNVP3d._inferenceLogjac.data.cpu().numpy())
示例#3
0
def test_tempalte_contraction_mlp():
    gaussian = Gaussian([2])

    sList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)]
    tList = [MLP(1, 10), MLP(1, 10), MLP(1, 10), MLP(1, 10)]

    realNVP = RealNVP([2], sList, tList, gaussian)

    x = realNVP.prior(10)
    mask = realNVP.createMask(["channel"] * 4, ifByte=1)
    print("original")
    #print(x)

    z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 0,
                                         True)

    print("Forward")
    #print(z)

    zp = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 0,
                                           True)

    print("Backward")
    #print(zp)
    assert_array_almost_equal(realNVP._generateLogjac.data.numpy(),
                              -realNVP._inferenceLogjac.data.numpy())

    x_data = realNVP.prior(10)
    y_data = realNVP.prior.logProbability(x_data)
    print("logProbability")
    '''
示例#4
0
def test_logProbabilityWithInference_cuda():
    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3).cuda()
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0], [1, 2, 1, 0]]
    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d).cuda()
    mask3d = realNVP3d.createMask(["checkerboard"] * 4, cuda=0)

    z3d = realNVP3d.generate(x3d)
    zp3d = realNVP3d.inference(z3d)

    print(realNVP3d.logProbabilityWithInference(z3d)[1])

    assert_array_almost_equal(x3d.cpu().data.numpy(), zp3d.cpu().data.numpy())
示例#5
0
def test_contraction_cuda_withDifferentMasks():
    gaussian3d = Gaussian([2, 4, 4])
    x = gaussian3d(3).cuda()
    #z3dp = z3d[:,0,:,:].view(10,-1,4,4)
    #print(z3dp)

    #print(x)
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0],
                    [1, 2, 1, 0]]  # [channel, filter_size, stride, padding]

    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d)
    realNVP = realNVP.cuda()
    mask = realNVP.createMask(
        ["channel", "checkerboard", "channel", "checkerboard"], 1, cuda=0)

    z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 2,
                                         True)
    print(
        realNVP._logProbabilityWithContraction(z, realNVP.mask, realNVP.mask_,
                                               2))
    zz = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 2,
                                           True)

    assert_array_almost_equal(x.cpu().data.numpy(), zz.cpu().data.numpy())
    assert_array_almost_equal(realNVP._generateLogjac.data.cpu().numpy(),
                              -realNVP._inferenceLogjac.data.cpu().numpy())
示例#6
0
def test_multiplyMask_generateWithSlice_CNN():
    gaussian3d = Gaussian([2, 4, 4])
    x = gaussian3d(3)
    #z3dp = z3d[:,0,:,:].view(10,-1,4,4)
    #print(z3dp)

    #print(x)
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0],
                    [1, 2, 1, 0]]  # [channel, filter_size, stride, padding]

    sList3d = [
        CNN(netStructure),
        CNN(netStructure),
        CNN(netStructure),
        CNN(netStructure)
    ]
    tList3d = [
        CNN(netStructure),
        CNN(netStructure),
        CNN(netStructure),
        CNN(netStructure)
    ]

    realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d)
    mask = realNVP.createMask(
        ["channel", "checkerboard", "channel", "checkerboard"], ifByte=1)

    z = realNVP._generateWithSlice(x, 0, True)
    #print(z)
    zz = realNVP._inferenceWithSlice(z, 0, True)

    #print(zz)

    assert_array_almost_equal(x.data.numpy(), zz.data.numpy())
    #print(realNVP._generateLogjac.data.numpy())
    #print(realNVP._inferenceLogjac.data.numpy())
    assert_array_almost_equal(realNVP._generateLogjac.data.numpy(),
                              -realNVP._inferenceLogjac.data.numpy())
示例#7
0
def test_template_contraction_function_with_channel():
    gaussian3d = Gaussian([2, 4, 4])
    x = gaussian3d(3)
    #z3dp = z3d[:,0,:,:].view(10,-1,4,4)
    #print(z3dp)

    #print(x)
    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0],
                    [1, 2, 1, 0]]  # [channel, filter_size, stride, padding]

    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d)
    mask = realNVP.createMask(["channel"] * 4, 1)

    z = realNVP._generateWithContraction(x, realNVP.mask, realNVP.mask_, 2,
                                         True)
    #print(z)

    zz = realNVP._inferenceWithContraction(z, realNVP.mask, realNVP.mask_, 2,
                                           True)
    #print(zz)

    assert_array_almost_equal(x.data.numpy(), zz.data.numpy())
    assert_array_almost_equal(realNVP._generateLogjac.data.numpy(),
                              -realNVP._inferenceLogjac.data.numpy())
示例#8
0
def test_tempalte_invertibleCNN():

    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3)
    #z3dp = z3d[:,0,:,:].view(10,-1,4,4)
    #print(z3dp)

    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0],
                    [2, 2, 1, 0]]  # [channel, filter_size, stride, padding]

    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d)
    mask3d = realNVP3d.createMask(["channel"] * 4, ifByte=0)

    print("Testing 3d")
    print("3d original:")
    #print(x3d)

    z3d = realNVP3d._generate(x3d, realNVP3d.mask, realNVP3d.mask_, True)
    print("3d forward:")
    #print(z3d)

    zp3d = realNVP3d._inference(z3d, realNVP3d.mask, realNVP3d.mask_, True)
    print("Backward")
    #print(zp3d)

    assert_array_almost_equal(realNVP3d._generateLogjac.data.numpy(),
                              -realNVP3d._inferenceLogjac.data.numpy())

    print("3d logProbability")
    print(realNVP3d._logProbability(z3d, realNVP3d.mask, realNVP3d.mask_))

    saveDict3d = realNVP3d.saveModel({})
    torch.save(saveDict3d, './saveNet3d.testSave')
    # realNVP.loadModel({})
    sListp3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tListp3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d)
    saveDictp3d = torch.load('./saveNet3d.testSave')
    realNVPp3d.loadModel(saveDictp3d)

    zz3d = realNVPp3d._generate(x3d, realNVPp3d.mask, realNVPp3d.mask_)
    print("3d Forward after restore")
    #print(zz3d)

    assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy())
    assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())
示例#9
0
def test_checkerboardMask():
    gaussian3d = Gaussian([2, 4, 4])
    x3d = gaussian3d(3)
    #z3dp = z3d[:,0,:,:].view(10,-1,4,4)
    #print(z3dp)

    netStructure = [[3, 2, 1, 1], [4, 2, 1, 1], [3, 2, 1, 0],
                    [1, 2, 1, 0]]  # [channel, filter_size, stride, padding]

    sList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tList3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVP3d = RealNVP([2, 4, 4], sList3d, tList3d, gaussian3d)
    mask3d = realNVP3d.createMask(["checkerboard"] * 4)
    print(realNVP3d.mask)

    z3d = realNVP3d.generate(x3d)
    print(realNVP3d.mask)
    print("3d forward:")
    #print(z3d)

    zp3d = realNVP3d.inference(z3d)
    print("Backward")
    #print(zp3d)

    print("3d logProbability")
    print(realNVP3d.logProbability(z3d))

    saveDict3d = realNVP3d.saveModel({})
    torch.save(saveDict3d, './saveNet3d.testSave')
    # realNVP.loadModel({})
    sListp3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]
    tListp3d = [
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2),
        CNN(netStructure, inchannel=2)
    ]

    realNVPp3d = RealNVP([2, 4, 4], sListp3d, tListp3d, gaussian3d)
    saveDictp3d = torch.load('./saveNet3d.testSave')
    realNVPp3d.loadModel(saveDictp3d)

    zz3d = realNVPp3d.generate(x3d)
    print("3d Forward after restore")
    #print(zz3d)

    assert_array_almost_equal(x3d.data.numpy(), zp3d.data.numpy())
    assert_array_almost_equal(zz3d.data.numpy(), z3d.data.numpy())