def _build(self, h, x, presence=None): """Builds the module. Args: h: Tensor of encodings of shape [B, n_enc_dims]. x: Tensor of inputs of shape [B, n_points, n_input_dims] presence: Tensor of shape [B, n_points, 1] or None; if it exists, it indicates which input points exist. Returns: A bunch of stuff. """ batch_size = int(x.shape[0]) capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims, self._n_votes, **self._capsule_kwargs) res = capsule(h) vote_shape = [batch_size, self._n_caps, self._n_votes, 6] res.vote = tf.reshape(res.vote[Ellipsis, :-1, :], vote_shape) votes, scale, vote_presence_prob = res.vote, res.scale, res.vote_presence likelihood = _capsule.CapsuleLikelihood(votes, scale, vote_presence_prob) ll_res = likelihood(x, presence) res.update(ll_res._asdict()) caps_presence_prob = tf.reduce_max( tf.reshape(vote_presence_prob, [batch_size, self._n_caps, self._n_votes]), 2) res.caps_presence_prob = caps_presence_prob return res
def _build(self, h, x, presence=None): """Builds the module. Args: h: Tensor of encodings of shape [B, n_enc_dims]. x: Tensor of inputs of shape [B, n_points, n_input_dims] presence: Tensor of shape [B, n_points, 1] or None; if it exists, it indicates which input points exist. Returns: A bunch of stuff. """ batch_size, n_input_points, _ = x.shape.as_list() capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims, self._n_votes, **self._capsule_kwargs) res = capsule(h) res.transform = res.vote res.vote = math_ops.apply_transform(transform=res.vote) for k, v in res.items(): if v.shape.ndims > 0: res[k] = snt.MergeDims(1, 2)(v) likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes, res.vote, res.scale, res.vote_presence) ll_res = likelihood(x, presence) res.update(ll_res._asdict()) # post processing mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1) prior_mixing_log_prob = tf.log(1. / n_input_points) mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob) mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1)) wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps) if presence is not None: wins_per_caps *= tf.expand_dims(presence, -1) wins_per_caps = tf.reduce_sum(wins_per_caps, 1) has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0)) should_be_active = tf.to_float(tf.greater(wins_per_caps, 1)) sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=should_be_active, logits=res.pres_logit_per_caps) sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1) sparsity_loss = tf.reduce_mean(sparsity_loss) caps_presence_prob = tf.reduce_max( tf.reshape(res.vote_presence, [batch_size, self._n_caps, self._n_votes]), 2) res.update(dict( mixing_kl=mixing_kl, sparsity_loss=sparsity_loss, caps_presence_prob=caps_presence_prob, mean_scale=tf.reduce_mean(res.scale) )) return res