def forward_step(self, x, squeeze_factor, es=None): sum_logdet = 0 out = x out = squeeze(out, factor=squeeze_factor) for flow in self.flows: out, logdet = flow.forward_step(out) sum_logdet += logdet if self.split_output: # Add noise here: if es is not None: out = unsqueeze(out, factor=squeeze_factor) out += es out = squeeze(out) zi, out = split_channel(out) prior_in = out else: zi = out out = None prior_in = zeros_like(zi) z_distritubion = self.prior(prior_in) mean, ln_var = split_channel(z_distritubion) return out, (zi, mean, ln_var), sum_logdet
def forward_blocks(x, encoder: InferenceModel, decoder: GenerativeModel): z = [] sum_logdet = 0 out = x num_levels = len(encoder.blocks) for level, (block_enc, block_gen) in enumerate( zip(encoder.blocks, reversed(decoder.blocks))): # squeeze out = squeeze(out, factor=2) # step of flow out, logdet = forward_flows(out, block_enc, block_gen) sum_logdet += logdet # split if level == num_levels - 1: z.append(out) else: n = out.shape[1] zi = out[:, :n // 2] out = out[:, n // 2:] z.append(zi) return z, sum_logdet
def factor_z(self, z): factorized_z = [] for level in range(self.hyperparams.levels): z = squeeze(z) if level == self.hyperparams.levels - 1: factorized_z.append(z) else: zi, z = split_channel(z) factorized_z.append(zi) return factorized_z
def initialize_actnorm_weights(self, x, reduce_memory=False): xp = cuda.get_array_module(x) levels = len(self.blocks) out = x for level, block in enumerate(self.blocks): out = squeeze(out, factor=self.hyperparams.squeeze_factor) for flow in block.flows: mean = xp.mean(out.data, axis=(0, 2, 3), keepdims=True) std = xp.std(out.data, axis=(0, 2, 3), keepdims=True) flow.actnorm.scale.data = 1.0 / std flow.actnorm.bias.data = -mean out, _ = flow.forward_step(out, reduce_memory=reduce_memory) if level < levels - 1: _, out = split_channel(out)
def forward_step(self, x, squeeze_factor, reduce_memory=False): sum_logdet = 0 out = x out = squeeze(out, factor=squeeze_factor) for flow in self.flows: out, logdet = flow.forward_step(out, reduce_memory=reduce_memory) sum_logdet += logdet if self.split_output: zi, out = split_channel(out) prior_in = out else: zi = out out = None prior_in = zeros_like(zi) z_distritubion = self.prior(prior_in) mean, ln_var = split_channel(z_distritubion) return out, (zi, mean, ln_var), sum_logdet