Exemplo n.º 1
0
  def __init__(self, mixing_logits, component_stats, component_class,
               presence=None):
    """Builds the module.

    Args:
      mixing_logits: tensor [B, k, ...] with k the number of components.
      component_stats: list of tensors of shape [B, k, ...] or broadcastable
        to these shapes; they are argument to the chosen distribution class.
      component_class: callable; returns a distribution object.
      presence: [B, k] tensor of floats in [0, 1] or None.
    """
    super(MixtureDistribution, self).__init__()
    if presence is not None:
      mixing_logits += make_brodcastable(safe_log(presence), mixing_logits)

    self._mixing_logits = mixing_logits

    component_stats = nest.flatten(component_stats)
    self._distributions = component_class(*component_stats)
    self._presence = presence
Exemplo n.º 2
0
  def _maybe_mask(self, tensor):
    if self._presence is None:
      return tensor

    pres = make_brodcastable(self._presence, tensor)
    return tensor * pres
Exemplo n.º 3
0
    def _build(self,
               pose,
               presence=None,
               template_feature=None,
               bg_image=None,
               img_embedding=None):
        """Builds the module.

    Args:
      pose: [B, n_templates, 6] tensor.
      presence: [B, n_templates] tensor.
      template_feature: [B, n_templates, n_features] tensor; these features are
        used to change templates based on the input, if present.
      bg_image: [B, *output_size] tensor representing the background.
      img_embedding: [B, d] tensor containing image embeddings.

    Returns:
      [B, n_templates, *output_size, n_channels] tensor.
    """
        batch_size, n_templates = pose.shape[:2].as_list()
        templates = self.make_templates(n_templates, template_feature)

        if templates.shape[0] == 1:
            templates = snt.TileByDim([0], [batch_size])(templates)

        # it's easier for me to think in inverse coordinates
        warper = snt.AffineGridWarper(self._output_size, self._template_size)
        warper = warper.inverse()

        grid_coords = snt.BatchApply(warper)(pose)
        resampler = snt.BatchApply(contrib_resampler.resampler)
        transformed_templates = resampler(templates, grid_coords)

        if bg_image is not None:
            bg_image = tf.expand_dims(bg_image, axis=1)
        else:
            bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1]))
            bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image

        transformed_templates = tf.concat([transformed_templates, bg_image],
                                          axis=1)

        if presence is not None:
            presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1)

        if True:  # pylint: disable=using-constant-test

            if self._use_alpha_channel:
                template_mixing_logits = snt.TileByDim([0], [batch_size])(
                    self._templates_alpha)
                template_mixing_logits = resampler(template_mixing_logits,
                                                   grid_coords)

                bg_mixing_logit = tf.nn.softplus(
                    tf.get_variable('bg_mixing_logit', initializer=[0.]))

                bg_mixing_logit = (
                    tf.zeros_like(template_mixing_logits[:, :1]) +
                    bg_mixing_logit)

                template_mixing_logits = tf.concat(
                    [template_mixing_logits, bg_mixing_logit], 1)

            else:
                temperature_logit = tf.get_variable('temperature_logit',
                                                    shape=[1])
                temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4
                template_mixing_logits = transformed_templates / temperature

        scale = 1.
        if self._learn_output_scale:
            scale = tf.get_variable('scale', shape=[1])
            scale = tf.nn.softplus(scale) + 1e-4

        if self._output_pdf_type == 'mixture':
            template_mixing_logits += make_brodcastable(
                math_ops.safe_log(presence), template_mixing_logits)

            rec_pdf = prob.MixtureDistribution(template_mixing_logits,
                                               [transformed_templates, scale],
                                               tfd.Normal)

        else:
            raise ValueError('Unknown pdf type: "{}".'.format(
                self._output_pdf_type))

        return AttrDict(raw_templates=tf.squeeze(self._templates, 0),
                        transformed_templates=transformed_templates[:, :-1],
                        mixing_logits=template_mixing_logits[:, :-1],
                        pdf=rec_pdf)