def init(self, data, s=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: logdet_accum = data.new_zeros(data.size(0)) out = data outputs = [] for i, block in enumerate(self.blocks): if isinstance(block, MaCowInternalBlock) or isinstance( block, MaCowTopBlock): if s is not None: s = squeeze2d(s, factor=2) out = squeeze2d(out, factor=2) out, logdet = block.init(out, s=s, init_scale=init_scale) logdet_accum = logdet_accum + logdet if isinstance(block, MaCowInternalBlock): out1, out2 = split2d(out, block.z1_channels) outputs.append(out2) out = out1 out = unsqueeze2d(out, factor=2) for _ in range(self.internals): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum
def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] if s is not None: s = squeeze2d(s, factor=2) out = squeeze2d(input, factor=2) for block in self.blocks: if isinstance(block, MaCowInternalBlock): if s is not None: s = squeeze2d(s, factor=2) out1, out2 = split2d(out, block.z1_channels) outputs.append(out2) out = squeeze2d(out1, factor=2) logdet_accum = input.new_zeros(input.size(0)) for i, block in enumerate(reversed(self.blocks)): if isinstance(block, MaCowInternalBlock): out2 = outputs.pop() out = unsplit2d([out, out2]) out, logdet = block.backward(out, s=s) logdet_accum = logdet_accum + logdet if isinstance(block, MaCowInternalBlock) or isinstance( block, MaCowTopBlock): if s is not None: s = unsqueeze2d(s, factor=2) out = unsqueeze2d(out, factor=2) assert len(outputs) == 0 return out, logdet_accum
def init( self, data, s=None, init_scale=1.0 ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: out = data # [batch] logdet_accum = data.new_zeros(data.size(0)) outputs = [] for layer, prior in zip(self.layers, self.priors): for step in layer: out, logdet = step.init(out, s=s, init_scale=init_scale) logdet_accum = logdet_accum + logdet out, logdet = prior.init(out, s=s, init_scale=init_scale) logdet_accum = logdet_accum + logdet # split out1, out2 = split2d(out, prior.z1_channels) outputs.append(out2) out = out1 outputs.append(out) outputs.reverse() out = unsplit2d(outputs) return out, logdet_accum
def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: logdet_accum = data.new_zeros(data.size(0)) out = data outputs = [] for i, block in enumerate(self.blocks): out = squeeze2d(out, factor=2) out, logdet = block.init(out, init_scale=init_scale) logdet_accum = logdet_accum + logdet if isinstance(block, GlowInternalBlock): out1, out2 = split2d(out, out.size(1) // 2) outputs.append(out2) out = out1 out = unsqueeze2d(out, factor=2) for _ in range(self.levels - 1): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum
def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] out = squeeze2d(input, factor=2) for _ in range(self.levels - 1): out1, out2 = split2d(out, out.size(1) // 2) outputs.append(out2) out = squeeze2d(out1, factor=2) logdet_accum = input.new_zeros(input.size(0)) for i, block in enumerate(reversed(self.blocks)): if isinstance(block, GlowInternalBlock): out2 = outputs.pop() out = unsplit2d([out, out2]) out, logdet = block.backward(out) logdet_accum = logdet_accum + logdet out = unsqueeze2d(out, factor=2) assert len(outputs) == 0 return out, logdet_accum
def forward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]: out = input # [batch] logdet_accum = input.new_zeros(input.size(0)) outputs = [] for layer, prior in zip(self.layers, self.priors): for step in layer: out, logdet = step.forward(out, s=s) logdet_accum = logdet_accum + logdet out, logdet = prior.forward(out, s=s) logdet_accum = logdet_accum + logdet # split out1, out2 = split2d(out, prior.z1_channels) outputs.append(out2) out = out1 outputs.append(out) outputs.reverse() out = unsplit2d(outputs) return out, logdet_accum
def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]: out = input outputs = [] for prior in self.priors: out1, out2 = split2d(out, prior.z1_channels) outputs.append(out2) out = out1 # [batch] logdet_accum = out.new_zeros(out.size(0)) for layer, prior in zip(reversed(self.layers), reversed(self.priors)): out2 = outputs.pop() out = unsplit2d([out, out2]) out, logdet = prior.backward(out, s=s) logdet_accum = logdet_accum + logdet for step in reversed(layer): out, logdet = step.backward(out, s=s) logdet_accum = logdet_accum + logdet assert len(outputs) == 0 return out, logdet_accum