def test_ce_loss_no_weights_targets_zero(self): logits = np.array([ [[-0.1, 0.9], [0.1, 0.9]], [[-0.1, 0.9], [0.1, 0.9]], ]) targets = np.array([ [1, 0], [1, 0], ]) loss, factor = utils.compute_weighted_cross_entropy(logits, targets) self.assertEqual(loss, 5.937449) self.assertEqual(factor, 2)
def test_ce_loss_no_weights(self): logits = np.array([ [[-0.1, 0.9], [0.1, 0.9]], [[-0.1, 0.9], [0.1, 0.9]], ]) targets = np.array([ [1, 1], [1, 1], ]) loss, factor = utils.compute_weighted_cross_entropy(logits, targets) self.assertEqual(loss, 5.4748983) self.assertEqual(factor, 4)
def test_ce_mismatched_shapes_raises_error(self): logits = np.array([ [[-0.1, 0.9], [0.1, 0.9]], [[-0.1, 0.9], [0.1, 0.9]], ]) targets = np.array([1]) weights = np.array([ [1, 0], [1, 0], ]) with self.assertRaisesRegex(ValueError, '(?i)Incorrect shapes.*targets'): _, _ = utils.compute_weighted_cross_entropy( logits, targets, weights)
def test_ce_loss_some_weights_zero(self): logits = np.array([ [[-0.1, 0.9], [0.1, 0.9]], [[-0.1, 0.9], [0.1, 0.9]], ]) targets = np.array([ [1, 1], [1, 1], ]) weights = np.array([ [1, 0], [1, 0], ]) loss, factor = utils.compute_weighted_cross_entropy( logits, targets, weights) self.assertEqual(loss, 0.62652326) self.assertEqual(factor, 2)
def test_ce_loss(self): logits = np.array([ [[-0.1, 0.9], [0.1, 0.9]], [[-0.1, 0.9], [0.1, 0.9]], ]) targets = np.array([ [1, 1], [1, 1], ]) weights = np.array([ [1., 1.], [1., 1.], ]) loss, factor = utils.compute_weighted_cross_entropy( logits, targets, weights) self.assertEqual(loss, 1.3687246) self.assertEqual(factor, 4)