示例#1
0
def test_OneHotEncode_values():
  """Test that OneHotEncode() returns a correct values for a known input."""
  data = data_generators.OneHotEncode(np.array([0, 1, 2]), 4)
  np.testing.assert_array_equal(
      np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]), data)
示例#2
0
def test_OneHotEncode_dtype():
  """Test that OneHotEncode() returns a float array."""
  data = data_generators.OneHotEncode(np.array([0, 1, 2]), 3)
  assert data.dtype == np.float64
示例#3
0
def test_OneHotEncode_shape():
  """Test that OneHotEncode() adds a dimension of vocabulary_size."""
  data = data_generators.OneHotEncode(np.array([0, 1, 2]), 4)
  assert data.shape == (3, 4)
示例#4
0
def test_OneHotEncode_empty_input():
  """Test that OneHotEncode() rejects an empty input."""
  with pytest.raises(IndexError):
    data_generators.OneHotEncode(np.array([]), 3)