def compute_representation_loss(self, pose_embeddings, view_embeddings, subencoder_embeddings, positive_indicator_matrix): """Computes the representation loss. Args: pose_embeddings: A tensor for the pose embeddings. Shape = [batch_size, 1, pose_embedding_dim]. view_embeddings: A tensor for the view embeddings. Shape = [batch_size, 1, view_embedding_dim]. subencoder_embeddings: A tensor for the embedding of the subencoder. Shape = [batch_size, 1, embedding_dim]. positive_indicator_matrix: A tensor for positive indicator matrix. The positive correspondences will have value 1.0 and otherwise 0.0. Shape = [batch_size, batch_size]. Returns: representation_loss: A scalar for the representation loss. """ if self._fusion_op_type == TYPE_FUSION_OP_CAT: fusion_embeddings = tf.concat([pose_embeddings, view_embeddings], axis=-1) else: if self._fusion_op_type == TYPE_FUSION_OP_POE: fusion_embeddings = pose_embeddings * view_embeddings elif self._fusion_op_type == TYPE_FUSION_OP_MOE: fusion_embeddings = 0.5 * (pose_embeddings + view_embeddings) else: raise ValueError('Unknown fusion operation: {}'.format( self._fusion_op_type)) representation_loss = losses.compute_fenchel_dual_loss( subencoder_embeddings, fusion_embeddings, losses.TYPE_MEASURE_JSD, positive_indicator_matrix) return representation_loss
def test_compute_fenchel_dual_loss(self): input_features = tf.ones([4, 3, 5], tf.float32) loss = losses.compute_fenchel_dual_loss( input_features, input_features, losses.TYPE_MEASURE_W1) self.assertAllEqual(loss, 0.) input_features = tf.constant([[[1., 1.], [0., 1.]], [[1., 0.], [1., 1.]], [[1., 0.], [1., 1.]], [[1., 1.], [1., 1.]]]) loss = losses.compute_fenchel_dual_loss(input_features, input_features, losses.TYPE_MEASURE_W1) self.assertAllClose(loss, 15.5 / 12 - 5.75 / 4) loss = losses.compute_fenchel_dual_loss( input_features, input_features, losses.TYPE_MEASURE_W1, tf.eye(4, dtype=tf.dtypes.float32)) self.assertAllClose(loss, 15.5 / 12 - 5.75 / 4) loss = losses.compute_fenchel_dual_loss( input_features, input_features, losses.TYPE_MEASURE_W1, tf.ones((4, 4), dtype=tf.dtypes.float32)) self.assertAllClose(loss, -(15.5 + 5.75) / 16)
def compute_representation_loss(self, inputs, positive_indicator_matrix): """Computes the representation loss. Args: inputs: An input tensor. Shape = [batch_size, num_points, ...]. positive_indicator_matrix: A tensor for positive indicator matrix. The positive correspondences will have value 1.0 and otherwise 0.0. Shape = [batch_size, batch_size]. Returns: output_embeddings: A tensor for the embedding. Shape = [batch_size, 1, embedding_dim]. representation_loss: A scalar for the representation loss. regularization_loss: A scalar for the regularization loss. """ output_embeddings, subencoder_output_embeddings = self(inputs, training=True) if self.fusion_op_type == TYPE_FUSION_OP_CAT: fusion_embeddings = output_embeddings else: pose_embeddings, view_embeddings = tf.split( output_embeddings, num_or_size_splits=[ self.pose_embedding_dim, self.view_embedding_dim ], axis=-1) if self.fusion_op_type == TYPE_FUSION_OP_POE: fusion_embeddings = pose_embeddings * view_embeddings elif self.fusion_op_type == TYPE_FUSION_OP_MOE: fusion_embeddings = 0.5 * (pose_embeddings + view_embeddings) else: raise ValueError('Unknown fusion operation: {}'.format( self.fusion_op_type)) representation_loss = losses.compute_fenchel_dual_loss( subencoder_output_embeddings, fusion_embeddings, losses.TYPE_MEASURE_JSD, positive_indicator_matrix) regularization_loss = sum(self.encoder.losses) return output_embeddings, representation_loss, regularization_loss
def compute_representation_loss(self, inputs, positive_indicator_matrix): """Computes the representation loss. Args: inputs: An input tensor. Shape = [batch_size, num_points, ...]. positive_indicator_matrix: A tensor for positive indicator matrix. The positive correspondences will have value 1.0 and otherwise 0.0. Shape = [batch_size, batch_size]. Returns: output_embeddings: A tensor for the embedding. Shape = [batch_size, 1, embedding_dim]. representation_loss: A scalar for the representation loss. regularization_loss: A scalar for the regularization loss. """ output_embeddings, subencoder_output_embeddings = self(inputs, training=True) representation_loss = losses.compute_fenchel_dual_loss( subencoder_output_embeddings, output_embeddings, losses.TYPE_MEASURE_JSD, positive_indicator_matrix) regularization_loss = sum(self.encoder.losses) return output_embeddings, representation_loss, regularization_loss
def train(self, inputs, encoder_optimizer, estimator_optimizer, discriminator_optimizer): """Trains the model for one step. Args: inputs: A list of input tensors containing 2D and 3D keypoints. Shape = [ batch_size, num_instances, num_joints, {2|3}] encoder_optimizer: An optimizer object for ecnoder. estimator_optimizer: An optimizer object for estimator. discriminator_optimizer: An optimizer object for discriminator. Returns: A dictionary for all losses. """ keypoints_2d, keypoints_3d = inputs anchor_keypoints_2d, positive_keypoints_2d = tf.split( keypoints_2d, num_or_size_splits=[1, 1], axis=1) anchor_keypoints_3d, positive_keypoints_3d = tf.split( keypoints_3d, num_or_size_splits=[1, 1], axis=1) anchor_keypoints_2d = tf.squeeze(anchor_keypoints_2d, axis=1) positive_keypoints_2d = tf.squeeze(positive_keypoints_2d, axis=1) anchor_keypoints_3d = tf.squeeze(anchor_keypoints_3d, axis=1) positive_keypoints_3d = tf.squeeze(positive_keypoints_3d, axis=1) if MAX_POSITIVE_KEYPOINT_MPJPE_2D is None: anchor_indicator_matrix = None positive_indicator_matrix = None else: anchor_indicator_matrix = compute_positive_indicator_matrix( anchor_keypoints_2d, anchor_keypoints_2d, distance_fn=keypoint_utils.compute_mpjpes, max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D) positive_indicator_matrix = compute_positive_indicator_matrix( positive_keypoints_2d, positive_keypoints_2d, distance_fn=keypoint_utils.compute_mpjpes, max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D) if MAX_POSITIVE_KEYPOINT_MPJPE_3D is None: view_indicator_matrix = None else: view_indicator_matrix = compute_positive_indicator_matrix( anchor_keypoints_3d, positive_keypoints_3d, distance_fn=keypoint_utils.compute_procrustes_aligned_mpjpes, max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_3D) def compute_estimator_loss(estimator, x, y, positive_indicator_matrix): x_mean, x_logvar = estimator(x, training=True) likelihood = losses.compute_log_likelihood(x_mean, x_logvar, y) bound = losses.compute_contrastive_log_ratio( x_mean, x_logvar, y, positive_indicator_matrix) return likelihood, bound with tf.GradientTape() as encoder_tape, tf.GradientTape( ) as estimator_tape, tf.GradientTape() as discriminator_tape: (anchor_embeddings, anchor_representation_loss, anchor_regularization_loss) = self.compute_representation_loss( anchor_keypoints_2d, anchor_indicator_matrix) (positive_embeddings, positive_representation_loss, positive_regularization_loss) = self.compute_representation_loss( positive_keypoints_2d, positive_indicator_matrix) representation_loss = (anchor_representation_loss + positive_representation_loss) regularization_loss = self.regularization_loss_weight * ( anchor_regularization_loss + positive_regularization_loss) encoder_total_loss = representation_loss + regularization_loss anchor_pose_embeddings, anchor_view_embeddings = tf.split( anchor_embeddings, num_or_size_splits=[ self.pose_embedding_dim, self.view_embedding_dim ], axis=-1) positive_pose_embeddings, positive_view_embeddings = tf.split( positive_embeddings, num_or_size_splits=[ self.pose_embedding_dim, self.view_embedding_dim ], axis=-1) view_loss = self.view_loss_weight * losses.compute_fenchel_dual_loss( anchor_pose_embeddings, positive_pose_embeddings, losses.TYPE_MEASURE_JSD, view_indicator_matrix) * 2.0 encoder_total_loss += view_loss anchor_view_embeddings = tf.squeeze(anchor_view_embeddings, axis=1) anchor_pose_embeddings = tf.squeeze(anchor_pose_embeddings, axis=1) positive_view_embeddings = tf.squeeze(positive_view_embeddings, axis=1) positive_pose_embeddings = tf.squeeze(positive_pose_embeddings, axis=1) (anchor_prior_loss, anchor_discriminator_loss) = self.compute_uniform_prior_loss( anchor_pose_embeddings, anchor_view_embeddings) (positive_prior_loss, positive_discriminator_loss) = self.compute_uniform_prior_loss( positive_pose_embeddings, positive_view_embeddings) prior_loss = anchor_prior_loss + positive_prior_loss encoder_total_loss += prior_loss discriminator_total_loss = (anchor_discriminator_loss + positive_discriminator_loss) inter_likelihood, inter_bound = compute_estimator_loss( self.inter_likelihood_estimator, anchor_view_embeddings, positive_view_embeddings, view_indicator_matrix) anchor_intra_likelihood, anchor_intra_bound = compute_estimator_loss( self.intra_likelihood_estimator, anchor_pose_embeddings, anchor_view_embeddings, anchor_indicator_matrix) positive_intra_likelihood, positive_intra_bound = compute_estimator_loss( self.intra_likelihood_estimator, positive_pose_embeddings, positive_view_embeddings, positive_indicator_matrix) intra_bound_loss = self.disentangle_loss_weight * ( anchor_intra_bound + positive_intra_bound) inter_bound_loss = self.disentangle_loss_weight * inter_bound disentangle_loss = intra_bound_loss + inter_bound_loss * 2.0 encoder_total_loss += disentangle_loss anchor_intra_likelihood_loss = -anchor_intra_likelihood positive_intra_likelihood_loss = -positive_intra_likelihood inter_likelihood_loss = -inter_likelihood estimator_total_loss = (anchor_intra_likelihood_loss + positive_intra_likelihood_loss + inter_likelihood_loss * 2.0) encoder_trainable_variables = (self.encoder.trainable_variables + self.subencoder.trainable_variables) encoder_grads = encoder_tape.gradient(encoder_total_loss, encoder_trainable_variables) estimator_trainable_variables = ( self.intra_likelihood_estimator.trainable_variables + self.inter_likelihood_estimator.trainable_variables) estimator_grads = estimator_tape.gradient( estimator_total_loss, estimator_trainable_variables) discriminator_trainable_variables = self.discriminator.trainable_variables discriminator_grads = discriminator_tape.gradient( discriminator_total_loss, discriminator_trainable_variables) encoder_optimizer.apply_gradients( zip(encoder_grads, encoder_trainable_variables)) estimator_optimizer.apply_gradients( zip(estimator_grads, estimator_trainable_variables)) discriminator_optimizer.apply_gradients( zip(discriminator_grads, discriminator_trainable_variables)) encoder_losses = dict(total_loss=encoder_total_loss, representation_loss=representation_loss, regularization_loss=regularization_loss, view_loss=view_loss, prior_loss=prior_loss, disentangle_loss=disentangle_loss, intra_bound_loss=intra_bound_loss, inter_bound_loss=inter_bound_loss) estimator_losses = dict( total_loss=estimator_total_loss, anchor_intra_likelihood_loss=anchor_intra_likelihood_loss, positive_intra_likelihood_loss=positive_intra_likelihood_loss, inter_likelihood_loss=inter_likelihood_loss) discriminator_losses = dict(total_loss=discriminator_total_loss) return dict(encoder=encoder_losses, estimator=estimator_losses, discriminator=discriminator_losses)
def train(self, inputs, encoder_optimizer): """Trains the model for one step. Args: inputs: A list of input tensors containing 2D and 3D keypoints. Shape = [ batch_size, num_instances, num_joints, {2|3}] encoder_optimizer: An optimizer object. Returns: A dictionary for all losses. """ keypoints_2d, keypoints_3d = inputs anchor_keypoints_2d, positive_keypoints_2d = tf.split( keypoints_2d, num_or_size_splits=[1, 1], axis=1) anchor_keypoints_3d, positive_keypoints_3d = tf.split( keypoints_3d, num_or_size_splits=[1, 1], axis=1) anchor_keypoints_2d = tf.squeeze(anchor_keypoints_2d, axis=1) positive_keypoints_2d = tf.squeeze(positive_keypoints_2d, axis=1) anchor_keypoints_3d = tf.squeeze(anchor_keypoints_3d, axis=1) positive_keypoints_3d = tf.squeeze(positive_keypoints_3d, axis=1) if MAX_POSITIVE_KEYPOINT_MPJPE_2D is None: anchor_indicator_matrix = None positive_indicator_matrix = None else: anchor_indicator_matrix = compute_positive_indicator_matrix( anchor_keypoints_2d, anchor_keypoints_2d, distance_fn=keypoint_utils.compute_mpjpes, max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D) positive_indicator_matrix = compute_positive_indicator_matrix( positive_keypoints_2d, positive_keypoints_2d, distance_fn=keypoint_utils.compute_mpjpes, max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_2D) if MAX_POSITIVE_KEYPOINT_MPJPE_3D is None: view_indicator_matrix = None else: view_indicator_matrix = compute_positive_indicator_matrix( anchor_keypoints_3d, positive_keypoints_3d, distance_fn=keypoint_utils.compute_procrustes_aligned_mpjpes, max_positive_distance=MAX_POSITIVE_KEYPOINT_MPJPE_3D) with tf.GradientTape() as tape: (anchor_embeddings, anchor_representation_loss, anchor_regularization_loss) = self.compute_representation_loss( anchor_keypoints_2d, anchor_indicator_matrix) (positive_embeddings, positive_representation_loss, positive_regularization_loss) = self.compute_representation_loss( positive_keypoints_2d, positive_indicator_matrix) representation_loss = (anchor_representation_loss + positive_representation_loss) regularization_loss = self.regularization_loss_weight * ( anchor_regularization_loss + positive_regularization_loss) total_loss = representation_loss + regularization_loss view_loss = self.view_loss_weight * losses.compute_fenchel_dual_loss( anchor_embeddings, positive_embeddings, losses.TYPE_MEASURE_JSD, view_indicator_matrix) * 2.0 total_loss += view_loss grads = tape.gradient(total_loss, self.trainable_variables) encoder_optimizer.apply_gradients(zip(grads, self.trainable_variables)) encoder_losses = dict(total_loss=total_loss, representation_loss=representation_loss, regularization_loss=regularization_loss, view_loss=view_loss) return dict(encoder=encoder_losses)