Esempio n. 1
0
 def test_ignore_row(self):
     # If a sample has no valid targets, it should be ignored in the reduction.
     config = self._get_config()
     crit = SoftTargetCrossEntropyLoss.from_config(config)
     outputs = torch.tensor([[1.0, 7.0, 0.0, 0.0, 2.0],
                             [4.0, 2.0, 1.0, 6.0, 0.5]])
     targets = torch.tensor([[1, 0, 0, 0, 1], [-1, -1, -1, -1, -1]])
     self.assertAlmostEqual(crit(outputs, targets).item(), self._get_loss())
Esempio n. 2
0
    def test_soft_target_cross_entropy(self):
        config = self._get_config()
        crit = SoftTargetCrossEntropyLoss.from_config(config)
        outputs = self._get_outputs()
        targets = self._get_targets()
        self.assertAlmostEqual(crit(outputs, targets).item(), self._get_loss())

        # Verify ignore index works
        outputs = self._get_outputs()
        targets = torch.tensor([[-1, 0, 0, 0, 1]])
        self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)
Esempio n. 3
0
    def test_soft_target_cross_entropy_none_reduction(self):
        # reduction mode is "none"
        config = self._get_config()
        config["reduction"] = "none"

        crit = SoftTargetCrossEntropyLoss.from_config(config)

        outputs = torch.tensor([[1.0, 7.0, 0.0, 0.0, 2.0],
                                [4.0, 2.0, 1.0, 6.0, 0.5]])
        targets = torch.tensor([[1, 0, 0, 0, 1], [0, 1, 0, 1, 0]])
        loss = crit(outputs, targets)
        self.assertEqual(loss.numel(), outputs.size(0))
Esempio n. 4
0
    def test_unnormalized_soft_target_cross_entropy(self):
        config = {
            "name": "soft_target_cross_entropy",
            "ignore_index": -1,
            "reduction": "mean",
            "normalize_targets": False,
        }
        crit = SoftTargetCrossEntropyLoss.from_config(config)
        outputs = self._get_outputs()
        targets = self._get_targets()
        self.assertAlmostEqual(crit(outputs, targets).item(), 11.0219593)

        # Verify ignore index works
        outputs = self._get_outputs()
        targets = torch.tensor([[-1, 0, 0, 0, 1]])
        self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097965)
Esempio n. 5
0
 def test_soft_target_cross_entropy_integer_label(self):
     config = self._get_config()
     crit = SoftTargetCrossEntropyLoss.from_config(config)
     outputs = self._get_outputs()
     targets = torch.tensor([4])
     self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)