def backward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: outputs = [] out = input for i in range(self.levels - 1): if i > 0: out1, out2 = split2d(out, self.blocks[i].z_channels) outputs.append(out2) out = out1 out = squeeze2d(out, factor=2) if self.squeeze_h: h = squeeze2d(h, factor=2) logdet_accum = input.new_zeros(input.size(0)) for i, block in enumerate(reversed(self.blocks)): if i > 0: out = unsqueeze2d(out, factor=2) if self.squeeze_h: h = unsqueeze2d(h, factor=2) if i < self.levels - 1: out2 = outputs.pop() out = unsplit2d([out, out2]) if i < self.levels - 1: out, logdet = block.backward(out, h=h) else: out, logdet, attn = block.backward_attn(out, h=h) logdet_accum = logdet_accum + logdet assert len(outputs) == 0 return out, logdet_accum, attn
def forward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: logdet_accum = input.new_zeros(input.size(0)) out = input outputs = [] attns = None for i, block in enumerate(self.blocks): # print('block {}, intput_out: shape: {}'.format(i, out.shape)) if i == 0: out, logdet, attns = block.forward_attn(out, h=h) else: out, logdet = block.forward(out, h=h) # print('block {}, forward_out: shape: {}'.format(i, out.shape)) logdet_accum = logdet_accum + logdet if i < self.levels - 1: if i > 0: # split when block is not bottom or top out1, out2 = split2d(out, block.z_channels) outputs.append(out2) out = out1 # print('block {}, split_out: shape: {}'.format(i, out.shape)) # squeeze when block is not top out = squeeze2d(out, factor=2) # print('block {}, squeeze_out: shape: {}'.format(i, out.shape)) if self.squeeze_h: h = squeeze2d(h, factor=2) out = unsqueeze2d(out, factor=2) for _ in range(self.internals): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum, attns
def init(self, data: torch.Tensor, h=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: logdet_accum = data.new_zeros(data.size(0)) out = data outputs = [] for i, block in enumerate(self.blocks): out, logdet = block.init(out, h=h, init_scale=init_scale) logdet_accum = logdet_accum + logdet if i < self.levels - 1: if i > 0: # split when block is not bottom or top out1, out2 = split2d(out, block.z_channels) outputs.append(out2) out = out1 # squeeze when block is not top out = squeeze2d(out, factor=2) if self.squeeze_h: h = squeeze2d(h, factor=2) out = unsqueeze2d(out, factor=2) for _ in range(self.internals): out2 = outputs.pop() out = unsqueeze2d(unsplit2d([out, out2]), factor=2) assert len(outputs) == 0 return out, logdet_accum
def init(self, data, h=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: # conv1x1 out, logdet_accum = self.conv1x1.init(data, init_scale=init_scale) # coupling out, logdet = self.coupling.init(out, h=h, init_scale=init_scale) logdet_accum = logdet_accum + logdet # actnorm out1, out2 = split2d(out, self.z1_channels) out2, logdet = self.actnorm.init(out2, init_scale=init_scale) logdet_accum = logdet_accum + logdet out = unsplit2d([out1, out2]) return out, logdet_accum
def backward(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: # actnorm out1, out2 = split2d(input, self.z1_channels) out2, logdet_accum = self.actnorm.backward(out2) out = unsplit2d([out1, out2]) # coupling out, logdet = self.coupling.backward(out, h=h) logdet_accum = logdet_accum + logdet # conv1x1 out, logdet = self.conv1x1.backward(out) logdet_accum = logdet_accum + logdet return out, logdet_accum
def forward_attn(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: attn = None # conv1x1 out, logdet_accum = self.conv1x1.forward(input) # coupling out, logdet, attn = self.coupling.forward_attn(out, h=h) logdet_accum = logdet_accum + logdet # actnorm out1, out2 = split2d(out, self.z1_channels) out2, logdet = self.actnorm.forward(out2) logdet_accum = logdet_accum + logdet out = unsplit2d([out1, out2]) return out, logdet_accum, attn
def init(self, data, h=None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: out = data # [batch] logdet_accum = data.new_zeros(data.size(0)) outputs = [] for layer, prior in zip(self.layers, self.priors): for step in layer: out, logdet = step.init(out, h=h, init_scale=init_scale) logdet_accum = logdet_accum + logdet out, logdet = prior.init(out, h=h, init_scale=init_scale) logdet_accum = logdet_accum + logdet # split out1, out2 = split2d(out, prior.z1_channels) outputs.append(out2) out = out1 outputs.append(out) outputs.reverse() out = unsplit2d(outputs) return out, logdet_accum
def forward(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: out = input # [batch] logdet_accum = input.new_zeros(input.size(0)) outputs = [] for layer, prior in zip(self.layers, self.priors): for step in layer: out, logdet = step.forward(out, h=h) logdet_accum = logdet_accum + logdet out, logdet = prior.forward(out, h=h) logdet_accum = logdet_accum + logdet # split out1, out2 = split2d(out, prior.z1_channels) outputs.append(out2) out = out1 outputs.append(out) outputs.reverse() out = unsplit2d(outputs) return out, logdet_accum
def backward(self, input: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: out = input outputs = [] for prior in self.priors: out1, out2 = split2d(out, prior.z1_channels) outputs.append(out2) out = out1 # [batch] logdet_accum = out.new_zeros(out.size(0)) for layer, prior in zip(reversed(self.layers), reversed(self.priors)): out2 = outputs.pop() out = unsplit2d([out, out2]) out, logdet = prior.backward(out, h=h) logdet_accum = logdet_accum + logdet for step in reversed(layer): out, logdet = step.backward(out, h=h) logdet_accum = logdet_accum + logdet assert len(outputs) == 0 return out, logdet_accum