def test_with_dynamic_ranks(self, gamma, from_logits): # y_true must have defined rank y_true = tf.keras.backend.placeholder(None, dtype=tf.int64) y_pred = tf.keras.backend.placeholder((None, 2), dtype=tf.float32) with self.assertRaises(NotImplementedError): sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma, from_logits=from_logits) # If axis is specified, y_pred must have a defined rank y_true = tf.keras.backend.placeholder((None, ), dtype=tf.int64) y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32) with self.assertRaises(ValueError): sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma, from_logits=from_logits, axis=0) # It's fine if y_pred has undefined rank is axis=-1 graph = tf.Graph() with graph.as_default(): y_true = tf.keras.backend.placeholder((None, ), dtype=tf.int64) y_pred = tf.keras.backend.placeholder(None, dtype=tf.float32) focal_loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=gamma, from_logits=from_logits) labels = [0, 0, 1] logits = [[10., 0.], [5., -5.], [0., 10.]] probs = softmax(logits, axis=-1) pred = logits if from_logits else probs loss_numpy = numpy_sparse_categorical_focal_loss( labels, pred, gamma=gamma, from_logits=from_logits) with tf.compat.v1.Session(graph=graph) as sess: loss = sess.run(focal_loss, feed_dict={ y_true: labels, y_pred: pred }) self.assertAllClose(loss, loss_numpy)
def test_reduce_to_multiclass_crossentropy_from_probabilities( self, y_true, y_pred): """Focal loss with gamma=0 should be the same as cross-entropy.""" focal_loss = sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred, gamma=0) ce = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_true, y_pred=y_pred) self.assertAllClose(focal_loss, ce)
def test_reduce_to_multiclass_crossentropy_from_logits( self, y_true, y_pred): """Focal loss with gamma=0 should be the same as cross-entropy.""" focal_loss = sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred, gamma=0, from_logits=True) ce = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.dtypes.cast(y_true, dtype=tf.dtypes.int64), logits=tf.dtypes.cast(y_pred, dtype=tf.dtypes.float32), ) self.assertAllClose(focal_loss, ce)
def test_class_weight(self, y_true, y_pred, gamma): rng = np.random.default_rng(0) for _ in range(10): class_weight = rng.uniform(size=np.shape(y_pred)[-1]) loss_without_weight = sparse_categorical_focal_loss( y_true=y_true, y_pred=y_pred, gamma=gamma, ) loss_with_weight = sparse_categorical_focal_loss( y_true=y_true, y_pred=y_pred, gamma=gamma, class_weight=class_weight, ) # Apply class weights to loss computed without class_weight loss_without_weight = loss_without_weight.numpy() loss_without_weight *= np.take(class_weight, y_true) self.assertAllClose(loss_with_weight, loss_without_weight)
def test_computation_sanity_checks(self, y_true, y_pred_logits, y_pred_prob, gamma): """Make sure the focal loss computation behaves as expected.""" focal_loss_prob = sparse_categorical_focal_loss( y_true=y_true, y_pred=y_pred_prob, gamma=gamma, from_logits=False, ) focal_loss_logits = sparse_categorical_focal_loss( y_true=y_true, y_pred=y_pred_logits, gamma=gamma, from_logits=True, ) losses = [focal_loss_prob, focal_loss_logits] if not (isinstance(y_true, tf.Tensor) or isinstance(y_pred_logits, tf.Tensor)): numpy_focal_loss_logits = numpy_sparse_categorical_focal_loss( y_true=y_true, y_pred=y_pred_logits, gamma=gamma, from_logits=True, ) losses.append(numpy_focal_loss_logits) if not (isinstance(y_true, tf.Tensor) or isinstance(y_pred_prob, tf.Tensor)): numpy_focal_loss_prob = numpy_sparse_categorical_focal_loss( y_true=y_true, y_pred=y_pred_prob, gamma=gamma, from_logits=False, ) losses.append(numpy_focal_loss_prob) for i, loss_1 in enumerate(losses): for loss_2 in losses[(i + 1):]: self.assertAllClose(loss_1, loss_2, atol=1e-5, rtol=1e-5)
def test_higher_rank_sanity_checks(self, gamma, axis, from_logits): labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]], dtype=tf.dtypes.int64) logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32), shape=[3, 3, 3]) probs = tf.nn.softmax(logits, axis=axis) y_pred = logits if from_logits else probs numpy_loss = numpy_sparse_categorical_focal_loss( labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis) focal_loss = sparse_categorical_focal_loss(labels, y_pred, gamma=gamma, from_logits=from_logits, axis=axis) self.assertAllClose(focal_loss, numpy_loss)
def test_reduce_to_keras_with_higher_rank_and_axis(self, axis, from_logits): labels = tf.convert_to_tensor([[0, 1, 2], [0, 0, 0], [1, 1, 1]], dtype=tf.dtypes.int64) logits = tf.reshape(tf.range(27, dtype=tf.dtypes.float32), shape=[3, 3, 3]) probs = tf.nn.softmax(logits, axis=axis) y_pred = logits if from_logits else probs keras_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, y_pred, from_logits=from_logits, axis=axis) focal_loss = sparse_categorical_focal_loss(labels, y_pred, gamma=0, from_logits=from_logits, axis=axis) self.assertAllClose(focal_loss, keras_loss)
def test_train_dummy_multiclass_classifier(self, n_examples, n_features, n_classes, epochs, gamma, from_logits, random_state): # Generate some fake data x = random_state.binomial(n=n_classes, p=0.5, size=(n_examples, n_features)) x = 2.0 * x / n_classes - 1.0 weights = 100.0 * np.ones(shape=(n_features, n_classes)) y = np.argmax(x.dot(weights), axis=-1) model = get_dummy_sparse_multiclass_classifier(n_features=n_features, n_classes=n_classes, gamma=gamma, from_logits=from_logits) history = model.fit(x, y, batch_size=n_examples, epochs=epochs, callbacks=[tf.keras.callbacks.TerminateOnNaN()]) # Check that we didn't stop early: if we did then we # encountered NaNs during training, and that shouldn't happen self.assertEqual(len(history.history['loss']), epochs) # Check that BinaryFocalLoss and binary_focal_loss agree (at # least when averaged) model_loss, *_ = model.evaluate(x, y) y_pred = model.predict(x) loss = sparse_categorical_focal_loss(y_true=y, y_pred=y_pred, gamma=gamma, from_logits=from_logits) loss = tf.math.reduce_mean(loss) self.assertAllClose(loss, model_loss)