def forward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: logdet_accum = input.new_zeros(input.size(0)) out = input outputs = [] attns = None for i, block in enumerate(self.blocks): # print('block {}, intput_out: shape: {}'.format(i, out.shape)) if i == 0: out, logdet, attns = block.forward_attn(out, h=h) else: out, logdet = block.forward(out, h=h) # print('block {}, forward_out: shape: {}'.format(i, out.shape)) logdet_accum = logdet_accum + logdet if i < self.levels - 1: if i > 0: # split when block is not bottom or top out1, out2 = split2d(out, block.z_channels) outputs.append(out2) out = out1 # print('block {}, split_out: shape: {}'.format(i, out.shape)) # squeeze when block is not top out = squeeze2d(out, factor=2) # print('block {}, squeeze_out: shape: {}'.format(i, out.shape)) if self.squeeze_h: h = squeeze2d(h, factor=2) out = unsqueeze2d(out, factor=2) for _ in range(self.internals): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum, attns
def backward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] out = input for i in range(self.levels - 1): if i > 0: out1, out2 = split2d(out, self.blocks[i].z_channels) outputs.append(out2) out = out1 out = squeeze2d(out, factor=2) if self.squeeze_h: h = squeeze2d(h, factor=2) logdet_accum = input.new_zeros(input.size(0)) for i, block in enumerate(reversed(self.blocks)): if i > 0: out = unsqueeze2d(out, factor=2) if self.squeeze_h: h = unsqueeze2d(h, factor=2) if i < self.levels - 1: out2 = outputs.pop() out = unsplit2d([out, out2]) if i < self.levels - 1: out, logdet = block.backward(out, h=h) else: out, logdet, attn = block.backward_attn(out, h=h) logdet_accum = logdet_accum + logdet assert len(outputs) == 0 return out, logdet_accum, attn
def init(self, data: torch.Tensor, h=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: logdet_accum = data.new_zeros(data.size(0)) out = data outputs = [] for i, block in enumerate(self.blocks): out, logdet = block.init(out, h=h, init_scale=init_scale) logdet_accum = logdet_accum + logdet if i < self.levels - 1: if i > 0: # split when block is not bottom or top out1, out2 = split2d(out, block.z_channels) outputs.append(out2) out = out1 # squeeze when block is not top out = squeeze2d(out, factor=2) if self.squeeze_h: h = squeeze2d(h, factor=2) out = unsqueeze2d(out, factor=2) for _ in range(self.internals): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum