コード例 #1
0
ファイル: scae.py プロジェクト: zzhaozeng/google-research
    def _build(self, h, x, presence=None):
        """Builds the module.

    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.

    Returns:
      A bunch of stuff.
    """
        batch_size = int(x.shape[0])

        capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims,
                                        self._n_votes, **self._capsule_kwargs)

        res = capsule(h)
        vote_shape = [batch_size, self._n_caps, self._n_votes, 6]
        res.vote = tf.reshape(res.vote[Ellipsis, :-1, :], vote_shape)

        votes, scale, vote_presence_prob = res.vote, res.scale, res.vote_presence

        likelihood = _capsule.CapsuleLikelihood(votes, scale,
                                                vote_presence_prob)
        ll_res = likelihood(x, presence)
        res.update(ll_res._asdict())

        caps_presence_prob = tf.reduce_max(
            tf.reshape(vote_presence_prob,
                       [batch_size, self._n_caps, self._n_votes]), 2)

        res.caps_presence_prob = caps_presence_prob
        return res
コード例 #2
0
  def _build(self, h, x, presence=None):
    """Builds the module.

    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.

    Returns:
      A bunch of stuff.
    """
    batch_size, n_input_points, _ = x.shape.as_list()

    capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims,
                                    self._n_votes, **self._capsule_kwargs)

    res = capsule(h)
    res.transform = res.vote
    res.vote = math_ops.apply_transform(transform=res.vote)
    for k, v in res.items():
      if v.shape.ndims > 0:
        res[k] = snt.MergeDims(1, 2)(v)

    likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes,
                                                          res.vote, res.scale,
                                                          res.vote_presence)
    ll_res = likelihood(x, presence)
    res.update(ll_res._asdict())

    # post processing
    mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1)
    prior_mixing_log_prob = tf.log(1. / n_input_points)
    mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob)
    mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1))

    wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps)

    if presence is not None:
      wins_per_caps *= tf.expand_dims(presence, -1)

    wins_per_caps = tf.reduce_sum(wins_per_caps, 1)

    has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0))
    should_be_active = tf.to_float(tf.greater(wins_per_caps, 1))

    sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=should_be_active, logits=res.pres_logit_per_caps)

    sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1)
    sparsity_loss = tf.reduce_mean(sparsity_loss)

    caps_presence_prob = tf.reduce_max(
        tf.reshape(res.vote_presence,
                   [batch_size, self._n_caps, self._n_votes]), 2)

    res.update(dict(
        mixing_kl=mixing_kl,
        sparsity_loss=sparsity_loss,
        caps_presence_prob=caps_presence_prob,
        mean_scale=tf.reduce_mean(res.scale)
    ))
    return res