def setup(self): self._in_conv = blocks_lib.get_vdvae_convolution( self.num_channels, (3, 3), name='in_conv', precision=self.precision) sampling_rates = sorted(self.downsampling_rates) num_blocks = self.num_blocks current_sequence_start = 0 blocks = [] for block_idx, rate in sampling_rates: if rate == 1: continue sequence_length = block_idx - current_sequence_start if sequence_length > 0: # Add sequence of non-downsampling blocks as a single layer stack. for i in range(current_sequence_start, block_idx): blocks.append( blocks_lib.ResBlock(self.bottlenecked_num_channels, self.num_channels, downsampling_rate=1, use_residual_connection=True, last_weights_scale=np.sqrt( 1.0 / self.num_blocks), precision=self.precision, name=f'res_block_{i}')) # Add downsampling block blocks.append( blocks_lib.ResBlock(self.bottlenecked_num_channels, self.num_channels, downsampling_rate=rate, use_residual_connection=True, last_weights_scale=np.sqrt( 1.0 / self.num_blocks), precision=self.precision, name=f'res_block_{block_idx}')) # Update running parameters current_sequence_start = block_idx + 1 # Add remaining blocks after last downsampling block sequence_length = num_blocks - current_sequence_start if sequence_length > 0: # Add sequence of non-downsampling blocks as a single layer stack. for i in range(current_sequence_start, num_blocks): blocks.append( blocks_lib.ResBlock(self.bottlenecked_num_channels, self.num_channels, downsampling_rate=1, use_residual_connection=True, last_weights_scale=np.sqrt( 1.0 / self.num_blocks), precision=self.precision, name=f'res_block_{i}')) self._blocks = blocks
def _compute_posterior( self, inputs, encoder_outputs, context_vectors, ): """Computes the posterior branch of the DecoderBlock.""" chex.assert_rank(inputs, 4) resolution = inputs.shape[1] try: encoded_image = encoder_outputs[resolution] except KeyError: raise KeyError( 'encoder_outputs does not contain the required ' # pylint: disable=g-doc-exception f'resolution ({resolution}). encoder_outputs resolutions ' f'are {list(encoder_outputs.keys())}.') posterior_block = blocks.ResBlock(self.bottlenecked_num_channels, self.latent_dim * 2, use_residual_connection=False, precision=self.precision, name='posterior_block') concatenated_inputs = jnp.concatenate([inputs, encoded_image], axis=3) posterior_output = posterior_block(concatenated_inputs, context_vectors) posterior_mean, posterior_log_std = jnp.split(posterior_output, 2, axis=3) return posterior_mean, posterior_log_std
def _compute_outputs(self, inputs, features, latent): """Computes the outputs of the DecoderBlock.""" latent_projection = blocks.get_vdvae_convolution( self.num_channels, (1, 1), self.weights_scale, name='latent_projection', precision=self.precision) output = inputs + features + latent_projection(latent) final_res_block = blocks.ResBlock( self.bottlenecked_num_channels, self.num_channels, use_residual_connection=True, last_weights_scale=self.weights_scale, precision=self.precision, name='final_residual_block') return final_res_block(output)
def _compute_prior_and_features( self, inputs, context_vectors, ): """Computes the prior branch of the DecoderBlock.""" chex.assert_rank(inputs, 4) prior_block = blocks.ResBlock(self.bottlenecked_num_channels, self.latent_dim * 2 + self.num_channels, use_residual_connection=False, last_weights_scale=0.0, precision=self.precision, name='prior_block') prior_output = prior_block(inputs, context_vectors) prior_mean = prior_output[Ellipsis, :self.latent_dim] prior_log_std = prior_output[Ellipsis, self.latent_dim:self.latent_dim * 2] features = prior_output[Ellipsis, self.latent_dim * 2:] return prior_mean, prior_log_std, features