Exemplo n.º 1
0
 def test_ce_loss_no_weights_targets_zero(self):
     logits = np.array([
         [[-0.1, 0.9], [0.1, 0.9]],
         [[-0.1, 0.9], [0.1, 0.9]],
     ])
     targets = np.array([
         [1, 0],
         [1, 0],
     ])
     loss, factor = utils.compute_weighted_cross_entropy(logits, targets)
     self.assertEqual(loss, 5.937449)
     self.assertEqual(factor, 2)
Exemplo n.º 2
0
 def test_ce_loss_no_weights(self):
     logits = np.array([
         [[-0.1, 0.9], [0.1, 0.9]],
         [[-0.1, 0.9], [0.1, 0.9]],
     ])
     targets = np.array([
         [1, 1],
         [1, 1],
     ])
     loss, factor = utils.compute_weighted_cross_entropy(logits, targets)
     self.assertEqual(loss, 5.4748983)
     self.assertEqual(factor, 4)
Exemplo n.º 3
0
 def test_ce_mismatched_shapes_raises_error(self):
     logits = np.array([
         [[-0.1, 0.9], [0.1, 0.9]],
         [[-0.1, 0.9], [0.1, 0.9]],
     ])
     targets = np.array([1])
     weights = np.array([
         [1, 0],
         [1, 0],
     ])
     with self.assertRaisesRegex(ValueError,
                                 '(?i)Incorrect shapes.*targets'):
         _, _ = utils.compute_weighted_cross_entropy(
             logits, targets, weights)
Exemplo n.º 4
0
 def test_ce_loss_some_weights_zero(self):
     logits = np.array([
         [[-0.1, 0.9], [0.1, 0.9]],
         [[-0.1, 0.9], [0.1, 0.9]],
     ])
     targets = np.array([
         [1, 1],
         [1, 1],
     ])
     weights = np.array([
         [1, 0],
         [1, 0],
     ])
     loss, factor = utils.compute_weighted_cross_entropy(
         logits, targets, weights)
     self.assertEqual(loss, 0.62652326)
     self.assertEqual(factor, 2)
Exemplo n.º 5
0
 def test_ce_loss(self):
     logits = np.array([
         [[-0.1, 0.9], [0.1, 0.9]],
         [[-0.1, 0.9], [0.1, 0.9]],
     ])
     targets = np.array([
         [1, 1],
         [1, 1],
     ])
     weights = np.array([
         [1., 1.],
         [1., 1.],
     ])
     loss, factor = utils.compute_weighted_cross_entropy(
         logits, targets, weights)
     self.assertEqual(loss, 1.3687246)
     self.assertEqual(factor, 4)