Exemple #1
0
    def _build(self, encoder_features, decoder_features):
        """Build-method that returns the segmentation logits.

    Args:
      encoder_features: A list of tensors of shape (b,h_i,w_i,c_i).
      decoder_features: A tensor of shape (b,h,w,c).
    Returns:
      Logits, i.e. a tensor of shape (b,h,w,num_classes).
    """
        num_latents = len(self._latent_dims)
        start_level = num_latents + 1
        num_levels = len(self._channels_per_block)

        for level in range(start_level, num_levels, 1):
            decoder_features = unet_utils.resize_up(decoder_features, scale=2)
            decoder_features = tf.concat(
                [decoder_features, encoder_features[::-1][level]], axis=-1)
            for _ in range(self._blocks_per_level):
                decoder_features = unet_utils.res_block(
                    input_features=decoder_features,
                    n_channels=self._channels_per_block[::-1][level],
                    n_down_channels=self._down_channels_per_block[::-1][level],
                    activation_fn=self._activation_fn,
                    initializers=self._initializers,
                    regularizers=self._regularizers,
                    convs_per_block=self._convs_per_block)

        return snt.Conv2D(output_channels=self._num_classes,
                          kernel_shape=(1, 1),
                          padding='SAME',
                          initializers=self._initializers,
                          regularizers=self._regularizers,
                          name='logits')(decoder_features)
Exemple #2
0
    def _build(self, inputs, mean=False, z_q=None):
        """A build-method allowing to sample from the module as specified.

    Args:
      inputs: A tensor of shape (b,h,w,c). When using the module as a prior the
      `inputs` tensor should be a batch of images. When using it as a posterior
      the tensor should be a (batched) concatentation of images and
      segmentations.
      mean: A boolean or a list of booleans. If a boolean, it specifies whether
        or not to use the distributions' means in ALL latent scales. If a list,
        each bool therein specifies whether or not to use the scale's mean. If
        False, the latents of the scale are sampled.
      z_q: None or a list of tensors. If not None, z_q provides external latents
        to be used instead of sampling them. This is used to employ posterior
        latents in the prior during training. Therefore, if z_q is not None, the
        value of `mean` is ignored. If z_q is None, either the distributions
        mean is used (in case `mean` for the respective scale is True) or else
        a sample from the distribution is drawn.
    Returns:
      A Dictionary holding the output feature map of the truncated U-Net
      decoder under key 'decoder_features', a list of the U-Net encoder features
      produced at the end of each encoder scale under key 'encoder_outputs', a
      list of the predicted distributions at each scale under key
      'distributions', a list of the used latents at each scale under the key
      'used_latents'.
    """
        encoder_features = inputs
        encoder_outputs = []
        num_levels = len(self._channels_per_block)
        num_latent_levels = len(self._latent_dims)
        if isinstance(mean, bool):
            mean = [mean] * num_latent_levels
        distributions = []
        used_latents = []

        # Iterate the descending levels in the U-Net encoder.
        for level in range(num_levels):
            # Iterate the residual blocks in each level.
            for _ in range(self._blocks_per_level):
                encoder_features = unet_utils.res_block(
                    input_features=encoder_features,
                    n_channels=self._channels_per_block[level],
                    n_down_channels=self._down_channels_per_block[level],
                    activation_fn=self._activation_fn,
                    initializers=self._initializers,
                    regularizers=self._regularizers,
                    convs_per_block=self._convs_per_block)

            encoder_outputs.append(encoder_features)
            if level != num_levels - 1:
                encoder_features = unet_utils.resize_down(encoder_features,
                                                          scale=2)

        # Iterate the ascending levels in the (truncated) U-Net decoder.
        decoder_features = encoder_outputs[-1]
        for level in range(num_latent_levels):

            # Predict a Gaussian distribution for each pixel in the feature map.
            latent_dim = self._latent_dims[level]
            mu_logsigma = snt.Conv2D(
                2 * latent_dim,
                (1, 1),
                padding='SAME',
                initializers=self._initializers,
                regularizers=self._regularizers,
            )(decoder_features)

            mu = mu_logsigma[Ellipsis, :latent_dim]
            logsigma = mu_logsigma[Ellipsis, latent_dim:]
            dist = tfd.MultivariateNormalDiag(loc=mu,
                                              scale_diag=tf.exp(logsigma))
            distributions.append(dist)

            # Get the latents to condition on.
            if z_q is not None:
                z = z_q[level]
            elif mean[level]:
                z = dist.loc
            else:
                z = dist.sample()
            used_latents.append(z)

            # Concat and upsample the latents with the previous features.
            decoder_output_lo = tf.concat([z, decoder_features], axis=-1)
            decoder_output_hi = unet_utils.resize_up(decoder_output_lo,
                                                     scale=2)
            decoder_features = tf.concat(
                [decoder_output_hi, encoder_outputs[::-1][level + 1]], axis=-1)

            # Iterate the residual blocks in each level.
            for _ in range(self._blocks_per_level):
                decoder_features = unet_utils.res_block(
                    input_features=decoder_features,
                    n_channels=self._channels_per_block[::-1][level + 1],
                    n_down_channels=self._down_channels_per_block[::-1][level +
                                                                        1],
                    activation_fn=self._activation_fn,
                    initializers=self._initializers,
                    regularizers=self._regularizers,
                    convs_per_block=self._convs_per_block)

        return {
            'decoder_features': decoder_features,
            'encoder_features': encoder_outputs,
            'distributions': distributions,
            'used_latents': used_latents
        }