def init(self, data, s=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: logdet_accum = data.new_zeros(data.size(0)) out = data outputs = [] for i, block in enumerate(self.blocks): if isinstance(block, MaCowInternalBlock) or isinstance( block, MaCowTopBlock): if s is not None: s = squeeze2d(s, factor=2) out = squeeze2d(out, factor=2) out, logdet = block.init(out, s=s, init_scale=init_scale) logdet_accum = logdet_accum + logdet if isinstance(block, MaCowInternalBlock): out1, out2 = split2d(out, block.z1_channels) outputs.append(out2) out = out1 out = unsqueeze2d(out, factor=2) for _ in range(self.internals): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum
def backward(self, input: torch.Tensor, s=None) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] if s is not None: s = squeeze2d(s, factor=2) out = squeeze2d(input, factor=2) for block in self.blocks: if isinstance(block, MaCowInternalBlock): if s is not None: s = squeeze2d(s, factor=2) out1, out2 = split2d(out, block.z1_channels) outputs.append(out2) out = squeeze2d(out1, factor=2) logdet_accum = input.new_zeros(input.size(0)) for i, block in enumerate(reversed(self.blocks)): if isinstance(block, MaCowInternalBlock): out2 = outputs.pop() out = unsplit2d([out, out2]) out, logdet = block.backward(out, s=s) logdet_accum = logdet_accum + logdet if isinstance(block, MaCowInternalBlock) or isinstance( block, MaCowTopBlock): if s is not None: s = unsqueeze2d(s, factor=2) out = unsqueeze2d(out, factor=2) assert len(outputs) == 0 return out, logdet_accum
def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] out = squeeze2d(input, factor=2) for _ in range(self.levels - 1): out1, out2 = split2d(out, out.size(1) // 2) outputs.append(out2) out = squeeze2d(out1, factor=2) logdet_accum = input.new_zeros(input.size(0)) for i, block in enumerate(reversed(self.blocks)): if isinstance(block, GlowInternalBlock): out2 = outputs.pop() out = unsplit2d([out, out2]) out, logdet = block.backward(out) logdet_accum = logdet_accum + logdet out = unsqueeze2d(out, factor=2) assert len(outputs) == 0 return out, logdet_accum
def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: logdet_accum = data.new_zeros(data.size(0)) out = data outputs = [] for i, block in enumerate(self.blocks): out = squeeze2d(out, factor=2) out, logdet = block.init(out, init_scale=init_scale) logdet_accum = logdet_accum + logdet if isinstance(block, GlowInternalBlock): out1, out2 = split2d(out, out.size(1) // 2) outputs.append(out2) out = out1 out = unsqueeze2d(out, factor=2) for _ in range(self.levels - 1): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum