Ejemplo n.º 1
0
 def test_ignore_index_label_smoothing_cross_entropy(self):
     config = {
         "name": "label_smoothing_cross_entropy",
         "ignore_index": -1,
         "smoothing_param": 0.2,
     }
     crit = LabelSmoothingCrossEntropyLoss.from_config(config)
     outputs = torch.tensor([[0.0, 7.0]])
     targets = torch.tensor([[-1]])
     self.assertAlmostEqual(crit(outputs, targets).item(), 3.50090909)
Ejemplo n.º 2
0
 def test_class_integer_label_smoothing_cross_entropy(self):
     config = {
         "name": "label_smoothing_cross_entropy",
         "ignore_index": -1,
         "smoothing_param": 0.2,
     }
     crit = LabelSmoothingCrossEntropyLoss.from_config(config)
     outputs = torch.tensor([[1.0, 2.0], [0.0, 2.0]])
     targets = torch.tensor([[0], [1]])
     self.assertAlmostEqual(crit(outputs, targets).item(), 0.76176142)
Ejemplo n.º 3
0
 def test_unnormalized_label_smoothing_cross_entropy(self):
     config = {
         "name": "label_smoothing_cross_entropy",
         "ignore_index": -1,
         "smoothing_param": 0.5,
     }
     crit = LabelSmoothingCrossEntropyLoss.from_config(config)
     outputs = torch.tensor([[0.0, 7.0, 0.0, 0.0, 2.0]])
     targets = torch.tensor([[0, 0, 0, 0, 1]])
     self.assertAlmostEqual(crit(outputs, targets).item(), 5.07609558, places=5)