コード例 #1
0
 def _make_reduce_func(self, reduce_func, input_dataset):
   """Make wrapping defun for reduce_func."""
   nested_dataset = dataset_ops.DatasetStructure(
       input_dataset._element_structure)  # pylint: disable=protected-access
   input_structure = structure.NestedStructure(
       (structure.TensorStructure(dtypes.int64, []), nested_dataset))
   self._reduce_func = dataset_ops.StructuredFunctionWrapper(
       reduce_func, self._transformation_name(),
       input_structure=input_structure)
   if not isinstance(
       self._reduce_func.output_structure, dataset_ops.DatasetStructure):
     raise TypeError("`reduce_func` must return a `Dataset` object.")
   # pylint: disable=protected-access
   self._structure = (
       self._reduce_func.output_structure._element_structure)
コード例 #2
0
ファイル: grouping.py プロジェクト: zuqiutxy/tensorflow
 def _make_reduce_func(self, reduce_func, input_dataset):
   """Make wrapping defun for reduce_func."""
   nested_dataset = dataset_ops.DatasetStructure(
       structure.Structure._from_legacy_structure(  # pylint: disable=protected-access
           input_dataset.output_types, input_dataset.output_shapes,
           input_dataset.output_classes))
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       reduce_func,
       self._transformation_name(),
       input_classes=(ops.Tensor, nested_dataset),
       input_shapes=(tensor_shape.scalar(), nested_dataset),
       input_types=(dtypes.int64, nested_dataset))
   if not isinstance(
       wrapped_func.output_structure, dataset_ops.DatasetStructure):
     raise TypeError("`reduce_func` must return a `Dataset` object.")
   # pylint: disable=protected-access
   element_structure = wrapped_func.output_structure._element_structure
   self._output_classes = element_structure._to_legacy_output_classes()
   self._output_types = element_structure._to_legacy_output_types()
   self._output_shapes = element_structure._to_legacy_output_shapes()
   self._reduce_func = wrapped_func.function
コード例 #3
0
ファイル: optimization.py プロジェクト: flavz27/master_PA
    def __init__(self,
                 input_dataset,
                 functions,
                 ratio_numerator=1,
                 ratio_denominator=1,
                 num_elements_per_branch=None):
        """Chooses the fastest of some dataset functions.

    Given dataset functions that take input_dataset as input and output
    another dataset, produces elements as quickly as the fastest of these
    output datasets. Note that datasets in the dataset functions are assumed
    to be stateless, and the iterators created by the functions' output datasets
    will, given the same input elements, all produce the same output elements.
    Datasets in the functions are also expected to iterate over the input
    dataset at most once. The violation of these conditions may lead to
    undefined behavior.

    For example:
    ```python
    dataset = tf.data.Dataset.range(100)
    dataset = _ChooseFastestDataset(
        dataset,
        [
            lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
            lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
        ],
        ratio=10,
        num_elements_per_branch=10
    )
    ```
    The resulting dataset will produce elements equivalent to
    `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
    `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`

    Note that the first `num_elements_per_branch` iterations may be slower due
    to the
    overhead of dynamically picking the fastest dataset. Namely, for these
    iterations, the dataset will produce elements from any of branches to
    determine which input is the fastest. For all subsequent iterations, that
    input will be used.

    Args:
      input_dataset: A `Dataset` that can be used as input to `functions`.
      functions: A list of callables, each of which takes a `Dataset` as input
        and returns a `Dataset`.
      ratio_numerator: The numerator in the ratio of input elements consumed to
        output elements produced for each function. This should be the same for
        all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_numerator should be 10.
      ratio_denominator: The denominator in the ratio of input elements consumed
        to output elements produced for each function. This should be the same
        for all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_denominator should be 1.
      num_elements_per_branch: The number of elements to get from each branch
        before deciding which dataset is fastest. In the first len(functions) *
        num_elements_per_branch iterations, the dataset will call from one of
        the branches, and update its knowledge of which input is the fastest.
        Note that (num_elements_per_branch * ratio) is expected to be an
        integer.

    Returns:
      A `Dataset` that has the same elements the inputs.
    """
        nested_structure = structure_lib.NestedStructure(
            dataset_ops.DatasetStructure(
                structure_lib.convert_legacy_structure(
                    input_dataset.output_types, input_dataset.output_shapes,
                    input_dataset.output_classes)))
        self._funcs = [
            dataset_ops.StructuredFunctionWrapper(
                f, "ChooseFastestV2", input_structure=nested_structure)
            for f in functions
        ]
        self._structure = self._funcs[0].output_structure._element_structure  # pylint: disable=protected-access

        self._captured_arguments = []
        for f in self._funcs:
            self._captured_arguments.extend(f.function.captured_inputs)
        self._capture_lengths = [
            len(f.function.captured_inputs) for f in self._funcs
        ]

        if ratio_numerator <= 0 or ratio_denominator <= 0:
            raise ValueError("ratio must be positive.")

        if num_elements_per_branch is None:
            # Pick a sensible default based on `ratio_denominator`
            num_elements_per_branch = 10 * ratio_denominator

        variant_tensor = (
            gen_experimental_dataset_ops.choose_fastest_branch_dataset(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                ratio_numerator=ratio_numerator,
                ratio_denominator=ratio_denominator,
                other_arguments=self._captured_arguments,
                num_elements_per_branch=num_elements_per_branch,
                branches=[f.function for f in self._funcs],
                other_arguments_lengths=self._capture_lengths,
                **dataset_ops.flat_structure(self)))
        super(_ChooseFastestBranchDataset,
              self).__init__(input_dataset, variant_tensor)
コード例 #4
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  def testAsFunctionWithMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

  def testAsFunctionWithMapInFlatMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).flat_map(
          lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, list(original_dataset))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertLen(dataset_fn()._inputs(), num_inputs)

  @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEmpty(dataset_fn._inputs())

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testNoWarnings(self):
    with test.mock.patch.object(warnings, "warn") as mock_log:
      dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
      dataset_fn(dataset_ops.Dataset.range(10))
      self.assertEmpty(mock_log.call_args_list)

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testFunctions(self):
    dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
    self.assertLen(dataset._functions(), 1)

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testDatasetStructure(self, tf_value_fn, expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset_ops.make_one_shot_iterator(dataset)
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @parameterized.named_parameters(
      ("Async", context.ASYNC),
      ("Sync", context.SYNC),
  )
  def testDatasetEagerIteration(self, execution_mode):
    with context.eager_mode(), context.execution_mode(execution_mode):
      val = 0
      dataset = dataset_ops.Dataset.range(10)
      for foo in dataset:
        self.assertEqual(val, foo.numpy())
        val += 1

  def testDatasetAsFunctionArgument(self):

    @def_function.function
    def _uses_dataset(d):
      accumulator = array_ops.zeros([], dtype=dtypes.int64)
      for value in d:
        accumulator += value
      return accumulator

    with ops.device("CPU"):
      first_dataset = dataset_ops.Dataset.range(10)
      self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
      second_dataset = dataset_ops.Dataset.range(11)
      self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
      first_concrete = _uses_dataset.get_concrete_function(first_dataset)
      # The dataset should not be a captured input
      self.assertEmpty(first_concrete.graph.captures)
      # The two datasets have the same structure and so should re-use a trace.
      self.assertIs(first_concrete,
                    _uses_dataset.get_concrete_function(second_dataset))
      # With a different structure we should use a different trace.
      self.assertIsNot(
          first_concrete,
          _uses_dataset.get_concrete_function(
              dataset_ops.Dataset.zip((first_dataset, second_dataset))))

  def testLimitedRetracing(self):
    trace_count = [0]

    @def_function.function
    def f(ds):
      trace_count[0] += 1
      counter = np.int64(0)
      for elem in ds:
        counter += elem
      return counter

    dataset = dataset_ops.Dataset.range(5)
    dataset2 = dataset_ops.Dataset.range(10)

    for _ in range(10):
      self.assertEqual(self.evaluate(f(dataset)), 10)
      self.assertEqual(self.evaluate(f(dataset2)), 45)
      self.assertEqual(trace_count[0], 1)
コード例 #5
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEqual(0, len(dataset_fn._inputs()))

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # TODO(b/119882922): use-after-free bug in eager mode.
  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testSkipEagerDatasetStructure(self, tf_value_fn,
                                    expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        _ = dataset.make_initializable_iterator()
コード例 #6
0
class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        with self.cached_session() as sess:
            graph = graph_pb2.GraphDef().FromString(
                sess.run(dataset._as_serialized_graph()))
            self.assertTrue(
                any(node.op != "RangeDataset" for node in graph.node))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetOpsTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEqual(0, len(dataset_fn._inputs()))

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    def testOptionsDefault(self):
        ds = dataset_ops.Dataset.range(0)
        self.assertEqual(dataset_ops.Options(), ds.options())

    def testOptionsOnce(self):
        options = dataset_ops.Options()
        ds = dataset_ops.Dataset.range(0).with_options(options).cache()
        self.assertEqual(options, ds.options())

    def testOptionsTwiceSame(self):
        options = dataset_ops.Options()
        options.experimental_autotune = True
        ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
            options)
        self.assertEqual(options, ds.options())

    def testOptionsTwiceDifferent(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = False
        ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
            options2)
        self.assertTrue(ds.options().experimental_autotune)
        # Explicitly check that flag is False since assertFalse allows None
        self.assertIs(ds.options().experimental_filter_fusion, False)

    def testOptionsTwiceDifferentError(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_autotune = False
        with self.assertRaisesRegexp(ValueError,
                                     "Cannot merge incompatible values"):
            dataset_ops.Dataset.range(0).with_options(options1).with_options(
                options2)

    def testOptionsMergeOptionsFromMultipleInputs(self):
        options1 = dataset_ops.Options()
        options1.experimental_autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_filter_fusion = True
        ds = dataset_ops.Dataset.zip(
            (dataset_ops.Dataset.range(0).with_options(options1),
             dataset_ops.Dataset.range(0).with_options(options2)))
        self.assertTrue(ds.options().experimental_autotune)
        self.assertTrue(ds.options().experimental_filter_fusion)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)