Beispiel #1
0
 def test_one_hot(self, smooth_eps, K):
     batch_size = 32
     y = torch.randint(K, (batch_size, ))
     y_np = y.numpy()
     actual = losses.one_hot(y, K, smooth_eps)
     y_np_one_hot = np.eye(K)[y_np]
     expected = y_np_one_hot * (1 - smooth_eps) + (smooth_eps / (K - 1))
     assert_allclose(actual, expected)
Beispiel #2
0
 def test_one_hot_eps_out_of_bounds(self, smooth_eps):
     with pytest.raises(AssertionError):
         losses.one_hot(torch.randint(5, (2, 2)), 5, smooth_eps)
Beispiel #3
0
 def test_one_hot_not_rank_one(self):
     with pytest.raises(AssertionError):
         losses.one_hot(torch.randint(5, (2, 2)), 5, 0)