def test_unequal_weights_reduction_2d(self): """Testing unequal weights reductions for 2D data""" y_true = [[1, 2], [0, 2]] y_pred = [[[0.05, 0.95, 0], [0.1, 0.8, 0.1]], [[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]]] sample_weight = np.take(self.class_weights, y_true) scce = self.scce(y_true, y_pred, sample_weight=sample_weight).numpy() self.sfl_unequal = BalancedSparseFocalLoss( self.class_weights, gamma=self.gamma, reduction=losses.Reduction.SUM) scce1 = scce * \ (1 - np.where(get_one_hot(y_true), y_pred, 0).sum(axis=-1))\ **self.gamma scce1 = scce1.sum() sfl1 = self.sfl_unequal(y_true, y_pred).numpy() np.testing.assert_allclose(scce1, sfl1, rtol=1e-6) self.sfl_unequal = BalancedSparseFocalLoss( self.class_weights, gamma=self.gamma, reduction=losses.Reduction.NONE) scce2 = scce * \ (1 - np.where(get_one_hot(y_true), y_pred, 0).sum(axis=-1))\ **self.gamma sfl2 = self.sfl_unequal(y_true, y_pred).numpy() np.testing.assert_allclose(scce2, sfl2, rtol=1e-5)
def setUp(self): """Setup shared by all tests""" self.scce = losses.SparseCategoricalCrossentropy( reduction=losses.Reduction.NONE) self.gamma = 4.0 self.sfl_equal = BalancedSparseFocalLoss([1, 1, 1], gamma=self.gamma) self.class_weights = [0.2, 0.3, 0.5] self.sfl_unequal = BalancedSparseFocalLoss(self.class_weights, gamma=self.gamma)
def test_equal_weights_reduction_1d(self): """Testing equal weights reductions for 1D data""" y_true = [1, 2] y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] scce = self.scce(y_true, y_pred).numpy() self.sfl_equal = BalancedSparseFocalLoss( [1, 1, 1], gamma=self.gamma, reduction=losses.Reduction.SUM) scce1 = scce * \ (1 - np.where(get_one_hot(y_true), y_pred, 0).sum(axis=-1))\ **self.gamma scce1 = scce1.sum() sfl1 = self.sfl_equal(y_true, y_pred).numpy() np.testing.assert_allclose(scce1, sfl1, rtol=1e-6) self.sfl_equal = BalancedSparseFocalLoss( [1, 1, 1], gamma=self.gamma, reduction=losses.Reduction.NONE) scce2 = scce * \ (1 - np.where(get_one_hot(y_true), y_pred, 0).sum(axis=-1))\ **self.gamma sfl2 = self.sfl_equal(y_true, y_pred).numpy() np.testing.assert_allclose(scce2, sfl2, rtol=1e-5)
def test_equal_weights_logits_1d(self): """Testing equal weights logits for 1D data""" y_true = [1, 2] y_pred = [[-0.05, 0.3, 0.19], [0.2, -0.4, 0.12]] self.scce = losses.SparseCategoricalCrossentropy( from_logits=True, reduction=losses.Reduction.NONE) scce = self.scce(y_true, y_pred).numpy() self.sfl_equal = BalancedSparseFocalLoss([1, 1, 1], gamma=self.gamma, from_logits=True) scce = scce * \ (1 - np.where(get_one_hot(y_true), softmax(y_pred), 0).sum(axis=-1))\ **self.gamma scce = scce.mean() sfl = self.sfl_equal(y_true, y_pred).numpy() np.testing.assert_allclose(scce, sfl, rtol=1e-6)
def test_unequal_weights_logits_2d(self): """Testing unequal weights logits for 2D data""" y_true = [[1, 2], [0, 2]] y_pred = [[[-0.05, 0.3, 0.19], [0.2, -0.4, 0.12]], [[-0.1, 0.22, -0.73], [0.23, -0.52, 0.2]]] sample_weight = np.take(self.class_weights, y_true) self.scce = losses.SparseCategoricalCrossentropy( from_logits=True, reduction=losses.Reduction.NONE) scce = self.scce(y_true, y_pred, sample_weight=sample_weight).numpy() self.sfl_unequal = BalancedSparseFocalLoss(self.class_weights, gamma=self.gamma, from_logits=True) scce = scce * \ (1 - np.where(get_one_hot(y_true), softmax(y_pred), 0).sum(axis=-1))\ **self.gamma scce = scce.mean() sfl = self.sfl_unequal(y_true, y_pred).numpy() np.testing.assert_allclose(scce, sfl, rtol=1e-6)