def train_step(self, video_text_pair_batch):
        """Executes one step of training.

        Args:
            video_text_pair_batch: the data to be inputted to the forward_pass
                function.
        """
        missing_experts = video_text_pair_batch[-1]

        with tf.GradientTape() as gradient_tape:
            encoder_output = self.forward_pass(video_text_pair_batch,
                                               training=True)
            similarity_matrix = build_similarity_matrix(*(*encoder_output,
                                                          missing_experts))
            loss = self.loss_function(similarity_matrix,
                                      self.margin_hyperparameter)

        gradients = gradient_tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients,
                                           self.trainable_variables))

        # It's wasteful to calculate ranking metrics for the entire train
        # dataset, so we just mark the values as NaN for keras.
        batch_metrics = {
            label: float("nan")
            for label in self.recall_at_k_labels
        }
        batch_metrics["median_rank"] = float("nan")
        batch_metrics["mean_rank"] = float("nan")
        batch_metrics["loss"] = loss

        return batch_metrics
    def test_computing_similarity_multiple_experts(self):
        """Tests building a similarity matrix with multiple experts."""
        tf.random.set_seed(2)

        text_embeddings_expert_one = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))
        video_embeddings_expert_one = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))

        text_embeddings_expert_two = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))
        video_embeddings_expert_two = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))

        text_embeddings_expert_three = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))
        video_embeddings_expert_three = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))

        mixture_weights = tf.random.uniform((self.mock_batch_size, 3))
        missing_experts = tf.constant([[False, True, False],
                                       [False, False, False],
                                       [False, False, True],
                                       [False, False, False],
                                       [False, True, True]])

        available_experts_float32 = 1 - tf.cast(missing_experts, tf.float32)

        weights, _ = tf.linalg.normalize(mixture_weights[:, None] *
                                         available_experts_float32[None, :],
                                         axis=-1,
                                         ord=1)

        expert_one_similarity = weights[:, :, 0] * tf.matmul(
            text_embeddings_expert_one,
            video_embeddings_expert_one,
            transpose_b=True)
        expert_two_similarity = weights[:, :, 1] * tf.matmul(
            text_embeddings_expert_two,
            video_embeddings_expert_two,
            transpose_b=True)
        expert_three_similarity = weights[:, :, 2] * tf.matmul(
            text_embeddings_expert_three,
            video_embeddings_expert_three,
            transpose_b=True)

        expected_matrix = (expert_one_similarity + expert_two_similarity +
                           expert_three_similarity)

        computed_matrix = build_similarity_matrix([
            video_embeddings_expert_one, video_embeddings_expert_two,
            video_embeddings_expert_three
        ], [
            text_embeddings_expert_one, text_embeddings_expert_two,
            text_embeddings_expert_three
        ], mixture_weights, missing_experts)

        self.assert_matricies_are_the_same(computed_matrix, expected_matrix)
    def test_step(self, video_text_pair_batch):
        """Executes one test step.

        Args:
            video_text_pair_batch: input to the forward_pass function.
                Additionally, for each video caption pair in this tuple, each
                video must have self.num_captions_per_video associated with it.
                Each video caption pair also must be adjacent to all other video
                caption pairs for the same video. 
        """
        video_text_pair_batch = self.remove_repeated_video_data(
            video_text_pair_batch)
        missing_experts = video_text_pair_batch[-1]

        video_results, text_results, mixture_weights = self.forward_pass(
            video_text_pair_batch, training=False)

        valid_metrics = {}
        loss = []
        ranks = []

        # Because there are multiple captions per video, we shard the embeddings
        # into self.captions_per_video shards. Because the video data is
        # repeated multiple times in a given batch, splitting the data and
        # computing retrieval methods on shards instead of computing metrics on
        # the entire validation set at once is the cleaner option.
        for caption_index in range(self.captions_per_video):
            shard_text_results = [
                embed[caption_index::self.captions_per_video]
                for embed in text_results
            ]
            shard_mixture_weights = mixture_weights[caption_index::self.
                                                    captions_per_video]

            similarity_matrix = build_similarity_matrix(
                video_results, shard_text_results, shard_mixture_weights,
                missing_experts)

            loss.append(
                self.loss_function(similarity_matrix,
                                   self.margin_hyperparameter))
            ranks.append(metrics.rankings.compute_ranks(similarity_matrix))

        ranks = tf.concat(ranks, axis=0)

        valid_metrics["loss"] = tf.reduce_mean(tf.stack(loss))
        valid_metrics["mean_rank"] = metrics.rankings.get_mean_rank(ranks)
        valid_metrics["median_rank"] = metrics.rankings.get_median_rank(ranks)

        for k, label in zip(self.recall_at_k_bounds, self.recall_at_k_labels):
            valid_metrics[label] = metrics.rankings.get_recall_at_k(ranks, k)

        return valid_metrics
    def test_computing_similarity_one_expert(self):
        """Tests building a similarity matrix when there is only one expert."""
        tf.random.set_seed(1)

        text_embeddings = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))
        video_embeddings = tf.random.uniform(
            (self.mock_batch_size, self.mock_embedding_dimensionality))

        mixture_weights = tf.ones((self.mock_batch_size, 1), tf.float32)
        missing_experts = tf.constant([[False]] * self.mock_batch_size)

        computed_matrix = build_similarity_matrix([video_embeddings],
                                                  [text_embeddings],
                                                  mixture_weights,
                                                  missing_experts)

        self.assertTrue(computed_matrix.shape == (self.mock_batch_size,
                                                  self.mock_batch_size))

        self.assert_matricies_are_the_same(
            computed_matrix,
            tf.matmul(text_embeddings, video_embeddings, transpose_b=True))
    def test_loss_calculation_with_different_embeddings(self):
        """Tests computing loss for mini-batches of varying quality embeddings.

        This tests computing loss for three sets of embeddings. The first set of
        embeddings are perfect, the second are ok, and the third are bad.
        """

        mock_missing_experts = tf.constant([[False], [False], [False]])
        mock_mixture_weights = tf.constant([[1.0], [1.0], [1.0]])

        mock_perfect_video_embeddings = [
            tf.constant([
                [-1.0, 0.0, 0.0],
                [0.0, 0.0, -1.0],
                [0.0, 1.0, 0.0],
            ])
        ]

        mock_perfect_text_embeddings = [
            tf.constant([
                [-1.0, 0.0, 0.0],
                [0.0, 0.0, -1.0],
                [0.0, 1.0, 0.0],
            ])
        ]

        similarity_matrix = build_similarity_matrix(
            mock_perfect_video_embeddings, mock_perfect_text_embeddings,
            mock_mixture_weights, mock_missing_experts)
        loss = bidirectional_max_margin_ranking_loss(similarity_matrix, 1.0)

        self.assertTrue(abs(loss.numpy() - 0.0) < self.error)

        similarity_matrix = build_similarity_matrix(
            mock_perfect_video_embeddings, mock_perfect_text_embeddings,
            mock_mixture_weights, mock_missing_experts)
        loss = bidirectional_max_margin_ranking_loss(similarity_matrix, 100.0)

        self.assertTrue(abs(loss.numpy() - 99.0) < self.error)

        mock_good_video_embeddings = [
            tf.constant([
                [-0.9938837, 0.11043153],
                [-0.70710677, 0.70710677],
                [0.0, 1.0],
            ])
        ]

        mock_good_text_embeddings = [
            tf.constant([
                [-1.0, 0.0],
                [-0.5547002, 0.8320503],
                [0.0, 1.0],
            ])
        ]

        similarity_matrix = build_similarity_matrix(mock_good_video_embeddings,
                                                    mock_good_text_embeddings,
                                                    mock_mixture_weights,
                                                    mock_missing_experts)
        loss = bidirectional_max_margin_ranking_loss(similarity_matrix, 1.0)

        self.assertTrue(abs(loss.numpy() - 0.5084931) < self.error)

        mock_missing_experts = tf.constant([[False], [False], [False],
                                            [False]])
        mock_mixture_weights = tf.constant([[1.0], [1.0], [1.0], [1.0]])

        mock_bad_video_embeddings = [
            tf.constant([
                [0.25, 0.25],
                [1.0, 1.0],
                [0.6, 0.5],
                [0.9, 0.8],
            ])
        ]

        mock_bad_text_embeddings = [
            tf.constant([
                [-1.0, 0.0],
                [0.0, 1.0],
                [-1.0, 1.0],
                [0.7, 0.6],
            ])
        ]

        similarity_matrix = build_similarity_matrix(mock_bad_video_embeddings,
                                                    mock_bad_text_embeddings,
                                                    mock_mixture_weights,
                                                    mock_missing_experts)
        loss = bidirectional_max_margin_ranking_loss(similarity_matrix, 1.5)

        self.assertTrue(abs(loss.numpy() - 1.21000000333) < self.error)