예제 #1
0
    def setup(self):
        self._in_conv = blocks_lib.get_vdvae_convolution(
            self.num_channels, (3, 3),
            name='in_conv',
            precision=self.precision)

        sampling_rates = sorted(self.downsampling_rates)
        num_blocks = self.num_blocks

        current_sequence_start = 0
        blocks = []
        for block_idx, rate in sampling_rates:
            if rate == 1:
                continue
            sequence_length = block_idx - current_sequence_start
            if sequence_length > 0:
                # Add sequence of non-downsampling blocks as a single layer stack.
                for i in range(current_sequence_start, block_idx):
                    blocks.append(
                        blocks_lib.ResBlock(self.bottlenecked_num_channels,
                                            self.num_channels,
                                            downsampling_rate=1,
                                            use_residual_connection=True,
                                            last_weights_scale=np.sqrt(
                                                1.0 / self.num_blocks),
                                            precision=self.precision,
                                            name=f'res_block_{i}'))

            # Add downsampling block
            blocks.append(
                blocks_lib.ResBlock(self.bottlenecked_num_channels,
                                    self.num_channels,
                                    downsampling_rate=rate,
                                    use_residual_connection=True,
                                    last_weights_scale=np.sqrt(
                                        1.0 / self.num_blocks),
                                    precision=self.precision,
                                    name=f'res_block_{block_idx}'))
            # Update running parameters
            current_sequence_start = block_idx + 1
        # Add remaining blocks after last downsampling block
        sequence_length = num_blocks - current_sequence_start
        if sequence_length > 0:
            # Add sequence of non-downsampling blocks as a single layer stack.
            for i in range(current_sequence_start, num_blocks):
                blocks.append(
                    blocks_lib.ResBlock(self.bottlenecked_num_channels,
                                        self.num_channels,
                                        downsampling_rate=1,
                                        use_residual_connection=True,
                                        last_weights_scale=np.sqrt(
                                            1.0 / self.num_blocks),
                                        precision=self.precision,
                                        name=f'res_block_{i}'))

        self._blocks = blocks
예제 #2
0
 def _compute_posterior(
     self,
     inputs,
     encoder_outputs,
     context_vectors,
 ):
     """Computes the posterior branch of the DecoderBlock."""
     chex.assert_rank(inputs, 4)
     resolution = inputs.shape[1]
     try:
         encoded_image = encoder_outputs[resolution]
     except KeyError:
         raise KeyError(
             'encoder_outputs does not contain the required '  # pylint: disable=g-doc-exception
             f'resolution ({resolution}). encoder_outputs resolutions '
             f'are {list(encoder_outputs.keys())}.')
     posterior_block = blocks.ResBlock(self.bottlenecked_num_channels,
                                       self.latent_dim * 2,
                                       use_residual_connection=False,
                                       precision=self.precision,
                                       name='posterior_block')
     concatenated_inputs = jnp.concatenate([inputs, encoded_image], axis=3)
     posterior_output = posterior_block(concatenated_inputs,
                                        context_vectors)
     posterior_mean, posterior_log_std = jnp.split(posterior_output,
                                                   2,
                                                   axis=3)
     return posterior_mean, posterior_log_std
예제 #3
0
 def _compute_outputs(self, inputs, features, latent):
     """Computes the outputs of the DecoderBlock."""
     latent_projection = blocks.get_vdvae_convolution(
         self.num_channels, (1, 1),
         self.weights_scale,
         name='latent_projection',
         precision=self.precision)
     output = inputs + features + latent_projection(latent)
     final_res_block = blocks.ResBlock(
         self.bottlenecked_num_channels,
         self.num_channels,
         use_residual_connection=True,
         last_weights_scale=self.weights_scale,
         precision=self.precision,
         name='final_residual_block')
     return final_res_block(output)
예제 #4
0
    def _compute_prior_and_features(
        self,
        inputs,
        context_vectors,
    ):
        """Computes the prior branch of the DecoderBlock."""
        chex.assert_rank(inputs, 4)
        prior_block = blocks.ResBlock(self.bottlenecked_num_channels,
                                      self.latent_dim * 2 + self.num_channels,
                                      use_residual_connection=False,
                                      last_weights_scale=0.0,
                                      precision=self.precision,
                                      name='prior_block')
        prior_output = prior_block(inputs, context_vectors)
        prior_mean = prior_output[Ellipsis, :self.latent_dim]
        prior_log_std = prior_output[Ellipsis,
                                     self.latent_dim:self.latent_dim * 2]
        features = prior_output[Ellipsis, self.latent_dim * 2:]

        return prior_mean, prior_log_std, features