def test_one_hot(self, smooth_eps, K): batch_size = 32 y = torch.randint(K, (batch_size, )) y_np = y.numpy() actual = losses.one_hot(y, K, smooth_eps) y_np_one_hot = np.eye(K)[y_np] expected = y_np_one_hot * (1 - smooth_eps) + (smooth_eps / (K - 1)) assert_allclose(actual, expected)
def test_one_hot_eps_out_of_bounds(self, smooth_eps): with pytest.raises(AssertionError): losses.one_hot(torch.randint(5, (2, 2)), 5, smooth_eps)
def test_one_hot_not_rank_one(self): with pytest.raises(AssertionError): losses.one_hot(torch.randint(5, (2, 2)), 5, 0)