def test_with_dynamic_ranks(self, gamma, from_logits):
        # y_true must have defined rank
        y_true = tf.keras.backend.placeholder(None, dtype=tf.int64)
        y_pred = tf.keras.backend.placeholder((None, 2), dtype=tf.float32)
        with self.assertRaises(NotImplementedError):
            sparse_categorical_focal_loss(y_true,
                                          y_pred,
                                          gamma=gamma,
                                          from_logits=from_logits)

        # If axis is specified, y_pred must have a defined rank
        y_true = tf.keras.backend.placeholder((None, ), dtype=tf.int64)
        y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32)
        with self.assertRaises(ValueError):
            sparse_categorical_focal_loss(y_true,
                                          y_pred,
                                          gamma=gamma,
                                          from_logits=from_logits,
                                          axis=0)

        # It's fine if y_pred has undefined rank is axis=-1
        graph = tf.Graph()
        with graph.as_default():
            y_true = tf.keras.backend.placeholder((None, ), dtype=tf.int64)
            y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32)
            focal_loss = sparse_categorical_focal_loss(y_true,
                                                       y_pred,
                                                       gamma=gamma,
                                                       from_logits=from_logits)

        labels = [0, 0, 1]
        logits = [[10., 0.], [5., -5.], [0., 10.]]
        probs = softmax(logits, axis=-1)

        pred = logits if from_logits else probs
        loss_numpy = numpy_sparse_categorical_focal_loss(
            labels, pred, gamma=gamma, from_logits=from_logits)

        with tf.compat.v1.Session(graph=graph) as sess:
            loss = sess.run(focal_loss,
                            feed_dict={
                                y_true: labels,
                                y_pred: pred
                            })

        self.assertAllClose(loss, loss_numpy)
 def test_reduce_to_multiclass_crossentropy_from_probabilities(
         self, y_true, y_pred):
     """Focal loss with gamma=0 should be the same as cross-entropy."""
     focal_loss = sparse_categorical_focal_loss(y_true=y_true,
                                                y_pred=y_pred,
                                                gamma=0)
     ce = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_true,
                                                          y_pred=y_pred)
     self.assertAllClose(focal_loss, ce)
 def test_reduce_to_multiclass_crossentropy_from_logits(
         self, y_true, y_pred):
     """Focal loss with gamma=0 should be the same as cross-entropy."""
     focal_loss = sparse_categorical_focal_loss(y_true=y_true,
                                                y_pred=y_pred,
                                                gamma=0,
                                                from_logits=True)
     ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
         labels=tf.dtypes.cast(y_true, dtype=tf.dtypes.int64),
         logits=tf.dtypes.cast(y_pred, dtype=tf.dtypes.float32),
     )
     self.assertAllClose(focal_loss, ce)
    def test_class_weight(self, y_true, y_pred, gamma):
        rng = np.random.default_rng(0)
        for _ in range(10):
            class_weight = rng.uniform(size=np.shape(y_pred)[-1])

            loss_without_weight = sparse_categorical_focal_loss(
                y_true=y_true,
                y_pred=y_pred,
                gamma=gamma,
            )
            loss_with_weight = sparse_categorical_focal_loss(
                y_true=y_true,
                y_pred=y_pred,
                gamma=gamma,
                class_weight=class_weight,
            )

            # Apply class weights to loss computed without class_weight
            loss_without_weight = loss_without_weight.numpy()
            loss_without_weight *= np.take(class_weight, y_true)

            self.assertAllClose(loss_with_weight, loss_without_weight)
    def test_computation_sanity_checks(self, y_true, y_pred_logits,
                                       y_pred_prob, gamma):
        """Make sure the focal loss computation behaves as expected."""
        focal_loss_prob = sparse_categorical_focal_loss(
            y_true=y_true,
            y_pred=y_pred_prob,
            gamma=gamma,
            from_logits=False,
        )
        focal_loss_logits = sparse_categorical_focal_loss(
            y_true=y_true,
            y_pred=y_pred_logits,
            gamma=gamma,
            from_logits=True,
        )
        losses = [focal_loss_prob, focal_loss_logits]
        if not (isinstance(y_true, tf.Tensor)
                or isinstance(y_pred_logits, tf.Tensor)):
            numpy_focal_loss_logits = numpy_sparse_categorical_focal_loss(
                y_true=y_true,
                y_pred=y_pred_logits,
                gamma=gamma,
                from_logits=True,
            )
            losses.append(numpy_focal_loss_logits)
        if not (isinstance(y_true, tf.Tensor)
                or isinstance(y_pred_prob, tf.Tensor)):
            numpy_focal_loss_prob = numpy_sparse_categorical_focal_loss(
                y_true=y_true,
                y_pred=y_pred_prob,
                gamma=gamma,
                from_logits=False,
            )
            losses.append(numpy_focal_loss_prob)

        for i, loss_1 in enumerate(losses):
            for loss_2 in losses[(i + 1):]:
                self.assertAllClose(loss_1, loss_2, atol=1e-5, rtol=1e-5)
    def test_higher_rank_sanity_checks(self, gamma, axis, from_logits):
        labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]],
                                      dtype=tf.dtypes.int64)
        logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32),
                            shape=[3, 3, 3])
        probs = tf.nn.softmax(logits, axis=axis)

        y_pred = logits if from_logits else probs
        numpy_loss = numpy_sparse_categorical_focal_loss(
            labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis)
        focal_loss = sparse_categorical_focal_loss(labels,
                                                   y_pred,
                                                   gamma=gamma,
                                                   from_logits=from_logits,
                                                   axis=axis)
        self.assertAllClose(focal_loss, numpy_loss)
    def test_reduce_to_keras_with_higher_rank_and_axis(self, axis,
                                                       from_logits):
        labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]],
                                      dtype=tf.dtypes.int64)
        logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32),
                            shape=[3, 3, 3])
        probs = tf.nn.softmax(logits, axis=axis)

        y_pred = logits if from_logits else probs
        keras_loss = tf.keras.losses.sparse_categorical_crossentropy(
            labels, y_pred, from_logits=from_logits, axis=axis)
        focal_loss = sparse_categorical_focal_loss(labels,
                                                   y_pred,
                                                   gamma=0,
                                                   from_logits=from_logits,
                                                   axis=axis)
        self.assertAllClose(focal_loss, keras_loss)
    def test_train_dummy_multiclass_classifier(self, n_examples, n_features,
                                               n_classes, epochs, gamma,
                                               from_logits, random_state):
        # Generate some fake data
        x = random_state.binomial(n=n_classes,
                                  p=0.5,
                                  size=(n_examples, n_features))
        x = 2.0 * x / n_classes - 1.0
        weights = 100.0 * np.ones(shape=(n_features, n_classes))
        y = np.argmax(x.dot(weights), axis=-1)

        model = get_dummy_sparse_multiclass_classifier(n_features=n_features,
                                                       n_classes=n_classes,
                                                       gamma=gamma,
                                                       from_logits=from_logits)
        history = model.fit(x,
                            y,
                            batch_size=n_examples,
                            epochs=epochs,
                            callbacks=[tf.keras.callbacks.TerminateOnNaN()])

        # Check that we didn't stop early: if we did then we
        # encountered NaNs during training, and that shouldn't happen
        self.assertEqual(len(history.history['loss']), epochs)

        # Check that BinaryFocalLoss and binary_focal_loss agree (at
        # least when averaged)
        model_loss, *_ = model.evaluate(x, y)

        y_pred = model.predict(x)
        loss = sparse_categorical_focal_loss(y_true=y,
                                             y_pred=y_pred,
                                             gamma=gamma,
                                             from_logits=from_logits)
        loss = tf.math.reduce_mean(loss)
        self.assertAllClose(loss, model_loss)