Beispiel #1
0
    def __init__(self, path, element_spec, compression=None, reader_func=None):

        if reader_func is None:
            reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
                lambda x: x,
                cycle_length=multiprocessing.cpu_count(),
                num_parallel_calls=dataset_ops.AUTOTUNE)

        self._path = path
        self._element_spec = element_spec
        self._compression = compression

        self._reader_func = dataset_ops.StructuredFunctionWrapper(
            reader_func,
            "load()",
            # Dataset of datasets of input elements
            input_structure=dataset_ops.DatasetSpec(
                dataset_ops.DatasetSpec(element_spec)))

        variant_tensor = gen_experimental_dataset_ops.load_dataset(
            path,
            reader_func_other_args=self._reader_func.function.captured_inputs,
            compression=compression,
            reader_func=self._reader_func.function,
            **self._flat_structure)
        super(_LoadDataset, self).__init__(variant_tensor)
Beispiel #2
0
  def __init__(self, path, element_spec=None, compression=None,
               reader_func=None):

    if reader_func is None:
      reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
          lambda x: x,
          cycle_length=multiprocessing.cpu_count(),
          num_parallel_calls=dataset_ops.AUTOTUNE)

    self._path = path
    if element_spec is None:
      with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f:
        encoded_spec = f.read()
      struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
      struct_pb.ParseFromString(encoded_spec)
      coder = nested_structure_coder.StructureCoder()
      spec = coder.decode_proto(struct_pb)
      self._element_spec = spec
    else:
      self._element_spec = element_spec
    self._compression = compression
    self._reader_func = dataset_ops.StructuredFunctionWrapper(
        reader_func,
        "load()",
        # Dataset of datasets of input elements
        input_structure=dataset_ops.DatasetSpec(
            dataset_ops.DatasetSpec(self._element_spec)))

    variant_tensor = gen_experimental_dataset_ops.load_dataset(
        path,
        reader_func_other_args=self._reader_func.function.captured_inputs,
        compression=compression,
        reader_func=self._reader_func.function,
        **self._flat_structure)
    super(_LoadDataset, self).__init__(variant_tensor)
Beispiel #3
0
  def testDatasetSpecHierarchical(self):
    spec_1 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(1, None), dtype=dtypes.int32),
        [5, None, 2])
    spec_2 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(None, None), dtype=dtypes.int32),
        [None, None, None])
    spec_3 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(1, 2), dtype=dtypes.int32),
        [5, 3, 2])
    spec_4 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(None, 2), dtype=dtypes.int32),
        [None, 1, None])

    self.assertTrue(spec_1.is_subtype_of(spec_1))

    self.assertTrue(spec_1.is_subtype_of(spec_2))
    self.assertTrue(spec_3.is_subtype_of(spec_2))
    self.assertTrue(spec_4.is_subtype_of(spec_2))

    self.assertFalse(spec_2.is_subtype_of(spec_1))
    self.assertFalse(spec_2.is_subtype_of(spec_3))
    self.assertFalse(spec_2.is_subtype_of(spec_4))

    self.assertEqual(spec_1.most_specific_common_supertype([]), spec_1)
    self.assertEqual(spec_1.most_specific_common_supertype([spec_4]), spec_2)
    self.assertEqual(
        spec_1.most_specific_common_supertype([spec_3, spec_4]), spec_2)
    self.assertEqual(
        spec_1.most_specific_common_supertype([spec_2, spec_3, spec_4]), spec_2)
Beispiel #4
0
    def __init__(self,
                 input_dataset,
                 path,
                 shard_func,
                 compression=None,
                 reader_func=None,
                 pending_snapshot_expiry_seconds=None,
                 use_legacy_function=False):

        if reader_func is None:
            reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
                lambda x: x,
                cycle_length=multiprocessing.cpu_count(),
                num_parallel_calls=dataset_ops.AUTOTUNE)

        self._input_dataset = input_dataset
        self._path = path
        self._compression = compression

        self._reader_func = dataset_ops.StructuredFunctionWrapper(
            reader_func,
            self._transformation_name() + ".reader_func",
            # Dataset of datasets of input elements
            input_structure=dataset_ops.DatasetSpec(
                dataset_ops.DatasetSpec(input_dataset.element_spec)),
            use_legacy_function=use_legacy_function)
        self._shard_func = dataset_ops.StructuredFunctionWrapper(
            shard_func,
            self._transformation_name() + ".shard_func",
            dataset=input_dataset,
            use_legacy_function=use_legacy_function)

        if ((not self._shard_func.output_structure.is_compatible_with(
                tensor_spec.TensorSpec([], dtypes.int32)))
                and (not self._shard_func.output_structure.is_compatible_with(
                    tensor_spec.TensorSpec([], dtypes.int64)))):
            raise TypeError(
                "shard_func must return a 0-dimension tensor containing an int."
            )

        variant_tensor = ged_ops.snapshot_dataset_v2(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            path,
            self._reader_func.function.captured_inputs,
            self._shard_func.function.captured_inputs,
            compression=compression,
            reader_func=self._reader_func.function,
            shard_func=self._shard_func.function,
            **self._flat_structure)
        super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
Beispiel #5
0
 def testDatasetSpecConstructor(self):
   rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
   st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
   t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
   element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
   ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
   self.assertEqual(ds_struct._element_spec, element_spec)
   # Note: shape was automatically converted from a list to a TensorShape.
   self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
 def testEncodeDataSetSpec(self):
   structure = [dataset_ops.DatasetSpec(
       {"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32),
        "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32),
        "t": tensor_spec.TensorSpec([10, 8], dtypes.string)})]
   self.assertTrue(self._coder.can_encode(structure))
   encoded = self._coder.encode_structure(structure)
   decoded = self._coder.decode_proto(encoded)
   self.assertEqual(structure, decoded)
  def testDatasetSpecTraceType(self):
    trace_type_1 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32),
        [5])
    trace_type_2 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32),
        [5])

    self.assertEqual(trace_type_1, trace_type_2)
    self.assertEqual(hash(trace_type_1), hash(trace_type_2))
    self.assertTrue(trace_type_1.is_subtype_of(trace_type_2))
    self.assertTrue(trace_type_2.is_subtype_of(trace_type_1))

    trace_type_3 = dataset_ops.DatasetSpec(
        tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32),
        [6])
    self.assertNotEqual(trace_type_1, trace_type_3)
    self.assertFalse(trace_type_1.is_subtype_of(trace_type_3))
    self.assertFalse(trace_type_3.is_subtype_of(trace_type_1))
Beispiel #8
0
  def testInputSignature(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        np.arange(10).astype(np.int32)).batch(5)

    @def_function.function(input_signature=[
        dataset_ops.DatasetSpec(
            tensor_spec.TensorSpec(
                shape=(None,), dtype=dtypes.int32, name=None),
            tensor_shape.TensorShape([]))
    ])
    def fn(_):
      pass

    fn(dataset)
 def _make_reduce_func(self, reduce_func, input_dataset):
     """Make wrapping defun for reduce_func."""
     nested_dataset = dataset_ops.DatasetSpec(input_dataset.element_spec)
     input_structure = (tensor_spec.TensorSpec([], dtypes.int64),
                        nested_dataset)
     self._reduce_func = dataset_ops.StructuredFunctionWrapper(
         reduce_func,
         self._transformation_name(),
         input_structure=input_structure)
     if not isinstance(self._reduce_func.output_structure,
                       dataset_ops.DatasetSpec):
         raise TypeError("`reduce_func` must return a `Dataset` object.")
     # pylint: disable=protected-access
     self._element_spec = (self._reduce_func.output_structure._element_spec)
Beispiel #10
0
  def testDatasetSpecHierarchicalDict(self):
    spec_1 = dataset_ops.DatasetSpec(
        {"a": tensor_spec.TensorSpec(shape=(1, None), dtype=dtypes.int32)},
        [])
    spec_2 = dataset_ops.DatasetSpec(
        {"a": tensor_spec.TensorSpec(shape=(None, None), dtype=dtypes.int32)},
        [])
    spec_3 = dataset_ops.DatasetSpec(
        {"b": tensor_spec.TensorSpec(shape=(1, None), dtype=dtypes.int32)},
        [])
    spec_4 = dataset_ops.DatasetSpec({"b": None}, [])

    self.assertTrue(spec_1.is_subtype_of(spec_1))
    self.assertTrue(spec_1.is_subtype_of(spec_2))
    self.assertFalse(spec_2.is_subtype_of(spec_1))

    self.assertFalse(spec_1.is_subtype_of(spec_3))
    self.assertFalse(spec_3.is_subtype_of(spec_1))
    self.assertFalse(spec_2.is_subtype_of(spec_3))
    self.assertFalse(spec_3.is_subtype_of(spec_2))

    self.assertTrue(spec_4.is_subtype_of(spec_4))
    self.assertEqual(spec_4.most_specific_common_supertype([]), spec_4)
    self.assertEqual(spec_4.most_specific_common_supertype([spec_4]), spec_4)
Beispiel #11
0
 def testDatasetDatasetSpec(self):
     self.checkDatasetSpec(
         dataset_ops.Dataset.from_tensor_slices(
             constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))
Beispiel #12
0
    def __init__(self,
                 input_dataset,
                 functions,
                 ratio_numerator=1,
                 ratio_denominator=1,
                 num_elements_per_branch=None):
        """Chooses the fastest of some dataset functions.

    Given dataset functions that take input_dataset as input and output
    another dataset, produces elements as quickly as the fastest of these
    output datasets. Note that datasets in the dataset functions are assumed
    to be stateless, and the iterators created by the functions' output datasets
    will, given the same input elements, all produce the same output elements.
    Datasets in the functions are also expected to iterate over the input
    dataset at most once. The violation of these conditions may lead to
    undefined behavior.

    For example:
    ```python
    dataset = tf.data.Dataset.range(100)
    dataset = _ChooseFastestDataset(
        dataset,
        [
            lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
            lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
        ],
        ratio=10,
        num_elements_per_branch=10
    )
    ```
    The resulting dataset will produce elements equivalent to
    `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
    `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`

    Note that the first `num_elements_per_branch` iterations may be slower due
    to the
    overhead of dynamically picking the fastest dataset. Namely, for these
    iterations, the dataset will produce elements from any of branches to
    determine which input is the fastest. For all subsequent iterations, that
    input will be used.

    Args:
      input_dataset: A `Dataset` that can be used as input to `functions`.
      functions: A list of callables, each of which takes a `Dataset` as input
        and returns a `Dataset`.
      ratio_numerator: The numerator in the ratio of input elements consumed to
        output elements produced for each function. This should be the same for
        all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_numerator should be 10.
      ratio_denominator: The denominator in the ratio of input elements consumed
        to output elements produced for each function. This should be the same
        for all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_denominator should be 1.
      num_elements_per_branch: The number of elements to get from each branch
        before deciding which dataset is fastest. In the first len(functions) *
        num_elements_per_branch iterations, the dataset will call from one of
        the branches, and update its knowledge of which input is the fastest.
        Note that (num_elements_per_branch * ratio) is expected to be an
        integer.

    Returns:
      A `Dataset` that has the same elements the inputs.
    """
        input_structure = dataset_ops.DatasetSpec(input_dataset.element_spec)
        self._funcs = [
            dataset_ops.StructuredFunctionWrapper(
                f, "ChooseFastestV2", input_structure=input_structure)
            for f in functions
        ]
        self._element_spec = self._funcs[0].output_structure._element_spec  # pylint: disable=protected-access

        self._captured_arguments = []
        for f in self._funcs:
            self._captured_arguments.extend(f.function.captured_inputs)
        self._capture_lengths = [
            len(f.function.captured_inputs) for f in self._funcs
        ]

        if ratio_numerator <= 0 or ratio_denominator <= 0:
            raise ValueError("ratio must be positive.")

        if num_elements_per_branch is None:
            # Pick a sensible default based on `ratio_denominator`
            num_elements_per_branch = 10 * ratio_denominator

        variant_tensor = (
            gen_experimental_dataset_ops.choose_fastest_branch_dataset(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                ratio_numerator=ratio_numerator,
                ratio_denominator=ratio_denominator,
                other_arguments=self._captured_arguments,
                num_elements_per_branch=num_elements_per_branch,
                branches=[f.function for f in self._funcs],
                other_arguments_lengths=self._capture_lengths,
                **self._flat_structure))
        super(_ChooseFastestBranchDataset,
              self).__init__(input_dataset, variant_tensor)
Beispiel #13
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    def testAsFunctionWithMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).map(
                lambda x: x * 2)
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

    def testAsFunctionWithMapInFlatMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).flat_map(
                lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, list(original_dataset))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertLen(dataset_fn()._inputs(), num_inputs)

    @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEmpty(dataset_fn._inputs())

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testNoWarnings(self):
        with test.mock.patch.object(warnings, "warn") as mock_log:
            dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
            dataset_fn(dataset_ops.Dataset.range(10))
            self.assertEmpty(mock_log.call_args_list)

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testFunctions(self):
        dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        self.assertLen(dataset._functions(), 1)

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), sparse_tensor.SparseTensorSpec([1],
                                                             dtypes.int32)),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (
                tensor_spec.TensorSpec([1], dtypes.string),
                tensor_spec.TensorSpec([], dtypes.string),
            )
        }),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(tensor_spec.TensorSpec([],
                                                          dtypes.float32))),
    )
    def testDatasetSpec(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.type_spec_from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset),
                                     expected_element_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(dataset_structure))
        self.assertEqual([tensor_shape.scalar()],
                         structure.get_flat_tensor_shapes(dataset_structure))

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset_ops.make_one_shot_iterator(dataset)
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @parameterized.named_parameters(
        ("Async", context.ASYNC),
        ("Sync", context.SYNC),
    )
    def testDatasetEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            for foo in dataset:
                self.assertEqual(val, foo.numpy())
                val += 1

    def testDatasetAsFunctionArgument(self):
        @def_function.function
        def _uses_dataset(d):
            accumulator = array_ops.zeros([], dtype=dtypes.int64)
            for value in d:
                accumulator += value
            return accumulator

        with ops.device("CPU"):
            first_dataset = dataset_ops.Dataset.range(10)
            self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
            second_dataset = dataset_ops.Dataset.range(11)
            self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
            first_concrete = _uses_dataset.get_concrete_function(first_dataset)
            # The dataset should not be a captured input
            self.assertEmpty(first_concrete.graph.captures)
            # The two datasets have the same structure and so should re-use a trace.
            self.assertIs(first_concrete,
                          _uses_dataset.get_concrete_function(second_dataset))
            # With a different structure we should use a different trace.
            self.assertIsNot(
                first_concrete,
                _uses_dataset.get_concrete_function(
                    dataset_ops.Dataset.zip((first_dataset, second_dataset))))

    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(ds):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in ds:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(dataset)), 10)
            self.assertEqual(self.evaluate(f(dataset2)), 45)
            self.assertEqual(trace_count[0], 1)
Beispiel #14
0
 def testDatasetSpecInnerSpec(self):
     inner_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32)
     ds_spec = dataset_ops.DatasetSpec(inner_spec)
     self.assertEqual(ds_spec.element_spec, inner_spec)