def test_OneHotEncode_values(): """Test that OneHotEncode() returns a correct values for a known input.""" data = data_generators.OneHotEncode(np.array([0, 1, 2]), 4) np.testing.assert_array_equal( np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]), data)
def test_OneHotEncode_dtype(): """Test that OneHotEncode() returns a float array.""" data = data_generators.OneHotEncode(np.array([0, 1, 2]), 3) assert data.dtype == np.float64
def test_OneHotEncode_shape(): """Test that OneHotEncode() adds a dimension of vocabulary_size.""" data = data_generators.OneHotEncode(np.array([0, 1, 2]), 4) assert data.shape == (3, 4)
def test_OneHotEncode_empty_input(): """Test that OneHotEncode() rejects an empty input.""" with pytest.raises(IndexError): data_generators.OneHotEncode(np.array([]), 3)