Ejemplo n.º 1
0
    def forward_step(self, x, squeeze_factor, es=None):
        sum_logdet = 0
        out = x

        out = squeeze(out, factor=squeeze_factor)

        for flow in self.flows:
            out, logdet = flow.forward_step(out)
            sum_logdet += logdet

        if self.split_output:
            # Add noise here:
            if es is not None:
                out = unsqueeze(out, factor=squeeze_factor)
                out += es
                out = squeeze(out)
            zi, out = split_channel(out)
            prior_in = out
        else:
            zi = out
            out = None
            prior_in = zeros_like(zi)

        z_distritubion = self.prior(prior_in)
        mean, ln_var = split_channel(z_distritubion)

        return out, (zi, mean, ln_var), sum_logdet
Ejemplo n.º 2
0
def forward_blocks(x, encoder: InferenceModel, decoder: GenerativeModel):
    z = []
    sum_logdet = 0
    out = x
    num_levels = len(encoder.blocks)

    for level, (block_enc, block_gen) in enumerate(
            zip(encoder.blocks, reversed(decoder.blocks))):
        # squeeze
        out = squeeze(out, factor=2)

        # step of flow
        out, logdet = forward_flows(out, block_enc, block_gen)
        sum_logdet += logdet

        # split
        if level == num_levels - 1:
            z.append(out)
        else:
            n = out.shape[1]
            zi = out[:, :n // 2]
            out = out[:, n // 2:]
            z.append(zi)

    return z, sum_logdet
Ejemplo n.º 3
0
 def factor_z(self, z):
     factorized_z = []
     for level in range(self.hyperparams.levels):
         z = squeeze(z)
         if level == self.hyperparams.levels - 1:
             factorized_z.append(z)
         else:
             zi, z = split_channel(z)
             factorized_z.append(zi)
     return factorized_z
Ejemplo n.º 4
0
    def initialize_actnorm_weights(self, x, reduce_memory=False):
        xp = cuda.get_array_module(x)
        levels = len(self.blocks)
        out = x

        for level, block in enumerate(self.blocks):
            out = squeeze(out, factor=self.hyperparams.squeeze_factor)

            for flow in block.flows:
                mean = xp.mean(out.data, axis=(0, 2, 3), keepdims=True)
                std = xp.std(out.data, axis=(0, 2, 3), keepdims=True)

                flow.actnorm.scale.data = 1.0 / std
                flow.actnorm.bias.data = -mean

                out, _ = flow.forward_step(out, reduce_memory=reduce_memory)

            if level < levels - 1:
                _, out = split_channel(out)
Ejemplo n.º 5
0
    def forward_step(self, x, squeeze_factor, reduce_memory=False):
        sum_logdet = 0
        out = x

        out = squeeze(out, factor=squeeze_factor)

        for flow in self.flows:
            out, logdet = flow.forward_step(out, reduce_memory=reduce_memory)
            sum_logdet += logdet

        if self.split_output:
            zi, out = split_channel(out)
            prior_in = out
        else:
            zi = out
            out = None
            prior_in = zeros_like(zi)

        z_distritubion = self.prior(prior_in)
        mean, ln_var = split_channel(z_distritubion)

        return out, (zi, mean, ln_var), sum_logdet