def test_legacy_lm_loss_compatibility(self):
        """Test to validate computational correctness during refactors."""
        # This is the empirical output of a masked LM with the following parameters:
        #   batch_size = 3
        #   vocab_size = 5
        #   sequence_length = 4
        #   num_predictions = 2
        output_data = np.array(
            [[[-2.5286622, -1.0963473, -1.4925185, -2.4451098, -1.2923571],
              [-2.7117882, -1.1205841, -4.02187, -0.9966936, -1.5119683]],
             [[-2.5379114, -0.82479054, -2.287932, -1.3747153, -2.053741],
              [-2.5379114, -0.82479054, -2.287932, -1.3747153, -2.053741]],
             [[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509],
              [-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
        labels = np.array([[4, 0], [2, 2], [2, 1]])

        # Validate that per_example loss calculations are the same.
        per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
            predictions=output_data, labels=labels)
        expected_per_example_loss_data = [[1.2923571, 2.7117882],
                                          [2.287932, 2.287932],
                                          [3.0924666, 1.8219438]]
        self.assertAllClose(expected_per_example_loss_data,
                            per_example_loss_data)

        # Validate that overall loss calculations are the same.
        weights = np.array([[1, 0], [0, 0], [0, 0]])
        loss_data = weighted_sparse_categorical_crossentropy.loss(
            predictions=output_data, labels=labels, weights=weights)
        expected_loss_data = 1.2923441
        self.assertAllClose(expected_loss_data, loss_data)
    def test_per_example_loss_3d_input(self):
        """Test per-example loss with a 3-dimensional input, from a masked LM."""
        vocab_size = 100
        sequence_length = 32
        hidden_size = 64
        num_predictions = 21
        model = self.create_lm_model(vocab_size=vocab_size,
                                     sequence_length=sequence_length,
                                     hidden_size=hidden_size,
                                     num_predictions=num_predictions)

        # Get the output of the masked LM.
        batch_size = 3
        lm_input_data = 10 * np.random.random_sample(
            (batch_size, sequence_length, hidden_size))
        masked_position_data = np.random.randint(2,
                                                 size=(batch_size,
                                                       num_predictions))
        output_data = model.predict([lm_input_data, masked_position_data])

        # Calculate per-example loss.
        labels = np.random.randint(vocab_size,
                                   size=(batch_size, num_predictions))
        per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
            predictions=output_data, labels=labels)

        # Per-example loss data should have one value per prediction, and those
        # values shouldn't be zero in this case (as we're using random data).
        expected_shape = [batch_size, num_predictions]
        self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
        self.assertNotAllClose(tf.zeros_like(per_example_loss_data),
                               per_example_loss_data)
    def test_per_example_loss_weights_3d_input(self):
        """Test weighted per-example loss with a 3-d input, from a masked LM."""
        vocab_size = 100
        sequence_length = 32
        hidden_size = 64
        num_predictions = 21
        model = self.create_lm_model(vocab_size=vocab_size,
                                     sequence_length=sequence_length,
                                     hidden_size=hidden_size,
                                     num_predictions=num_predictions)

        # Get the output of the masked LM.
        batch_size = 3
        lm_input_data = 10 * np.random.random_sample(
            (batch_size, sequence_length, hidden_size))
        masked_position_data = np.random.randint(2,
                                                 size=(batch_size,
                                                       num_predictions))
        output_data = model.predict([lm_input_data, masked_position_data])

        # Calculate per-example loss with weights.
        labels = np.random.randint(vocab_size,
                                   size=(batch_size, num_predictions))
        weights = np.random.randint(2, size=(batch_size, num_predictions))

        per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
            predictions=output_data, labels=labels, weights=weights)

        # Weighted per-example loss data should be equivalent to multiplying the
        # loss tensor by the weights tensor.
        expected_weighted_loss = per_example_loss_data * weights
        self.assertAllClose(expected_weighted_loss, per_example_loss_data)
    def test_mismatched_predictions_and_labels_ranks_squeezes(self):
        """Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
        batch_size = 3
        output_data = np.random.random_sample((batch_size, 10))
        labels = np.random.randint(10, size=(batch_size, 1))

        # All that this test tests is that the squeeze is successful.
        _ = weighted_sparse_categorical_crossentropy.per_example_loss(
            predictions=output_data, labels=labels)
    def test_mismatched_weights_and_labels_ranks_fail(self):
        """Test that the loss asserts when rank(predictions) != rank(labels)."""
        batch_size = 3
        output_data = np.random.random_sample((batch_size, 10, 15))
        labels = np.random.randint(10, size=(batch_size, 10))
        weights = np.random.randint(2, size=(batch_size))

        with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
            _ = weighted_sparse_categorical_crossentropy.per_example_loss(
                predictions=output_data, labels=labels, weights=weights)
        with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
            _ = weighted_sparse_categorical_crossentropy.loss(
                predictions=output_data, labels=labels, weights=weights)
  def test_tf_tensor_inputs(self):
    """Test that tf.Tensors can be used as inputs to the loss function."""
    batch_size = 3
    output_data = tf.convert_to_tensor(
        np.random.random_sample((batch_size, 10, 15)))
    labels = tf.convert_to_tensor(np.random.randint(10, size=(batch_size, 10)))
    weights = tf.convert_to_tensor(np.random.randint(2, size=(batch_size, 10)))

    # We're not trying to validate numerical correctness, just ensure that
    # we can in fact pass tensors to these functions without causing runtime
    # errors from the shape checking code.
    _ = weighted_sparse_categorical_crossentropy.per_example_loss(
        predictions=output_data, labels=labels, weights=weights)
    _ = weighted_sparse_categorical_crossentropy.loss(
        predictions=output_data, labels=labels, weights=weights)
    def test_per_example_loss_2d_input(self):
        """Test per-example loss with a 2-d input, from a classifier."""
        input_width = 512
        num_classes = 10
        model = self.create_classification_model(input_width, num_classes)

        # Invoke the network as part of a Model.
        batch_size = 3
        input_data = 10 * np.random.random_sample((batch_size, input_width))
        output_data = model.predict(input_data)

        # Calculate per example loss.
        labels = np.random.randint(num_classes, size=(batch_size))
        per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
            predictions=output_data, labels=labels)

        # Per-example loss data should have one value per batch item, and those
        # values shouldn't be zero in this case (as we're using random data).
        self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
        self.assertNotAllClose(tf.zeros_like(per_example_loss_data),
                               per_example_loss_data)
  def test_legacy_classification_loss_compatibility(self):
    """Test to validate computational correctness during refactors."""
    # This is the empirical output of a classifier with the following params:
    #   batch_size = 2
    #   num_classes = 3
    output_data = np.array([[-1.6094601e-03, -1.0966038e+01, -6.4434357e+00],
                            [-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
    labels = np.array([2, 1])

    # Validate that per_example loss calculations are the same.
    per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
        predictions=output_data, labels=labels)
    expected_per_example_loss_data = [6.4434357, 6.4009643]
    self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)

    # Validate that overall loss calculations are the same.
    weights = None
    loss_data = weighted_sparse_categorical_crossentropy.loss(
        predictions=output_data, labels=labels, weights=weights)
    expected_loss_data = 6.4222
    self.assertAllClose(expected_loss_data, loss_data)
    def test_per_example_loss_weights_2d_input(self):
        """Test weighted per-example loss with a 2-d input, from a classifier."""
        input_width = 512
        num_classes = 10
        model = self.create_classification_model(input_width, num_classes)

        # Invoke the network as part of a Model.
        batch_size = 3
        input_data = 10 * np.random.random_sample((batch_size, input_width))
        output_data = model.predict(input_data)

        # Calculate per-example loss with weights.
        labels = np.random.randint(num_classes, size=(batch_size))
        weights = np.random.randint(2, size=(batch_size))

        per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
            predictions=output_data, labels=labels, weights=weights)

        # Weighted per-example loss data should be equivalent to multiplying the
        # loss tensor by the weights tensor.
        expected_weighted_loss = per_example_loss_data * weights
        self.assertAllClose(expected_weighted_loss, per_example_loss_data)