示例#1
0
    def compute_representation_loss(self, pose_embeddings, view_embeddings,
                                    subencoder_embeddings,
                                    positive_indicator_matrix):
        """Computes the representation loss.

    Args:
      pose_embeddings: A tensor for the pose embeddings. Shape = [batch_size,
        1, pose_embedding_dim].
      view_embeddings: A tensor for the view embeddings. Shape = [batch_size,
        1, view_embedding_dim].
      subencoder_embeddings: A tensor for the embedding of the subencoder.
        Shape = [batch_size, 1, embedding_dim].
      positive_indicator_matrix: A tensor for positive indicator matrix. The
        positive correspondences will have value 1.0 and otherwise 0.0. Shape =
        [batch_size, batch_size].

    Returns:
      representation_loss: A scalar for the representation loss.
    """
        if self._fusion_op_type == TYPE_FUSION_OP_CAT:
            fusion_embeddings = tf.concat([pose_embeddings, view_embeddings],
                                          axis=-1)
        else:
            if self._fusion_op_type == TYPE_FUSION_OP_POE:
                fusion_embeddings = pose_embeddings * view_embeddings
            elif self._fusion_op_type == TYPE_FUSION_OP_MOE:
                fusion_embeddings = 0.5 * (pose_embeddings + view_embeddings)
            else:
                raise ValueError('Unknown fusion operation: {}'.format(
                    self._fusion_op_type))

        representation_loss = losses.compute_fenchel_dual_loss(
            subencoder_embeddings, fusion_embeddings, losses.TYPE_MEASURE_JSD,
            positive_indicator_matrix)
        return representation_loss
  def test_compute_fenchel_dual_loss(self):
    input_features = tf.ones([4, 3, 5], tf.float32)
    loss = losses.compute_fenchel_dual_loss(
        input_features, input_features, losses.TYPE_MEASURE_W1)
    self.assertAllEqual(loss, 0.)

    input_features = tf.constant([[[1., 1.], [0., 1.]], [[1., 0.], [1., 1.]],
                                  [[1., 0.], [1., 1.]], [[1., 1.], [1., 1.]]])
    loss = losses.compute_fenchel_dual_loss(input_features, input_features,
                                            losses.TYPE_MEASURE_W1)
    self.assertAllClose(loss, 15.5 / 12 - 5.75 / 4)

    loss = losses.compute_fenchel_dual_loss(
        input_features, input_features, losses.TYPE_MEASURE_W1,
        tf.eye(4, dtype=tf.dtypes.float32))
    self.assertAllClose(loss, 15.5 / 12 - 5.75 / 4)

    loss = losses.compute_fenchel_dual_loss(
        input_features, input_features, losses.TYPE_MEASURE_W1,
        tf.ones((4, 4), dtype=tf.dtypes.float32))
    self.assertAllClose(loss, -(15.5 + 5.75) / 16)
示例#3
0
    def compute_representation_loss(self, inputs, positive_indicator_matrix):
        """Computes the representation loss.

    Args:
      inputs: An input tensor. Shape = [batch_size, num_points, ...].
      positive_indicator_matrix: A tensor for positive indicator matrix. The
        positive correspondences will have value 1.0 and otherwise 0.0. Shape =
        [batch_size, batch_size].

    Returns:
      output_embeddings: A tensor for the embedding. Shape = [batch_size, 1,
        embedding_dim].
      representation_loss: A scalar for the representation loss.
      regularization_loss: A scalar for the regularization loss.
    """
        output_embeddings, subencoder_output_embeddings = self(inputs,
                                                               training=True)

        if self.fusion_op_type == TYPE_FUSION_OP_CAT:
            fusion_embeddings = output_embeddings
        else:
            pose_embeddings, view_embeddings = tf.split(
                output_embeddings,
                num_or_size_splits=[
                    self.pose_embedding_dim, self.view_embedding_dim
                ],
                axis=-1)
            if self.fusion_op_type == TYPE_FUSION_OP_POE:
                fusion_embeddings = pose_embeddings * view_embeddings
            elif self.fusion_op_type == TYPE_FUSION_OP_MOE:
                fusion_embeddings = 0.5 * (pose_embeddings + view_embeddings)
            else:
                raise ValueError('Unknown fusion operation: {}'.format(
                    self.fusion_op_type))

        representation_loss = losses.compute_fenchel_dual_loss(
            subencoder_output_embeddings, fusion_embeddings,
            losses.TYPE_MEASURE_JSD, positive_indicator_matrix)
        regularization_loss = sum(self.encoder.losses)
        return output_embeddings, representation_loss, regularization_loss
示例#4
0
    def compute_representation_loss(self, inputs, positive_indicator_matrix):
        """Computes the representation loss.

    Args:
      inputs: An input tensor. Shape = [batch_size, num_points, ...].
      positive_indicator_matrix: A tensor for positive indicator matrix. The
        positive correspondences will have value 1.0 and otherwise 0.0. Shape =
        [batch_size, batch_size].

    Returns:
      output_embeddings: A tensor for the embedding. Shape = [batch_size, 1,
        embedding_dim].
      representation_loss: A scalar for the representation loss.
      regularization_loss: A scalar for the regularization loss.
    """
        output_embeddings, subencoder_output_embeddings = self(inputs,
                                                               training=True)
        representation_loss = losses.compute_fenchel_dual_loss(
            subencoder_output_embeddings, output_embeddings,
            losses.TYPE_MEASURE_JSD, positive_indicator_matrix)
        regularization_loss = sum(self.encoder.losses)
        return output_embeddings, representation_loss, regularization_loss
示例#5
0
    def train(self, inputs, encoder_optimizer, estimator_optimizer,
              discriminator_optimizer):
        """Trains the model for one step.

    Args:
      inputs: A list of input tensors containing 2D and 3D keypoints. Shape = [
        batch_size, num_instances, num_joints, {2|3}]
      encoder_optimizer: An optimizer object for ecnoder.
      estimator_optimizer: An optimizer object for estimator.
      discriminator_optimizer: An optimizer object for discriminator.

    Returns:
      A dictionary for all losses.
    """
        keypoints_2d, keypoints_3d = inputs
        anchor_keypoints_2d, positive_keypoints_2d = tf.split(
            keypoints_2d, num_or_size_splits=[1, 1], axis=1)
        anchor_keypoints_3d, positive_keypoints_3d = tf.split(
            keypoints_3d, num_or_size_splits=[1, 1], axis=1)

        anchor_keypoints_2d = tf.squeeze(anchor_keypoints_2d, axis=1)
        positive_keypoints_2d = tf.squeeze(positive_keypoints_2d, axis=1)
        anchor_keypoints_3d = tf.squeeze(anchor_keypoints_3d, axis=1)
        positive_keypoints_3d = tf.squeeze(positive_keypoints_3d, axis=1)

        if MAX_POSITIVE_KEYPOINT_MPJPE_2D is None:
            anchor_indicator_matrix = None
            positive_indicator_matrix = None
        else:
            anchor_indicator_matrix = compute_positive_indicator_matrix(
                anchor_keypoints_2d,
                anchor_keypoints_2d,
                distance_fn=keypoint_utils.compute_mpjpes,
                max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D)
            positive_indicator_matrix = compute_positive_indicator_matrix(
                positive_keypoints_2d,
                positive_keypoints_2d,
                distance_fn=keypoint_utils.compute_mpjpes,
                max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D)

        if MAX_POSITIVE_KEYPOINT_MPJPE_3D is None:
            view_indicator_matrix = None
        else:
            view_indicator_matrix = compute_positive_indicator_matrix(
                anchor_keypoints_3d,
                positive_keypoints_3d,
                distance_fn=keypoint_utils.compute_procrustes_aligned_mpjpes,
                max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_3D)

        def compute_estimator_loss(estimator, x, y, positive_indicator_matrix):
            x_mean, x_logvar = estimator(x, training=True)
            likelihood = losses.compute_log_likelihood(x_mean, x_logvar, y)
            bound = losses.compute_contrastive_log_ratio(
                x_mean, x_logvar, y, positive_indicator_matrix)
            return likelihood, bound

        with tf.GradientTape() as encoder_tape, tf.GradientTape(
        ) as estimator_tape, tf.GradientTape() as discriminator_tape:
            (anchor_embeddings, anchor_representation_loss,
             anchor_regularization_loss) = self.compute_representation_loss(
                 anchor_keypoints_2d, anchor_indicator_matrix)
            (positive_embeddings, positive_representation_loss,
             positive_regularization_loss) = self.compute_representation_loss(
                 positive_keypoints_2d, positive_indicator_matrix)

            representation_loss = (anchor_representation_loss +
                                   positive_representation_loss)
            regularization_loss = self.regularization_loss_weight * (
                anchor_regularization_loss + positive_regularization_loss)
            encoder_total_loss = representation_loss + regularization_loss

            anchor_pose_embeddings, anchor_view_embeddings = tf.split(
                anchor_embeddings,
                num_or_size_splits=[
                    self.pose_embedding_dim, self.view_embedding_dim
                ],
                axis=-1)
            positive_pose_embeddings, positive_view_embeddings = tf.split(
                positive_embeddings,
                num_or_size_splits=[
                    self.pose_embedding_dim, self.view_embedding_dim
                ],
                axis=-1)

            view_loss = self.view_loss_weight * losses.compute_fenchel_dual_loss(
                anchor_pose_embeddings, positive_pose_embeddings,
                losses.TYPE_MEASURE_JSD, view_indicator_matrix) * 2.0
            encoder_total_loss += view_loss

            anchor_view_embeddings = tf.squeeze(anchor_view_embeddings, axis=1)
            anchor_pose_embeddings = tf.squeeze(anchor_pose_embeddings, axis=1)
            positive_view_embeddings = tf.squeeze(positive_view_embeddings,
                                                  axis=1)
            positive_pose_embeddings = tf.squeeze(positive_pose_embeddings,
                                                  axis=1)

            (anchor_prior_loss,
             anchor_discriminator_loss) = self.compute_uniform_prior_loss(
                 anchor_pose_embeddings, anchor_view_embeddings)
            (positive_prior_loss,
             positive_discriminator_loss) = self.compute_uniform_prior_loss(
                 positive_pose_embeddings, positive_view_embeddings)
            prior_loss = anchor_prior_loss + positive_prior_loss
            encoder_total_loss += prior_loss
            discriminator_total_loss = (anchor_discriminator_loss +
                                        positive_discriminator_loss)

            inter_likelihood, inter_bound = compute_estimator_loss(
                self.inter_likelihood_estimator, anchor_view_embeddings,
                positive_view_embeddings, view_indicator_matrix)
            anchor_intra_likelihood, anchor_intra_bound = compute_estimator_loss(
                self.intra_likelihood_estimator, anchor_pose_embeddings,
                anchor_view_embeddings, anchor_indicator_matrix)
            positive_intra_likelihood, positive_intra_bound = compute_estimator_loss(
                self.intra_likelihood_estimator, positive_pose_embeddings,
                positive_view_embeddings, positive_indicator_matrix)

            intra_bound_loss = self.disentangle_loss_weight * (
                anchor_intra_bound + positive_intra_bound)
            inter_bound_loss = self.disentangle_loss_weight * inter_bound
            disentangle_loss = intra_bound_loss + inter_bound_loss * 2.0
            encoder_total_loss += disentangle_loss

            anchor_intra_likelihood_loss = -anchor_intra_likelihood
            positive_intra_likelihood_loss = -positive_intra_likelihood
            inter_likelihood_loss = -inter_likelihood
            estimator_total_loss = (anchor_intra_likelihood_loss +
                                    positive_intra_likelihood_loss +
                                    inter_likelihood_loss * 2.0)

        encoder_trainable_variables = (self.encoder.trainable_variables +
                                       self.subencoder.trainable_variables)
        encoder_grads = encoder_tape.gradient(encoder_total_loss,
                                              encoder_trainable_variables)
        estimator_trainable_variables = (
            self.intra_likelihood_estimator.trainable_variables +
            self.inter_likelihood_estimator.trainable_variables)
        estimator_grads = estimator_tape.gradient(
            estimator_total_loss, estimator_trainable_variables)

        discriminator_trainable_variables = self.discriminator.trainable_variables
        discriminator_grads = discriminator_tape.gradient(
            discriminator_total_loss, discriminator_trainable_variables)

        encoder_optimizer.apply_gradients(
            zip(encoder_grads, encoder_trainable_variables))
        estimator_optimizer.apply_gradients(
            zip(estimator_grads, estimator_trainable_variables))
        discriminator_optimizer.apply_gradients(
            zip(discriminator_grads, discriminator_trainable_variables))

        encoder_losses = dict(total_loss=encoder_total_loss,
                              representation_loss=representation_loss,
                              regularization_loss=regularization_loss,
                              view_loss=view_loss,
                              prior_loss=prior_loss,
                              disentangle_loss=disentangle_loss,
                              intra_bound_loss=intra_bound_loss,
                              inter_bound_loss=inter_bound_loss)

        estimator_losses = dict(
            total_loss=estimator_total_loss,
            anchor_intra_likelihood_loss=anchor_intra_likelihood_loss,
            positive_intra_likelihood_loss=positive_intra_likelihood_loss,
            inter_likelihood_loss=inter_likelihood_loss)

        discriminator_losses = dict(total_loss=discriminator_total_loss)

        return dict(encoder=encoder_losses,
                    estimator=estimator_losses,
                    discriminator=discriminator_losses)
示例#6
0
    def train(self, inputs, encoder_optimizer):
        """Trains the model for one step.

    Args:
      inputs: A list of input tensors containing 2D and 3D keypoints. Shape = [
        batch_size, num_instances, num_joints, {2|3}]
      encoder_optimizer: An optimizer object.

    Returns:
      A dictionary for all losses.
    """
        keypoints_2d, keypoints_3d = inputs
        anchor_keypoints_2d, positive_keypoints_2d = tf.split(
            keypoints_2d, num_or_size_splits=[1, 1], axis=1)
        anchor_keypoints_3d, positive_keypoints_3d = tf.split(
            keypoints_3d, num_or_size_splits=[1, 1], axis=1)

        anchor_keypoints_2d = tf.squeeze(anchor_keypoints_2d, axis=1)
        positive_keypoints_2d = tf.squeeze(positive_keypoints_2d, axis=1)
        anchor_keypoints_3d = tf.squeeze(anchor_keypoints_3d, axis=1)
        positive_keypoints_3d = tf.squeeze(positive_keypoints_3d, axis=1)

        if MAX_POSITIVE_KEYPOINT_MPJPE_2D is None:
            anchor_indicator_matrix = None
            positive_indicator_matrix = None
        else:
            anchor_indicator_matrix = compute_positive_indicator_matrix(
                anchor_keypoints_2d,
                anchor_keypoints_2d,
                distance_fn=keypoint_utils.compute_mpjpes,
                max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D)
            positive_indicator_matrix = compute_positive_indicator_matrix(
                positive_keypoints_2d,
                positive_keypoints_2d,
                distance_fn=keypoint_utils.compute_mpjpes,
                max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D)

        if MAX_POSITIVE_KEYPOINT_MPJPE_3D is None:
            view_indicator_matrix = None
        else:
            view_indicator_matrix = compute_positive_indicator_matrix(
                anchor_keypoints_3d,
                positive_keypoints_3d,
                distance_fn=keypoint_utils.compute_procrustes_aligned_mpjpes,
                max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_3D)

        with tf.GradientTape() as tape:
            (anchor_embeddings, anchor_representation_loss,
             anchor_regularization_loss) = self.compute_representation_loss(
                 anchor_keypoints_2d, anchor_indicator_matrix)
            (positive_embeddings, positive_representation_loss,
             positive_regularization_loss) = self.compute_representation_loss(
                 positive_keypoints_2d, positive_indicator_matrix)

            representation_loss = (anchor_representation_loss +
                                   positive_representation_loss)
            regularization_loss = self.regularization_loss_weight * (
                anchor_regularization_loss + positive_regularization_loss)
            total_loss = representation_loss + regularization_loss

            view_loss = self.view_loss_weight * losses.compute_fenchel_dual_loss(
                anchor_embeddings, positive_embeddings,
                losses.TYPE_MEASURE_JSD, view_indicator_matrix) * 2.0
            total_loss += view_loss

        grads = tape.gradient(total_loss, self.trainable_variables)
        encoder_optimizer.apply_gradients(zip(grads, self.trainable_variables))
        encoder_losses = dict(total_loss=total_loss,
                              representation_loss=representation_loss,
                              regularization_loss=regularization_loss,
                              view_loss=view_loss)

        return dict(encoder=encoder_losses)