Beispiel #1
0
 def test_cross_entropy(self, smooth_eps, reduction):
     batch_size = 2
     K = 5
     labels = torch.randint(K, (batch_size, ))
     logits = torch.randn(batch_size, K)
     actual = losses.cross_entropy(logits, labels, smooth_eps, reduction)
     logits_np = logits.numpy()
     labels_np = labels.numpy()
     labels_np_one_hot = np.eye(K)[labels_np] * (1 - smooth_eps) \
         + (smooth_eps / (K - 1))
     # Compute log softmax of logits.
     log_probs_np = log_softmax(logits_np)
     loss_np = (-labels_np_one_hot * log_probs_np).sum(axis=-1)
     if reduction == "mean":
         expected = loss_np.mean()
     elif reduction == "sum":
         expected = loss_np.sum(axis=-1)
     else:  # none
         expected = loss_np
     assert_allclose(actual, expected)
Beispiel #2
0
 def test_cross_entropy_labels_dim(self):
     K, batch_size = 5, 2
     with pytest.raises(AssertionError):
         labels = torch.randint(K, (batch_size, 2))
         logits = torch.randn(batch_size, K)
         losses.cross_entropy(logits, labels, reduction="average")
Beispiel #3
0
 def test_cross_entropy_unsupported_reduction(self):
     K, batch_size = 50, 200
     with pytest.raises(AssertionError):
         labels = torch.randint(K, (batch_size, ))
         logits = torch.randn(batch_size, K)
         losses.cross_entropy(logits, labels, reduction="average")