def flowBuilder(n,numFlow,innerBuilder=None,typeLayer=3,relax=False,shift=False): nn = n*2 op = source.Gaussian([nn]).to(torch.float64) if innerBuilder is None: raise Exception("innerBuilder is None") if relax: f3 = flow.DiagScaling(nn,initValue=0.1*np.random.randn(nn),fix=[0]*n+[0]*n,shift=shift) else: f3 = flow.DiagScaling(nn,initValue=0.1*np.random.randn(nn),fix=[0]*n+[1]*n,shift=shift) layers=[f3] if typeLayer == 0: layers.append(flow.Symplectic(nn)) else: for d in range(numFlow): if typeLayer == 3: layers.append(flow.PointTransformation(innerBuilder(n))) layers.append(flow.Symplectic(nn)) elif typeLayer ==2: layers.append(flow.Symplectic(nn)) elif typeLayer ==1: layers.append(flow.PointTransformation(innerBuilder(n))) elif typeLayer!=0: raise Exception("No such type") return flow.FlowNet(layers,op).double()
def test_saveload(): p = source.Gaussian([8]) maskList = [] for n in range(4): if n %2==0: b = torch.zeros(4) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[i] = 1 b=b.reshape(1,4) else: b = 1-b maskList.append(b) maskList = torch.cat(maskList,0).to(torch.float32) fl = flow.RNVP(maskList, [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),None]) for _ in range(4)], [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),utils.ScalableTanh(4)]) for _ in range(4)]) f = flow.PointTransformation(fl,p) p = source.Gaussian([8]) maskList = [] for n in range(4): if n %2==0: b = torch.zeros(4) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[i] = 1 b=b.reshape(1,4) else: b = 1-b maskList.append(b) maskList = torch.cat(maskList,0).to(torch.float32) fl = flow.RNVP(maskList, [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),None]) for _ in range(4)], [utils.SimpleMLPreshape([4,32,32,4],[nn.ELU(),nn.ELU(),utils.ScalableTanh(4)]) for _ in range(4)]) blankf = flow.PointTransformation(fl,p) saveload(f,blankf)