def test_assert_element_shape(self):

    def create_dataset(_):
      return (array_ops.ones(2, dtype=dtypes.float32),
              array_ops.zeros((3, 4), dtype=dtypes.int32))

    dataset = dataset_ops.Dataset.range(5).map(create_dataset)
    expected_shapes = (tensor_shape.TensorShape(2),
                       tensor_shape.TensorShape((3, 4)))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(dataset))

    result = dataset.apply(batching.assert_element_shape(expected_shapes))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(result))

    iterator = dataset_ops.make_initializable_iterator(result)
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op)
      for _ in range(5):
        sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
  def test_assert_element_shape_on_unknown_shape_dataset(self):

    def create_unknown_shape_dataset(x):
      return script_ops.py_func(
          lambda _: (  # pylint: disable=g-long-lambda
              np.ones(2, dtype=np.float32),
              np.zeros((3, 4), dtype=np.int32)),
          [x],
          [dtypes.float32, dtypes.int32])

    dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
    unknown_shapes = (tensor_shape.TensorShape(None),
                      tensor_shape.TensorShape(None))
    self.assertEqual(unknown_shapes,
                     dataset_ops.get_legacy_output_shapes(dataset))

    expected_shapes = (tensor_shape.TensorShape(2),
                       tensor_shape.TensorShape((3, 4)))
    result = dataset.apply(batching.assert_element_shape(expected_shapes))
    self.assertEqual(expected_shapes,
                     dataset_ops.get_legacy_output_shapes(result))

    iterator = dataset_ops.make_initializable_iterator(result)
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with self.cached_session() as sess:
      sess.run(init_op)
      for _ in range(5):
        sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
  def testNonSequenceNestedStructure(self):
    components = np.array([1, 2, 3], dtype=np.int64)

    dataset = dataset_ops.Dataset.from_tensors(components)
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

    dataset = dataset.filter(
        lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

    dataset = dataset.map(lambda x: array_ops.stack([x, x]))
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([2, 3], dataset_ops.get_legacy_output_shapes(dataset))

    dataset = dataset.flat_map(
        lambda x: dataset_ops.Dataset.from_tensor_slices(x))
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

    get_next = self.getNext(dataset)
    self.assertEqual(dtypes.int64, get_next().dtype)
    self.assertEqual([3], get_next().shape)
示例#4
0
  def testChangingStateShape(self):
    # Test the fixed-point shape invariant calculations: start with
    # initial values with known shapes, and use a scan function that
    # changes the size of the state on each element.
    def _scan_fn(state, input_value):
      # Statically known rank, but dynamic length.
      ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
      # Statically unknown rank.
      ret_larger_rank = array_ops.expand_dims(state[1], 0)
      return (ret_longer_vector, ret_larger_rank), (state, input_value)

    dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
        scan_ops.scan(([0], 1), _scan_fn))
    self.assertEqual(
        [None], dataset_ops.get_legacy_output_shapes(dataset)[0][0].as_list())
    self.assertIs(
        None, dataset_ops.get_legacy_output_shapes(dataset)[0][1].ndims)
    self.assertEqual(
        [], dataset_ops.get_legacy_output_shapes(dataset)[1].as_list())

    next_element = self.getNext(dataset)

    for i in range(5):
      (longer_vector_val, larger_rank_val), _ = self.evaluate(next_element())
      self.assertAllEqual([0] * (2**i), longer_vector_val)
      self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element())
 def testNestedDict(self):
   components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
   dataset = dataset_ops.Dataset.from_tensors(components)
   self.assertEqual(dtypes.int32,
                    dataset_ops.get_legacy_output_types(dataset)["a"]["aa"])
   self.assertEqual(dtypes.float32,
                    dataset_ops.get_legacy_output_types(dataset)["a"]["ab"])
   self.assertEqual(dtypes.int32,
                    dataset_ops.get_legacy_output_types(dataset)["b"])
   self.assertEqual([],
                    dataset_ops.get_legacy_output_shapes(dataset)["a"]["aa"])
   self.assertEqual([2],
                    dataset_ops.get_legacy_output_shapes(dataset)["a"]["ab"])
   self.assertEqual([3],
                    dataset_ops.get_legacy_output_shapes(dataset)["b"])
示例#6
0
  def testRepeatTensorDataset(self):
    """Test a dataset that repeats its input multiple times."""
    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
    # This placeholder can be fed when dataset-definition subgraph
    # runs (i.e. `init_op` below) to configure the number of
    # repetitions used in a particular iterator.

    def do_test(count):
      dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
      self.assertEqual(
          [c.shape for c in components],
          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
      self.assertDatasetProduces(dataset, [components] * count)

    # Test a finite repetition.
    do_test(3)

    # test a different finite repetition.
    do_test(7)

    # Test an empty repetition.
    do_test(0)

    # Test an infinite repetition.
    # NOTE(mrry): There's not a good way to test that the sequence
    # actually is infinite.
    dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1)
    self.assertEqual(
        [c.shape for c in components],
        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
    get_next = self.getNext(dataset)
    for _ in range(17):
      results = self.evaluate(get_next())
      for component, result_component in zip(components, results):
        self.assertAllEqual(component, result_component)
示例#7
0
  def _apply_fn(dataset):
    """Function from `Dataset` to `Dataset` that applies the transformation."""
    # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
    # are normalized to the rank-1 dense representation, so that the
    # sparse-oblivious unbatching logic will slice them
    # appropriately. This leads to a somewhat inefficient re-encoding step
    # for all SparseTensor components.
    # TODO(mrry): Consider optimizing this in future if it turns out to be
    # a bottleneck.
    def normalize(arg, *rest):
      # pylint: disable=protected-access
      if rest:
        return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
      else:
        return dataset._element_structure._to_batched_tensor_list(arg)

    normalized_dataset = dataset.map(normalize)

    # NOTE(mrry): Our `map()` has lost information about the sparseness
    # of any SparseTensor components, so re-apply the structure of the
    # original dataset.
    restructured_dataset = _RestructuredDataset(
        normalized_dataset,
        dataset_ops.get_legacy_output_types(dataset),
        dataset_ops.get_legacy_output_shapes(dataset),
        dataset_ops.get_legacy_output_classes(dataset),
        allow_unsafe_cast=True)
    return _UnbatchDataset(restructured_dataset)
示例#8
0
def _create_or_validate_filenames_dataset(filenames):
  """Creates (or validates) a dataset of filenames.

  Args:
    filenames: Either a list or dataset of filenames. If it is a list, it is
      convert to a dataset. If it is a dataset, its type and shape is validated.

  Returns:
    A dataset of filenames.
  """
  if isinstance(filenames, dataset_ops.DatasetV2):
    if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
      raise TypeError(
          "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
    if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
        tensor_shape.scalar()):
      raise TypeError(
          "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
          "elements.")
  else:
    filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string)
    filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
    filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)

  return filenames
 def testChooseFastestManyInputs(self):
   dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
   merge = optimization._ChooseFastestDataset([dataset for _ in range(5)])
   self.assertDatasetProduces(
       merge,
       expected_output=[0, 1, 2, 3, 4],
       expected_shapes=dataset_ops.get_legacy_output_shapes(dataset))
示例#10
0
  def testConcatenateDatasetDifferentShape(self):
    input_components = (
        np.tile(np.array([[1], [2], [3], [4]]), 20),
        np.tile(np.array([[12], [13], [14], [15]]), 4))
    to_concatenate_components = (
        np.tile(np.array([[1], [2], [3], [4], [5]]), 20),
        np.tile(np.array([[12], [13], [14], [15], [16]]), 15))

    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
        to_concatenate_components)
    concatenated = input_dataset.concatenate(dataset_to_concatenate)
    self.assertEqual(
        [ts.as_list()
         for ts in nest.flatten(
             dataset_ops.get_legacy_output_shapes(concatenated))],
        [[20], [None]])
    get_next = self.getNext(concatenated)
    for i in range(9):
      result = self.evaluate(get_next())
      if i < 4:
        for component, result_component in zip(input_components, result):
          self.assertAllEqual(component[i], result_component)
      else:
        for component, result_component in zip(to_concatenate_components,
                                               result):
          self.assertAllEqual(component[i - 4], result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#11
0
  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError("Input shape should have at least one dimension.")
      if (tensor_shape.dimension_value(output_shapes[0]) and
          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
        raise errors.InvalidArgumentError(
            None, None,
            "First dim of input shape: %d is not divisible by num_workers: %d" %
            (output_shapes[0], num_workers))
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = output_dims[0] // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **dataset_ops.flat_structure(self))
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#12
0
  def testNestedZipDataset(self):

    equal_length_components = [
        np.tile(np.array([[1], [2], [3], [4]]), 20),
        np.tile(np.array([[12], [13], [14], [15]]), 22),
        np.array([37.0, 38.0, 39.0, 40.0])
    ]
    datasets = [
        dataset_ops.Dataset.from_tensor_slices(component)
        for component in equal_length_components
    ]
    dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))

    self.assertEqual(
        dataset_ops.get_legacy_output_shapes(dataset),
        (tensor_shape.TensorShape([20]),
         (tensor_shape.TensorShape([22]), tensor_shape.TensorShape([]))))

    get_next = self.getNext(dataset)
    for i in range(4):
      result1, (result2, result3) = self.evaluate(get_next())
      self.assertAllEqual(equal_length_components[0][i], result1)
      self.assertAllEqual(equal_length_components[1][i], result2)
      self.assertAllEqual(equal_length_components[2][i], result3)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#13
0
  def testConcatenateDataset(self):
    input_components = (
        np.tile(np.array([[1], [2], [3], [4]]), 20),
        np.tile(np.array([[12], [13], [14], [15]]), 15),
        np.array([37.0, 38.0, 39.0, 40.0]))
    to_concatenate_components = (
        np.tile(np.array([[1], [2], [3], [4], [5]]), 20),
        np.tile(np.array([[12], [13], [14], [15], [16]]), 15),
        np.array([37.0, 38.0, 39.0, 40.0, 41.0]))

    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
        to_concatenate_components)
    concatenated = input_dataset.concatenate(dataset_to_concatenate)
    self.assertEqual(
        dataset_ops.get_legacy_output_shapes(concatenated),
        (tensor_shape.TensorShape([20]), tensor_shape.TensorShape([15]),
         tensor_shape.TensorShape([])))

    get_next = self.getNext(concatenated)

    for i in range(9):
      result = self.evaluate(get_next())
      if i < 4:
        for component, result_component in zip(input_components, result):
          self.assertAllEqual(component[i], result_component)
      else:
        for component, result_component in zip(to_concatenate_components,
                                               result):
          self.assertAllEqual(component[i - 4], result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#14
0
  def assertDatasetProduces(self,
                            dataset,
                            expected_output=None,
                            expected_shapes=None,
                            expected_error=None,
                            requires_initialization=False,
                            num_test_iterations=1,
                            assert_items_equal=False):
    """Asserts that a dataset produces the expected output / error.

    Args:
      dataset: A dataset to check for the expected output / error.
      expected_output: A list of elements that the dataset is expected to
        produce.
      expected_shapes: A list of TensorShapes which is expected to match
        output_shapes of dataset.
      expected_error: A tuple `(type, predicate)` identifying the expected error
        `dataset` should raise. The `type` should match the expected exception
        type, while `predicate` should either be 1) a unary function that inputs
        the raised exception and returns a boolean indicator of success or 2) a
        regular expression that is expected to match the error message
        partially.
      requires_initialization: Indicates that when the test is executed in graph
        mode, it should use an initializable iterator to iterate through the
        dataset (e.g. when it contains stateful nodes). Defaults to False.
      num_test_iterations: Number of times `dataset` will be iterated. Defaults
        to 2.
      assert_items_equal: Tests expected_output has (only) the same elements
        regardless of order.
    """
    self.assertTrue(
        expected_error is not None or expected_output is not None,
        "Exactly one of expected_output or expected error should be provided.")
    if expected_error:
      self.assertTrue(
          expected_output is None,
          "Exactly one of expected_output or expected error should be provided."
      )
      with self.assertRaisesWithPredicateMatch(expected_error[0],
                                               expected_error[1]):
        get_next = self.getNext(
            dataset, requires_initialization=requires_initialization)
        self.evaluate(get_next())
      return
    if expected_shapes:
      self.assertEqual(expected_shapes,
                       dataset_ops.get_legacy_output_shapes(dataset))
    self.assertGreater(num_test_iterations, 0)
    for _ in range(num_test_iterations):
      get_next = self.getNext(
          dataset, requires_initialization=requires_initialization)
      result = []
      for _ in range(len(expected_output)):
        result.append(self.evaluate(get_next()))
      self._compareOutputToExpected(result, expected_output, assert_items_equal)
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())
示例#15
0
  def __init__(self, input_dataset, features, num_parallel_calls):
    self._input_dataset = input_dataset
    if not input_dataset._element_structure.is_compatible_with(  # pylint: disable=protected-access
        structure.TensorStructure(dtypes.string, [None])):
      raise TypeError("Input dataset should be a dataset of vectors of strings")
    self._num_parallel_calls = num_parallel_calls
    # pylint: disable=protected-access
    self._features = parsing_ops._prepend_none_dimension(features)
    # sparse_keys and dense_keys come back sorted here.
    (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
     dense_shapes) = parsing_ops._features_to_raw_params(
         self._features, [
             parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
             parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
         ])
    # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
    (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
     dense_shape_as_shape) = parsing_ops._process_raw_parameters(
         None, dense_defaults, sparse_keys, sparse_types, dense_keys,
         dense_types, dense_shapes)
    # pylint: enable=protected-access
    self._sparse_keys = sparse_keys
    self._sparse_types = sparse_types
    self._dense_keys = dense_keys
    self._dense_defaults = dense_defaults_vec
    self._dense_shapes = dense_shapes
    self._dense_types = dense_types
    input_dataset_shape = dataset_ops.get_legacy_output_shapes(
        self._input_dataset)
    dense_output_shapes = [input_dataset_shape.concatenate(shape)
                           for shape in dense_shape_as_shape]
    sparse_output_shapes = [input_dataset_shape.concatenate([None])
                            for _ in range(len(sparse_keys))]

    output_shapes = dict(
        zip(self._dense_keys + self._sparse_keys,
            dense_output_shapes + sparse_output_shapes))
    output_types = dict(
        zip(self._dense_keys + self._sparse_keys,
            self._dense_types + self._sparse_types))
    output_classes = dict(
        zip(self._dense_keys + self._sparse_keys,
            [ops.Tensor for _ in range(len(self._dense_defaults))] +
            [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
            ]))
    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)

    variant_tensor = (
        gen_experimental_dataset_ops.experimental_parse_example_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._num_parallel_calls,
            self._dense_defaults,
            self._sparse_keys,
            self._dense_keys,
            self._sparse_types,
            self._dense_shapes,
            **dataset_ops.flat_structure(self)))
    super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
  def _test(self,
            input_tensor,
            feature_val,
            expected_values=None,
            expected_err=None,
            create_iterator_twice=False):

    if expected_err:
      with self.assertRaisesWithPredicateMatch(expected_err[0],
                                               expected_err[1]):
        dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
            contrib_parsing_ops.parse_example_dataset(feature_val))
        get_next = self.getNext(dataset)
        self.evaluate(get_next())
      return
    else:
      # Returns dict w/ Tensors and SparseTensors.
      # Check values.
      dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
          contrib_parsing_ops.parse_example_dataset(feature_val))
      get_next = self.getNext(dataset)
      result = self.evaluate(get_next())
      self._compare_output_to_expected(result, expected_values)
      with self.assertRaises(errors_impl.OutOfRangeError):
        self.evaluate(get_next())
      with self.assertRaises(errors_impl.OutOfRangeError):
        self.evaluate(get_next())
      if create_iterator_twice:
        get_next = self.getNext(dataset)
        result = self.evaluate(get_next())
        self._compare_output_to_expected(result, expected_values)
        with self.assertRaises(errors_impl.OutOfRangeError):
          self.evaluate(get_next())
    # Check shapes; if serialized is a Tensor we need its size to
    # properly check.
    batch_size = (
        self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor)
        else np.asarray(input_tensor).size)
    for k, f in feature_val.items():
      if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
        self.assertEqual(
            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[0],
            batch_size)
      elif isinstance(f, parsing_ops.VarLenFeature):
        self.assertEqual(
            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)
  def testFromTensorSlicesMixed(self):
    """Test a dataset that represents the slices from a tuple of tensors."""
    components = (np.tile(np.array([[1], [2], [3]]), 20),
                  np.tile(np.array([[12], [13], [14]]), 22),
                  np.array([37.0, 38.0, 39.0]),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 0], [2, 0]]),
                      values=np.array([0, 0, 0]),
                      dense_shape=np.array([3, 1])),
                  sparse_tensor.SparseTensorValue(
                      indices=np.array([[0, 0], [1, 1], [2, 2]]),
                      values=np.array([1, 2, 3]),
                      dense_shape=np.array([3, 3])))

    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    get_next = self.getNext(dataset)
    self.assertEqual([
        tensor_shape.TensorShape(c.dense_shape[1:])
        if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
    ], [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])

    expected = [
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[0]]),
             values=np.array([1]),
             dense_shape=np.array([3]))),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[1]]),
             values=np.array([2]),
             dense_shape=np.array([3]))),
        (sparse_tensor.SparseTensorValue(
            indices=np.array([[0]]),
            values=np.array([0]),
            dense_shape=np.array([1])),
         sparse_tensor.SparseTensorValue(
             indices=np.array([[2]]),
             values=np.array([3]),
             dense_shape=np.array([3]))),
    ]
    for i in range(3):
      results = self.evaluate(get_next())
      for component, result_component in zip(
          (list(zip(*components[:3]))[i] + expected[i]), results):
        if sparse_tensor.is_sparse(component):
          self.assertSparseValuesEqual(component, result_component)
        else:
          self.assertAllEqual(component, result_component)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#18
0
 def do_test(count):
   dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
   self.assertEqual(
       [c.shape[1:] for c in components],
       [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
   start_range = min(count, 10) if count != -1 else 10
   self.assertDatasetProduces(
       dataset,
       [tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
  def testFromTensorSlicesWithDict(self):
    components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    get_next = self.getNext(dataset)

    self.assertEqual(dtypes.int32,
                     dataset_ops.get_legacy_output_types(dataset)["foo"])
    self.assertEqual(dtypes.float32,
                     dataset_ops.get_legacy_output_types(dataset)["bar"])
    self.assertEqual((), dataset_ops.get_legacy_output_shapes(dataset)["foo"])
    self.assertEqual((1,), dataset_ops.get_legacy_output_shapes(dataset)["bar"])

    for i in range(3):
      results = self.evaluate(get_next())
      self.assertEqual(components["foo"][i], results["foo"])
      self.assertEqual(components["bar"][i], results["bar"])
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#20
0
 def _apply_fn(dataset):
   output_shapes = _merge_output_shapes(
       dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
   # pylint: disable=protected-access
   return batching._RestructuredDataset(
       dataset.map(_check_shape),
       dataset_ops.get_legacy_output_types(dataset),
       output_shapes=output_shapes,
       output_classes=dataset_ops.get_legacy_output_classes(dataset))
示例#21
0
  def testIteratorStringHandle(self):
    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    feedable_iterator = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
        dataset_ops.get_legacy_output_shapes(dataset_3))
    next_element = feedable_iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(dataset_3).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))
    self.assertTrue(dataset_ops.get_structure(dataset_4).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))

    with self.cached_session() as sess:
      iterator_3_handle = sess.run(iterator_3.string_handle())
      iterator_4_handle = sess.run(iterator_4.string_handle())

      self.assertEqual(10,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(1,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(20,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(2,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(30,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(3,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(40,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_3_handle})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_4_handle})
 def _test_tuple_elements_by_padding(no_padding):
   dataset = build_dataset(sparse=no_padding)
   dataset = dataset.apply(grouping.bucket_by_sequence_length(
       element_length_func=_element_length_fn,
       bucket_batch_sizes=[2, 2, 2],
       bucket_boundaries=[0, 8],
       no_padding=no_padding))
   shapes = dataset_ops.get_legacy_output_shapes(dataset)
   self.assertEqual([None, None], shapes[0].as_list())
   self.assertEqual([None], shapes[1].as_list())
 def testIndefiniteRepeatShapeInference(self):
   dataset = self.make_batch_feature(
       filenames=self.test_filenames[0],
       label_key="label",
       num_epochs=None,
       batch_size=32)
   for shape, clazz in zip(
       nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)),
       nest.flatten(dataset_ops.get_legacy_output_classes(dataset))):
     if issubclass(clazz, ops.Tensor):
       self.assertEqual(32, shape[0])
示例#24
0
  def _buildMapDataset(self, components, count):

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
        _map_fn).repeat(count)
    self.assertEqual(
        [c.shape[1:] for c in components],
        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
    return dataset
示例#25
0
  def testFromTensors(self):
    """Test a dataset that represents a single tuple of tensors."""
    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))

    dataset = dataset_ops.Dataset.from_tensors(components)

    self.assertEqual(
        [c.shape for c in components],
        nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)))

    self.assertDatasetProduces(dataset, expected_output=[components])
示例#26
0
  def testRepeatRepeatTensorDataset(self):
    """Test the composition of repeat datasets."""
    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
    inner_count, outer_count = 7, 14

    dataset = dataset_ops.Dataset.from_tensors(components).repeat(
        inner_count).repeat(outer_count)
    self.assertEqual(
        [c.shape for c in components],
        [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
    self.assertDatasetProduces(dataset,
                               [components] * (inner_count * outer_count))
示例#27
0
  def __init__(self, filenames, compression_type=None, buffer_size=None,
               num_parallel_reads=None):
    """Creates a `TFRecordDataset` to read one or more TFRecord files.

    NOTE: The `num_parallel_reads` argument can be used to improve performance
    when reading from a remote filesystem.

    Args:
      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
        more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes in the read buffer. 0 means no buffering.
      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
        number of files to read in parallel. Defaults to reading files
        sequentially.

    Raises:
      TypeError: If any argument does not have the expected type.
      ValueError: If any argument does not have the expected shape.
    """
    if isinstance(filenames, dataset_ops.DatasetV2):
      if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
        raise TypeError(
            "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
      if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
          tensor_shape.scalar()):
        raise ValueError(
            "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
            "elements.")
    else:
      filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string)
      filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
      filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)

    self._filenames = filenames
    self._compression_type = compression_type
    self._buffer_size = buffer_size
    self._num_parallel_reads = num_parallel_reads

    def read_one_file(filename):
      return _TFRecordDataset(filename, compression_type, buffer_size)

    if num_parallel_reads is None:
      self._impl = filenames.flat_map(read_one_file)
    else:
      self._impl = ParallelInterleaveDataset(
          filenames, read_one_file, cycle_length=num_parallel_reads,
          block_length=1, sloppy=False, buffer_output_elements=None,
          prefetch_input_elements=None)
    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
    super(TFRecordDatasetV2, self).__init__(variant_tensor)
示例#28
0
    def dataset_fn(count=5, buffer_size=None, seed=0):
      repeat_dataset = (
          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
      if buffer_size:
        shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)

        self.assertEqual(
            tuple([c.shape[1:] for c in components]),
            dataset_ops.get_legacy_output_shapes(shuffle_dataset))
        return shuffle_dataset
      else:
        return repeat_dataset
 def testChooseFastest(self):
   dataset = dataset_ops.Dataset.range(600)
   f = lambda x: 2 * x
   dataset_a = dataset.batch(50).map(f)
   dataset_b = dataset.map(f).batch(50)
   merge = optimization._ChooseFastestDataset([dataset_a, dataset_b])
   self.assertDatasetProduces(
       merge,
       expected_output=[
           [i * 2 for i in range(j * 50, (j + 1) * 50)] for j in range(12)
       ],
       expected_shapes=dataset_ops.get_legacy_output_shapes(dataset_a))
示例#30
0
  def testCacheDatasetPassthrough(self):
    components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                  np.array([9.0, 10.0, 11.0, 12.0]))

    def dataset_fn(count=5, filename=None):
      repeat_dataset = (
          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
      if filename:
        return repeat_dataset.cache(filename)
      else:
        return repeat_dataset

    self.assertEqual(
        tuple([c.shape[1:] for c in components]),
        dataset_ops.get_legacy_output_shapes(dataset_fn()))

    get_next = self.getNext(dataset_fn())

    # First run without caching to collect the "ground truth".
    elements = []
    for _ in range(20):
      elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

    # Assert that the cached dataset has the same elements as the
    # "ground truth".
    get_next = self.getNext(dataset_fn(filename=self.cache_prefix))
    cached_elements = []
    for _ in range(20):
      cached_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertAllEqual(elements, cached_elements)

    # Re-initialize with an empty upstream (to throw errors.OutOfRangeError
    # if we didn't use the cache).
    get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix))
    replayed_elements = []
    for _ in range(20):
      replayed_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertEqual(cached_elements, replayed_elements)

    # Re-initialize with an empty upstream and a missing cache file (should
    # throw errors.OutOfRangeError immediately).
    get_next = self.getNext(
        dataset_fn(count=0, filename=self.cache_prefix + "nonsense"))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
示例#31
0
def is_dataset_shape_fully_defined(dataset):
    """Returns whether a dataset contains a final partial batch."""
    shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
    unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
    return not unknown_shapes
示例#32
0
 def _remote_fn(h):
     handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
     remote_iterator = iterator_ops.Iterator.from_string_handle(
         handle, dataset_ops.get_legacy_output_types(dataset_3),
         dataset_ops.get_legacy_output_shapes(dataset_3))
     return remote_iterator.get_next()
示例#33
0
        self.evaluate(get_next())
      if create_iterator_twice:
        get_next = self.getNext(dataset)
        result = self.evaluate(get_next())
        self._compare_output_to_expected(result, expected_values)
        with self.assertRaises(errors_impl.OutOfRangeError):
          self.evaluate(get_next())
    # Check shapes; if serialized is a Tensor we need its size to
    # properly check.
    batch_size = (
        self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor)
        else np.asarray(input_tensor).size)
    for k, f in feature_val.items():
      if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
        self.assertEqual(
            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[0],
            batch_size)
      elif isinstance(f, parsing_ops.VarLenFeature):
        self.assertEqual(
            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)

  def testEmptySerializedWithAllDefaults(self):
    sparse_name = "st_a"
    a_name = "a"
    b_name = "b"
    c_name = "c:has_a_tricky_name"
    a_default = [0, 42, 0]
    b_default = np.random.rand(3, 3).astype(bytes)
    c_default = np.random.rand(2).astype(np.float32)

    expected_st_a = (  # indices, values, shape
    def testNestedStructure(self):
        components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
                                                            np.array([6.,
                                                                      7.])),
                      np.array([8, 9, 10], dtype=np.int64))

        dataset = dataset_ops.Dataset.from_tensors(components)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.shuffle(10, 10)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.repeat(-1)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.filter(lambda x, y, z: True)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.take(5)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
        self.assertEqual(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual((([3], [3]), ([2], [2])),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset.
                                   from_tensors(((x[0], x[1]), (y[0], y[1]))))
        self.assertEqual(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual((([3], [3]), ([2], [2])),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.batch(32)
        self.assertEqual(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset_ops.get_legacy_output_types(dataset))
        dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
        self.assertEqual(
            (([None, 3], [None, 3]), ([None, 2], [None, 2])),
            nest.pack_sequence_as(
                dataset_output_shapes,
                [s.as_list() for s in nest.flatten(dataset_output_shapes)]))

        # Define a separate set of components with matching leading
        # dimension for the from-slices constructor.
        components_for_slices = (np.array([1, 2, 3], dtype=np.int64),
                                 (np.array([4., 5.,
                                            6.]), np.array([7., 8., 9.])),
                                 np.array([10, 11, 12], dtype=np.int64))

        dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([], ([], []), []),
                         dataset_ops.get_legacy_output_shapes(dataset))
示例#35
0
 def testNonRectangularInputs(self):
     elements = [[[1]], [[2, 3]], [[4, 5, 6]]]
     dataset = from_list.from_list(elements)
     self.assertEqual(tensor_shape.Dimension(1),
                      dataset_ops.get_legacy_output_shapes(dataset)[0])
     self.assertDatasetProduces(dataset, expected_output=elements)
示例#36
0
    def __init__(self,
                 filenames,
                 compression_type=None,
                 buffer_size=None,
                 num_parallel_reads=None):
        """Creates a `TFRecordDataset` to read one or more TFRecord files.

    Args:
      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
        more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes in the read buffer. If your input pipeline is I/O bottlenecked,
        consider setting this parameter to a value 1-100 MBs. If `None`, a
        sensible default for both local and remote file systems is used.
      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
        number of files to read in parallel. If greater than one, the records of
        files read in parallel are outputted in an interleaved order. If your
        input pipeline is I/O bottlenecked, consider setting this parameter to a
        value greater than one to parallelize the I/O. If `None`, files will be
        read sequentially.

    Raises:
      TypeError: If any argument does not have the expected type.
      ValueError: If any argument does not have the expected shape.
    """
        if isinstance(filenames, dataset_ops.DatasetV2):
            if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
                raise TypeError(
                    "`filenames` must be a `tf.data.Dataset` of `tf.string` elements."
                )
            if not dataset_ops.get_legacy_output_shapes(
                    filenames).is_compatible_with(tensor_shape.scalar()):
                raise ValueError(
                    "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
                    "elements.")
        else:
            filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string)
            filenames = array_ops.reshape(filenames, [-1],
                                          name="flat_filenames")
            filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)

        self._filenames = filenames
        self._compression_type = compression_type
        self._buffer_size = buffer_size
        self._num_parallel_reads = num_parallel_reads

        def read_one_file(filename):
            return _TFRecordDataset(filename, compression_type, buffer_size)

        if num_parallel_reads is None:
            self._impl = filenames.flat_map(read_one_file)
        else:
            self._impl = ParallelInterleaveDataset(
                filenames,
                read_one_file,
                cycle_length=num_parallel_reads,
                block_length=1,
                sloppy=False,
                buffer_output_elements=None,
                prefetch_input_elements=None)
        variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
        super(TFRecordDatasetV2, self).__init__(variant_tensor)
示例#37
0
 def _get_output_shapes(self, ds_fn):
   with ops.Graph().as_default():
     return dataset_ops.get_legacy_output_shapes(ds_fn())
示例#38
0
    def testIteratorStringHandleFuture(self):
        with forward_compat.forward_compatibility_horizon(2018, 8, 4):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            dataset_4 = dataset_ops.Dataset.from_tensor_slices(
                [10, 20, 30, 40])

            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

            handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            feedable_iterator = iterator_ops.Iterator.from_string_handle(
                handle_placeholder,
                dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            next_element = feedable_iterator.get_next()

            self.assertTrue(
                structure.are_compatible(
                    dataset_ops.get_structure(dataset_3),
                    dataset_ops.get_structure(feedable_iterator)))

            with self.cached_session() as sess:
                iterator_3_handle = sess.run(iterator_3.string_handle())
                iterator_4_handle = sess.run(iterator_4.string_handle())

                self.assertEqual(
                    10,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    1,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    20,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    2,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    30,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    3,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    40,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element,
                             feed_dict={handle_placeholder: iterator_3_handle})
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element,
                             feed_dict={handle_placeholder: iterator_4_handle})
示例#39
0
    def __init__(self, input_dataset, features, num_parallel_calls):
        self._input_dataset = input_dataset
        if not structure.are_compatible(
                input_dataset.element_spec,
                structure.TensorStructure(dtypes.string, [None])):
            raise TypeError(
                "Input dataset should be a dataset of vectors of strings")
        self._num_parallel_calls = num_parallel_calls
        # pylint: disable=protected-access
        self._features = parsing_ops._prepend_none_dimension(features)
        # sparse_keys and dense_keys come back sorted here.
        (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
         dense_shapes) = parsing_ops._features_to_raw_params(
             self._features, [
                 parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
                 parsing_ops.FixedLenFeature,
                 parsing_ops.FixedLenSequenceFeature
             ])
        # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
        (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys,
         dense_shapes,
         dense_shape_as_shape) = parsing_ops._process_raw_parameters(
             None, dense_defaults, sparse_keys, sparse_types, dense_keys,
             dense_types, dense_shapes)
        # pylint: enable=protected-access
        self._sparse_keys = sparse_keys
        self._sparse_types = sparse_types
        self._dense_keys = dense_keys
        self._dense_defaults = dense_defaults_vec
        self._dense_shapes = dense_shapes
        self._dense_types = dense_types
        input_dataset_shape = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)
        dense_output_shapes = [
            input_dataset_shape.concatenate(shape)
            for shape in dense_shape_as_shape
        ]
        sparse_output_shapes = [
            input_dataset_shape.concatenate([None])
            for _ in range(len(sparse_keys))
        ]

        output_shapes = dict(
            zip(self._dense_keys + self._sparse_keys,
                dense_output_shapes + sparse_output_shapes))
        output_types = dict(
            zip(self._dense_keys + self._sparse_keys,
                self._dense_types + self._sparse_types))
        output_classes = dict(
            zip(self._dense_keys + self._sparse_keys,
                [ops.Tensor for _ in range(len(self._dense_defaults))] + [
                    sparse_tensor.SparseTensor
                    for _ in range(len(self._sparse_keys))
                ]))
        self._element_spec = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)

        variant_tensor = (
            gen_experimental_dataset_ops.experimental_parse_example_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._num_parallel_calls,
                self._dense_defaults,
                self._sparse_keys,
                self._dense_keys,
                self._sparse_types,
                self._dense_shapes,
                **self._flat_structure))
        super(_ParseExampleDataset, self).__init__(input_dataset,
                                                   variant_tensor)
示例#40
0
  def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
    """Test a dataset that maps a TF function across its input elements."""
    # The pipeline is TensorSliceDataset ->
    # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    def dataset_fn(batch_size, count):
      dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
          count).apply(
              batching.map_and_batch(
                  map_func=_map_fn,
                  batch_size=batch_size,
                  num_parallel_calls=num_parallel_calls,
                  num_parallel_batches=num_parallel_batches))
      return dataset

    # Batch of a finite input, where the batch_size divides the
    # total number of elements.
    dataset = dataset_fn(14, 28)
    get_next = self.getNext(dataset)
    self.assertEqual(
        [[None] + list(c.shape[1:]) for c in components],
        [shape.as_list()
         for shape in dataset_ops.get_legacy_output_shapes(dataset)])
    num_batches = (28 * 7) // 14
    for i in range(num_batches):
      result = self.evaluate(get_next())
      for component, result_component in zip(components, result):
        for j in range(14):
          self.assertAllEqual(component[(i * 14 + j) % 7]**2,
                              result_component[j])
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

    # Batch of a finite input, where the batch_size does not
    # divide the total number of elements.
    get_next = self.getNext(dataset_fn(8, 14))

    # We expect (num_batches - 1) full-sized batches.
    num_batches = int(math.ceil((14 * 7) / 8))
    for i in range(num_batches - 1):
      result = self.evaluate(get_next())
      for component, result_component in zip(components, result):
        for j in range(8):
          self.assertAllEqual(component[(i * 8 + j) % 7]**2,
                              result_component[j])

    result = self.evaluate(get_next())
    for component, result_component in zip(components, result):
      for j in range((14 * 7) % 8):
        self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
                            result_component[j])
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

    # Batch of an empty input should fail straight away.
    self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[])

    # Empty batch should be an initialization time error.
    with self.assertRaises(errors.InvalidArgumentError):
      self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[])
示例#41
0
文件: ops.py 项目: tomsmoker/language
 def reduce_fn(_, x):
     return x.padded_batch(batch_size,
                           dataset_ops.get_legacy_output_shapes(x))
    def __init__(self, input_dataset, features, num_parallel_calls,
                 deterministic):
        self._input_dataset = input_dataset
        if not structure.are_compatible(
                input_dataset.element_spec,
                tensor_spec.TensorSpec([None], dtypes.string)):
            raise TypeError(
                "Input dataset should be a dataset of vectors of strings")
        self._num_parallel_calls = num_parallel_calls
        if deterministic is None:
            self._deterministic = "default"
        elif deterministic:
            self._deterministic = "true"
        else:
            self._deterministic = "false"
        # pylint: disable=protected-access
        self._features = parsing_ops._prepend_none_dimension(features)
        # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
        params = parsing_ops._ParseOpParams.from_features(
            self._features, [
                parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
                parsing_ops.FixedLenFeature,
                parsing_ops.FixedLenSequenceFeature, parsing_ops.RaggedFeature
            ])
        # pylint: enable=protected-access
        self._sparse_keys = params.sparse_keys
        self._sparse_types = params.sparse_types
        self._ragged_keys = params.ragged_keys
        self._ragged_value_types = params.ragged_value_types
        self._ragged_split_types = params.ragged_split_types
        self._dense_keys = params.dense_keys
        self._dense_defaults = params.dense_defaults_vec
        self._dense_shapes = params.dense_shapes_as_proto
        self._dense_types = params.dense_types
        input_dataset_shape = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)

        self._element_spec = {}

        for (key, value_type) in zip(params.sparse_keys, params.sparse_types):
            self._element_spec[key] = sparse_tensor.SparseTensorSpec(
                input_dataset_shape.concatenate([None]), value_type)

        for (key, value_type, dense_shape) in zip(params.dense_keys,
                                                  params.dense_types,
                                                  params.dense_shapes):
            self._element_spec[key] = tensor_spec.TensorSpec(
                input_dataset_shape.concatenate(dense_shape), value_type)

        for (key, value_type, splits_type) in zip(params.ragged_keys,
                                                  params.ragged_value_types,
                                                  params.ragged_split_types):
            self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
                input_dataset_shape.concatenate([None]), value_type, 1,
                splits_type)

        if deterministic is not None or compat.forward_compatible(2020, 3, 6):
            variant_tensor = (
                gen_experimental_dataset_ops.parse_example_dataset_v2(
                    self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                    self._num_parallel_calls,
                    self._dense_defaults,
                    self._sparse_keys,
                    self._dense_keys,
                    self._sparse_types,
                    self._dense_shapes,
                    deterministic=self._deterministic,
                    ragged_keys=self._ragged_keys,
                    ragged_value_types=self._ragged_value_types,
                    ragged_split_types=self._ragged_split_types,
                    **self._flat_structure))
        else:
            variant_tensor = (
                gen_experimental_dataset_ops.parse_example_dataset(
                    self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                    self._num_parallel_calls,
                    self._dense_defaults,
                    self._sparse_keys,
                    self._dense_keys,
                    self._sparse_types,
                    self._dense_shapes,
                    ragged_keys=self._ragged_keys,
                    ragged_value_types=self._ragged_value_types,
                    ragged_split_types=self._ragged_split_types,
                    **self._flat_structure))
        super(_ParseExampleDataset, self).__init__(input_dataset,
                                                   variant_tensor)
示例#43
0
  def assertDatasetProduces(self,
                            dataset,
                            expected_output=None,
                            expected_shapes=None,
                            expected_error=None,
                            requires_initialization=False,
                            num_test_iterations=1,
                            assert_items_equal=False,
                            expected_error_iter=1):
    """Asserts that a dataset produces the expected output / error.

    Args:
      dataset: A dataset to check for the expected output / error.
      expected_output: A list of elements that the dataset is expected to
        produce.
      expected_shapes: A list of TensorShapes which is expected to match
        output_shapes of dataset.
      expected_error: A tuple `(type, predicate)` identifying the expected error
        `dataset` should raise. The `type` should match the expected exception
        type, while `predicate` should either be 1) a unary function that inputs
        the raised exception and returns a boolean indicator of success or 2) a
        regular expression that is expected to match the error message
        partially.
      requires_initialization: Indicates that when the test is executed in graph
        mode, it should use an initializable iterator to iterate through the
        dataset (e.g. when it contains stateful nodes). Defaults to False.
      num_test_iterations: Number of times `dataset` will be iterated. Defaults
        to 1.
      assert_items_equal: Tests expected_output has (only) the same elements
        regardless of order.
      expected_error_iter: How many times to iterate before expecting an error,
        if an error is expected.
    """
    self.assertTrue(
        expected_error is not None or expected_output is not None,
        "Exactly one of expected_output or expected error should be provided.")
    if expected_error:
      self.assertTrue(
          expected_output is None,
          "Exactly one of expected_output or expected error should be provided."
      )
      with self.assertRaisesWithPredicateMatch(expected_error[0],
                                               expected_error[1]):
        get_next = self.getNext(
            dataset, requires_initialization=requires_initialization)
        for _ in range(expected_error_iter):
          self.evaluate(get_next())
      return
    if expected_shapes:
      self.assertEqual(expected_shapes,
                       dataset_ops.get_legacy_output_shapes(dataset))
    self.assertGreater(num_test_iterations, 0)
    for _ in range(num_test_iterations):
      get_next = self.getNext(
          dataset, requires_initialization=requires_initialization)
      result = []
      for _ in range(len(expected_output)):
        try:
          result.append(self.evaluate(get_next()))
        except errors.OutOfRangeError:
          raise AssertionError(
              "Dataset ended early, producing %d elements out of %d. "
              "Dataset output: %s" %
              (len(result), len(expected_output), str(result)))
      self._compareOutputToExpected(result, expected_output, assert_items_equal)
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next())
示例#44
0
def _flat_shapes(dataset):
    return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
示例#45
0
def _flat_shapes(dataset):
    return [
        ts.as_list()
        for ts in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
    ]
示例#46
0
 def LoadingFunc(h):
   remote_iterator = iterator_ops.Iterator.from_string_handle(
       h, dataset_ops.get_legacy_output_types(source_dataset),
       dataset_ops.get_legacy_output_shapes(source_dataset))
   return remote_iterator.get_next()
 def testIndefiniteRepeatShapeInference(self):
     dataset = readers.make_tf_record_dataset(
         file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
     for shape in nest.flatten(
             dataset_ops.get_legacy_output_shapes(dataset)):
         self.assertEqual(32, shape[0])
示例#48
0
    def __init__(self,
                 dataset,
                 output_types,
                 output_shapes=None,
                 output_classes=None):
        """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types. If omitted,
        the class types will be inherited from `dataset`.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
        self._input_dataset = dataset

        input_types = dataset_ops.get_legacy_output_types(dataset)
        # Validate that the types are compatible.
        output_types = nest.map_structure(dtypes.as_dtype, output_types)
        flat_original_types = nest.flatten(input_types)
        flat_new_types = nest.flatten(output_types)
        if flat_original_types != flat_new_types:
            raise ValueError(
                "Dataset with output types %r cannot be restructured to have "
                "output types %r" %
                (dataset_ops.get_legacy_output_types(dataset), output_types))

        input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
        if output_shapes is None:
            # Inherit shapes from the original `dataset`.
            output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(input_shapes))
        else:
            # Validate that the shapes are compatible.
            nest.assert_same_structure(output_types, output_shapes)
            flat_original_shapes = nest.flatten(input_shapes)
            flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

            for original_shape, new_shape in zip(flat_original_shapes,
                                                 flat_new_shapes):
                if not original_shape.is_compatible_with(new_shape):
                    raise ValueError(
                        "Dataset with output shapes %r cannot be restructured to have "
                        "incompatible output shapes %r" %
                        (input_shapes, output_shapes))
            output_shapes = nest.map_structure_up_to(output_types,
                                                     tensor_shape.as_shape,
                                                     output_shapes)

        input_classes = dataset_ops.get_legacy_output_classes(dataset)
        if output_classes is None:
            # Inherit class types from the original `dataset`.
            output_classes = nest.pack_sequence_as(output_types,
                                                   nest.flatten(input_classes))

        self._element_spec = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
        super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
 def output_shapes(self):
     return dataset_ops.get_legacy_output_shapes(self._iterator)
示例#50
0
 def _get_output_shapes(self, ds_fn):
     assert not context.executing_eagerly()
     with ops.Graph().as_default():
         return dataset_ops.get_legacy_output_shapes(ds_fn())
示例#51
0
def get_batch_dimension(iterator):
    shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator))
    # Take the batch size from the first element, as it should be the same for
    # all.
    dims = shapes[0].dims
    return dims[0] if dims else None
示例#52
0
 def _remote_fn(h):
     remote_iterator = iterator_ops.Iterator.from_string_handle(
         h, dataset_ops.get_legacy_output_types(dataset_3),
         dataset_ops.get_legacy_output_shapes(dataset_3))
     return remote_iterator.get_next()
示例#53
0
 def loading_func(h):
     remote_itr = iterator_ops.Iterator.from_string_handle(
         h, dataset_ops.get_legacy_output_types(itr),
         dataset_ops.get_legacy_output_shapes(itr))
     return remote_itr.get_next()