コード例 #1
0
ファイル: flowBuilder.py プロジェクト: wwang2/neuralCT
def flowBuilder(n,numFlow,innerBuilder=None,typeLayer=3,relax=False,shift=False):
    nn = n*2
    op = source.Gaussian([nn]).to(torch.float64)

    if innerBuilder is None:
        raise Exception("innerBuilder is None")
    if relax:
        f3 = flow.DiagScaling(nn,initValue=0.1*np.random.randn(nn),fix=[0]*n+[0]*n,shift=shift)
    else:
        f3 = flow.DiagScaling(nn,initValue=0.1*np.random.randn(nn),fix=[0]*n+[1]*n,shift=shift)
    layers=[f3]
    if typeLayer == 0:
        layers.append(flow.Symplectic(nn))
    else:
        for d in range(numFlow):
            if typeLayer == 3:
                layers.append(flow.PointTransformation(innerBuilder(n)))
                layers.append(flow.Symplectic(nn))
            elif typeLayer ==2:
                layers.append(flow.Symplectic(nn))
            elif typeLayer ==1:
                layers.append(flow.PointTransformation(innerBuilder(n)))
            elif typeLayer!=0:
                raise Exception("No such type")
    return flow.FlowNet(layers,op).double()
コード例 #2
0
def test_saveload():
    p = source.Gaussian([8])
    maskList = []
    for n in range(4):
        if n %2==0:
            b = torch.zeros(4)
            i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2)
            b.zero_()[i] = 1
            b=b.reshape(1,4)
        else:
            b = 1-b
        maskList.append(b)
    maskList = torch.cat(maskList,0).to(torch.float32)

    fl = flow.RNVP(maskList, [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),None]) for _ in range(4)], [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),utils.ScalableTanh(4)]) for _ in range(4)])
    f = flow.PointTransformation(fl,p)

    p = source.Gaussian([8])
    maskList = []
    for n in range(4):
        if n %2==0:
            b = torch.zeros(4)
            i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2)
            b.zero_()[i] = 1
            b=b.reshape(1,4)
        else:
            b = 1-b
        maskList.append(b)
    maskList = torch.cat(maskList,0).to(torch.float32)

    fl = flow.RNVP(maskList, [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),None]) for _ in range(4)], [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),utils.ScalableTanh(4)]) for _ in range(4)])
    blankf = flow.PointTransformation(fl,p)
    saveload(f,blankf)