def flowBuilder(n,numFlow,innerBuilder=None,typeLayer=3,relax=False,shift=False): nn = n*2 op = source.Gaussian([nn]).to(torch.float64) if innerBuilder is None: raise Exception("innerBuilder is None") if relax: f3 = flow.DiagScaling(nn,initValue=0.1*np.random.randn(nn),fix=[0]*n+[0]*n,shift=shift) else: f3 = flow.DiagScaling(nn,initValue=0.1*np.random.randn(nn),fix=[0]*n+[1]*n,shift=shift) layers=[f3] if typeLayer == 0: layers.append(flow.Symplectic(nn)) else: for d in range(numFlow): if typeLayer == 3: layers.append(flow.PointTransformation(innerBuilder(n))) layers.append(flow.Symplectic(nn)) elif typeLayer ==2: layers.append(flow.Symplectic(nn)) elif typeLayer ==1: layers.append(flow.PointTransformation(innerBuilder(n))) elif typeLayer!=0: raise Exception("No such type") return flow.FlowNet(layers,op).double()
def extractPPrior(flowCon): layers = [] _diag = flowCon.layerList[0] nn = _diag.shift.shape[0] // 2 _op = source.Gaussian([nn]).to(torch.float64) assert _diag.shift.sum() == 0 assert _diag.fix.sum() == nn layers.append( flow.DiagScaling(nn, initValue=_diag.elements.clone().detach()[nn:])) return flow.FlowNet(layers, _op).to(torch.float64)
def extractFlow(flowCon): from copy import deepcopy layers = [] _op = deepcopy(flowCon.prior) _rnvp = deepcopy(flowCon.layerList[1].flow) _diag = flowCon.layerList[0] nn = _diag.shift.shape[0]//2 layers.append(flow.DiagScaling(nn,initValue=_diag.elements.clone().detach()[:nn])) layers.append(_rnvp) return flow.FlowNet(layers,_op).double()
def extractFlow(flowCon): from copy import deepcopy layers = [] _rnvp = deepcopy(flowCon.layerList[1].flow) _diag = flowCon.layerList[0] nn = _diag.shift.shape[0] // 2 _op = source.Gaussian([nn]).to(torch.float64) assert _diag.shift.sum() == 0 assert _diag.fix.sum() == nn layers.append( flow.DiagScaling(nn, initValue=_diag.elements.clone().detach()[:nn])) layers.append(_rnvp) return flow.FlowNet(layers, _op).double()
def test_saveload(): p = source.Gaussian([4]) f1 = flow.Scaling(4, [2, 3]) maskList = [] for n in range(4): if n % 2 == 0: b = torch.zeros(1, 4) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[:, i] = 1 b = b.reshape(1, 4) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) f2 = flow.NICE(maskList, [ utils.SimpleMLPreshape([4, 32, 32, 4], [nn.ELU(), nn.ELU(), None]) for _ in range(4) ]) f = flow.FlowNet([f1, f2], p) f1 = flow.Scaling(4) maskList = [] for n in range(4): if n % 2 == 0: b = torch.zeros(1, 4) i = torch.randperm(b.numel()).narrow(0, 0, b.numel() // 2) b.zero_()[:, i] = 1 b = b.reshape(1, 4) else: b = 1 - b maskList.append(b) maskList = torch.cat(maskList, 0).to(torch.float32) f2 = flow.NICE(maskList, [ utils.SimpleMLPreshape([4, 32, 32, 4], [nn.ELU(), nn.ELU(), None]) for _ in range(4) ]) blankf = flow.FlowNet([f1, f2], p) saveload(f, blankf)
def buildSource(f): return flow.FlowNet([f.layerList[0]], f.prior).double()