示例#1
0
class UniqueTest(test_base.DatasetTestBase, parameterized.TestCase):
    def _testSimpleHelper(self, dtype, test_cases):
        """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component is
        the expected sequence of outputs from the transformation.
    """

        # The `current_test_case` will be updated when we loop over `test_cases`
        # below; declare it here so that the generator can capture it once.
        current_test_case = []
        dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
                                                     dtype).unique()

        for test_case, expected in test_cases:
            current_test_case = test_case
            self.assertDatasetProduces(dataset, [
                compat.as_bytes(element) if dtype == dtypes.string else element
                for element in expected
            ])

    @combinations.generate(test_base.graph_only_combinations())
    def testSimpleInt(self):
        for dtype in [dtypes.int32, dtypes.int64]:
            self._testSimpleHelper(dtype, [
                ([], []),
                ([1], [1]),
                ([1, 1, 1, 1, 1, 1, 1], [1]),
                ([1, 1, 1, 1, 0], [1, 0]),
                ([1, 2, 3, 4], [1, 2, 3, 4]),
                ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
                ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
                ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2],
                                                            [3, 3]]),
            ])

    @combinations.generate(test_base.graph_only_combinations())
    def testSimpleString(self):
        self._testSimpleHelper(dtypes.string, [
            ([], []),
            (["hello"], ["hello"]),
            (["hello", "hello", "hello"], ["hello"]),
            (["hello", "world"], ["hello", "world"]),
            (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"
                                                          ]),
        ])

    @combinations.generate(test_base.graph_only_combinations())
    def testUnsupportedTypes(self):
        for dtype in [
                dtypes.bool, dtypes.double, dtypes.complex64, dtypes.float32,
                dtypes.float64, dtypes.qint16, dtypes.qint32
        ]:
            with self.assertRaises(TypeError):
                _ = dataset_ops.Dataset.from_generator(lambda: [],
                                                       dtype).unique()
示例#2
0
class LenTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testKnown(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        self.assertLen(ds, 10)

    @combinations.generate(test_base.eager_only_combinations())
    def testInfinite(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        with self.assertRaisesRegex(TypeError, "infinite"):
            len(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testUnknown(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).filter(lambda x: True)
        with self.assertRaisesRegex(TypeError, "unknown"):
            len(ds)

    @combinations.generate(test_base.graph_only_combinations())
    def testGraphMode(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        with self.assertRaisesRegex(
                TypeError,
                r"`tf.data.Dataset` only supports `len` in eager mode. Use "
                r"`tf.data.Dataset.cardinality\(\)` instead."):
            len(ds)
示例#3
0
文件: len_test.py 项目: MFChunga/poo
class LenTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testKnown(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        self.assertLen(ds, 10)

    @combinations.generate(test_base.eager_only_combinations())
    def testInfinite(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        with self.assertRaisesRegex(TypeError, 'infinite'):
            len(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testUnknown(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).filter(lambda x: True)
        with self.assertRaisesRegex(TypeError, 'unknown'):
            len(ds)

    @combinations.generate(test_base.graph_only_combinations())
    def testGraphMode(self):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        with self.assertRaisesRegex(TypeError, 'not supported while tracing'):
            len(ds)
示例#4
0
class RandomSeedTest(test_base.DatasetTestBase, parameterized.TestCase):
    def _checkEqual(self, tinput, toutput):
        random_seed.set_random_seed(tinput[0])
        g_seed, op_seed = data_random_seed.get_seed(tinput[1])
        g_seed = self.evaluate(g_seed)
        op_seed = self.evaluate(op_seed)
        msg = "test_case = {0}, got {1}, want {2}".format(
            tinput, (g_seed, op_seed), toutput)
        self.assertEqual((g_seed, op_seed), toutput, msg=msg)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_random_seed_combinations()))
    def testRandomSeed(self, input_fn, output_fn):
        tinput, toutput = input_fn(), output_fn()
        self._checkEqual(tinput=tinput, toutput=toutput)
        random_seed.set_random_seed(None)

    @combinations.generate(test_base.graph_only_combinations())
    def testIncrementalRandomSeed(self):
        random_seed.set_random_seed(1)
        for i in range(10):
            tinput = (1, None)
            toutput = (1, i)
            self._checkEqual(tinput=tinput, toutput=toutput)
示例#5
0
class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testBasic(self):
        ds = dataset_ops.Dataset.range(3)
        self.assertEqual([0, 1, 2], list(ds.as_numpy_iterator()))

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedStructure(self):
        point = collections.namedtuple('Point', ['x', 'y'])
        ds = dataset_ops.Dataset.from_tensor_slices({
            'a': ([1, 2], [3, 4]),
            'b': [5, 6],
            'c': point([7, 8], [9, 10])
        })
        self.assertEqual([{
            'a': (1, 3),
            'b': 5,
            'c': point(7, 9)
        }, {
            'a': (2, 4),
            'b': 6,
            'c': point(8, 10)
        }], list(ds.as_numpy_iterator()))

    @combinations.generate(test_base.graph_only_combinations())
    def testNonEager(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaises(RuntimeError):
            ds.as_numpy_iterator()

    def _testInvalidElement(self, element):
        ds = dataset_ops.Dataset.from_tensors(element)
        with self.assertRaisesRegex(
                TypeError, '.*does not support datasets containing.*'):
            ds.as_numpy_iterator()

    @combinations.generate(test_base.eager_only_combinations())
    def testSparseElement(self):
        self._testInvalidElement(
            sparse_tensor.SparseTensorValue([[0]], [0], [1]))

    @combinations.generate(test_base.eager_only_combinations())
    def testRaggedElement(self):
        lst = [[1, 2], [3], [4, 5, 6]]
        rt = ragged_factory_ops.constant(lst)
        ds = dataset_ops.Dataset.from_tensor_slices(rt)
        for actual, expected in zip(ds.as_numpy_iterator(), lst):
            self.assertTrue(np.array_equal(actual, expected))

    @combinations.generate(test_base.eager_only_combinations())
    def testDatasetElement(self):
        self._testInvalidElement(dataset_ops.Dataset.range(3))

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedNonTensorElement(self):
        tuple_elem = (constant_op.constant([1, 2,
                                            3]), dataset_ops.Dataset.range(3))
        self._testInvalidElement(tuple_elem)
示例#6
0
class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testBasic(self):
        ds = dataset_ops.Dataset.range(3)
        self.assertEqual([0, 1, 2], list(ds.as_numpy_iterator()))

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedStructure(self):
        point = collections.namedtuple('Point', ['x', 'y'])
        ds = dataset_ops.Dataset.from_tensor_slices({
            'a': ([1, 2], [3, 4]),
            'b': [5, 6],
            'c': point([7, 8], [9, 10])
        })
        self.assertEqual([{
            'a': (1, 3),
            'b': 5,
            'c': point(7, 9)
        }, {
            'a': (2, 4),
            'b': 6,
            'c': point(8, 10)
        }], list(ds.as_numpy_iterator()))

    @combinations.generate(test_base.graph_only_combinations())
    def testNonEager(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaises(RuntimeError):
            ds.as_numpy_iterator()

    @combinations.generate(test_base.eager_only_combinations())
    def testSparseElement(self):
        ds = dataset_ops.Dataset.from_tensors(
            sparse_tensor.SparseTensorValue([[0]], [0], [1]))
        with self.assertRaises(TypeError):
            ds.as_numpy_iterator()

    @combinations.generate(test_base.eager_only_combinations())
    def testRaggedElement(self):
        ds = dataset_ops.Dataset.from_tensors(
            ragged_tensor_value.RaggedTensorValue(
                np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
        with self.assertRaises(TypeError):
            ds.as_numpy_iterator()

    @combinations.generate(test_base.eager_only_combinations())
    def testDatasetElement(self):
        ds = dataset_ops.Dataset.from_tensors(dataset_ops.Dataset.range(3))
        with self.assertRaises(TypeError):
            ds.as_numpy_iterator()

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedNonTensorElement(self):
        elem = (constant_op.constant([1, 2, 3]), dataset_ops.Dataset.range(3))
        ds = dataset_ops.Dataset.from_tensors(elem)
        with self.assertRaises(TypeError):
            ds.as_numpy_iterator()
示例#7
0
class WrapUnwrapTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    # TODO(b/182414964): After making options persistent across tf.function is
    # enabled, the ModelDatasetOp and MaxIntraParallelismOp are no longer present
    # in Python. As a result, the FinalizeDataset is placed on GPU because of
    # colocation constraint on the iterator. It then requires a registered copy
    # operation from CPU to GPU for RangeDataset that does not exist and the test
    # fails. Fix this test and re-enable it.
    def DISABLED_testBasic(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._variant_tensor  # pylint: disable=protected-access

        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            wrapped_variant)

        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds.element_spec)
        get_next = self.getNext(variant_ds, requires_initialization=True)
        for i in range(100):
            self.assertEqual(i, self.evaluate(get_next()))

    @combinations.generate(test_base.graph_only_combinations())
    def testGPU(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._variant_tensor  # pylint: disable=protected-access
        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)

        with ops.device("/gpu:0"):
            gpu_wrapped_variant = array_ops.identity(wrapped_variant)

        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            gpu_wrapped_variant)
        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds.element_spec)
        iterator = dataset_ops.make_initializable_iterator(variant_ds)
        get_next = iterator.get_next()

        with self.cached_session():
            self.evaluate(iterator.initializer)
            for i in range(100):
                self.assertEqual(i, self.evaluate(get_next))
示例#8
0
class WrapDatasetVariantTest(test_base.DatasetTestBase,
                             parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testBasic(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._variant_tensor  # pylint: disable=protected-access

        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            wrapped_variant)

        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds.element_spec)
        get_next = self.getNext(variant_ds, requires_initialization=True)
        for i in range(100):
            self.assertEqual(i, self.evaluate(get_next()))

    @combinations.generate(test_base.graph_only_combinations())
    def testGPU(self):
        ds = dataset_ops.Dataset.range(100)
        ds_variant = ds._variant_tensor  # pylint: disable=protected-access
        wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)

        with ops.device("/gpu:0"):
            gpu_wrapped_variant = array_ops.identity(wrapped_variant)

        unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(
            gpu_wrapped_variant)
        variant_ds = dataset_ops._VariantDataset(unwrapped_variant,
                                                 ds.element_spec)
        iterator = dataset_ops.make_initializable_iterator(variant_ds)
        get_next = iterator.get_next()

        with self.cached_session():
            self.evaluate(iterator.initializer)
            for i in range(100):
                self.assertEqual(i, self.evaluate(get_next))
示例#9
0
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testNoGradients(self):
        component = constant_op.constant([1.])
        side = constant_op.constant(0.)
        add = lambda x: x + side
        dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
        value = dataset_ops.make_one_shot_iterator(dataset).get_next()
        self.assertIsNone(gradients_impl.gradients(value, component)[0])
        self.assertIsNone(gradients_impl.gradients(value, side)[0])
        self.assertIsNone(
            gradients_impl.gradients(value, [component, side])[0])

    @combinations.generate(test_base.graph_only_combinations())
    def testCapturingStateInOneShotRaisesException(self):
        var = variables.Variable(37.0, name="myvar")
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [0.0, 1.0, 2.0]).map(lambda x: x + var))
        with self.assertRaisesRegex(
                ValueError,
                r"`Dataset.make_one_shot_iterator\(\)` does not support "
                "datasets that capture stateful objects.+myvar"):
            dataset_ops.make_one_shot_iterator(dataset)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIterator(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorCaptureByValue(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        tensor_components = tuple(
            [ops.convert_to_tensor(c) for c in components])

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(tensor_components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOneShotIteratorInsideContainer(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def within_container():
            def _map_fn(x, y, z):
                return math_ops.square(x), math_ops.square(y), math_ops.square(
                    z)

            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensor_slices(components).map(
                    _map_fn).repeat(14))
            return iterator.get_next()

        server = server_lib.Server.create_local_server()

        # Create two iterators within unique containers, and run them to
        # make sure that the resources aren't shared.
        #
        # The test below would fail if cname were the same across both
        # sessions.
        for j in range(2):
            with session.Session(server.target) as sess:
                cname = "iteration%d" % j
                with ops.container(cname):
                    get_next = within_container()

                for _ in range(14):
                    for i in range(7):
                        result = sess.run(get_next)
                        for component, result_component in zip(
                                components, result):
                            self.assertAllEqual(component[i]**2,
                                                result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorNonBlocking(self):
        dataset = dataset_ops.Dataset.from_tensors([1, 2,
                                                    3]).map(lambda x: x * x)
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        # Create a session with a single thread to ensure that the
        # one-shot iterator initializer does not deadlock.
        config = config_pb2.ConfigProto(inter_op_parallelism_threads=1,
                                        use_per_session_threads=True)
        with session.Session(config=config) as sess:
            self.assertAllEqual([1, 4, 9], sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

        # Test with multiple threads invoking the one-shot iterator concurrently.
        with session.Session(config=config) as sess:
            results = []

            def consumer_thread():
                try:
                    results.append(sess.run(next_element))
                except errors.OutOfRangeError:
                    results.append(None)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            self.assertLen(results, num_threads)
            self.assertLen([None for r in results if r is None],
                           num_threads - 1)
            self.assertAllEqual([[1, 4, 9]],
                                [r for r in results if r is not None])

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorInitializerFails(self):
        # Define a dataset whose initialization will always fail.
        dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                sess.run(next_element)

            # Test that subsequent attempts to use the iterator also fail.
            with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                sess.run(next_element)

        with self.cached_session() as sess:

            def consumer_thread():
                with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                    sess.run(next_element)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

    @combinations.generate(test_base.graph_only_combinations())
    def testSimpleSharedResource(self):
        components = (np.array(1, dtype=np.int64),
                      np.array([1, 2, 3],
                               dtype=np.int64), np.array(37.0,
                                                         dtype=np.float64))

        server = server_lib.Server.create_local_server()

        # Create two non-overlapping sessions that share the same iterator
        # resource on the same server, and verify that an action of the
        # first session (initializing the iterator) is visible in the
        # second session.
        with ops.Graph().as_default():
            iterator = dataset_ops.make_initializable_iterator(
                dataset_ops.Dataset.from_tensors(components).map(
                    lambda x, y, z: (x, y, z)),
                shared_name="shared_iterator")
            init_op = iterator.initializer
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                sess.run(init_op)
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

                # Re-initialize the iterator in the first session.
                sess.run(init_op)

        with ops.Graph().as_default():
            # Re-define the iterator manually, without defining any of the
            # functions in this graph, to ensure that we are not
            # accidentally redefining functions with the same names in the
            # new graph.
            iterator = iterator_ops.Iterator.from_structure(
                shared_name="shared_iterator",
                output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
                output_shapes=([], [3], []))
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                # Use the iterator without re-initializing in the second session.
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testNotInitializedError(self):
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
        iterator = dataset_ops.make_initializable_iterator(
            dataset_ops.Dataset.from_tensors(components))
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegex(errors.FailedPreconditionError,
                                        "iterator has not been initialized"):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testReinitializableIterator(self):
        dataset_3 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([1, 2, 3]))
        dataset_4 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([4, 5, 6, 7]))
        iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset_3), [None])

        dataset_3_init_op = iterator.make_initializer(dataset_3)
        dataset_4_init_op = iterator.make_initializer(dataset_4)
        get_next = iterator.get_next()

        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(
            [None],
            dataset_ops.get_legacy_output_shapes(iterator).as_list())

        with self.cached_session() as sess:
            # The iterator is initially uninitialized.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(get_next)

            # Initialize with one dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Initialize with a different dataset.
            sess.run(dataset_4_init_op)
            self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Reinitialize with the first dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testReinitializableIteratorWithFunctions(self):
        def g():
            for i in range(10):
                yield i

        iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            dataset_1 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_1))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

            dataset_2 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_2))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testReinitializableIteratorStaticErrors(self):
        # Non-matching structure for types and shapes.
        with self.assertRaises(TypeError):
            iterator = iterator_ops.Iterator.from_structure(
                (dtypes.int64, dtypes.float64), [None])

        # Test validation of dataset argument.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64))

        # Incompatible structure.
        with self.assertRaises(ValueError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ),
                     (constant_op.constant([4., 5., 6., 7.],
                                           dtype=dtypes.float64), ))))

        # Incompatible types.
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float32))))

        # Incompatible shapes.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64), ([None], []))
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float64))))

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandle(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleFuture(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleReuseTensorObject(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
        initializable_iterator = dataset_ops.make_initializable_iterator(
            dataset)
        structure_iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset))

        created_ops = len(ops.get_default_graph().get_operations())

        self.assertIs(one_shot_iterator.string_handle(),
                      one_shot_iterator.string_handle())
        self.assertIs(initializable_iterator.string_handle(),
                      initializable_iterator.string_handle())
        self.assertIs(structure_iterator.string_handle(),
                      structure_iterator.string_handle())

        # Assert that getting the (default) string handle creates no ops.
        self.assertEqual(created_ops,
                         len(ops.get_default_graph().get_operations()))

        # Specifying an explicit name will create a new op.
        handle_with_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo", handle_with_name.op.name)
        self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

        handle_with_same_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo_1", handle_with_same_name.op.name)
        self.assertIsNot(handle_with_name, handle_with_same_name)

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleError(self):
        dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices(
            [1, 2, 3]).repeat())
        dataset_float_vector = (dataset_ops.Dataset.from_tensors(
            [1.0, 2.0, 3.0]))

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])

        feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [])
        feedable_int_vector = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [None])
        feedable_int_any = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32)

        with self.cached_session() as sess:
            handle_int_scalar = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_int_scalar).string_handle())
            handle_float_vector = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_float_vector).string_handle())

            self.assertEqual(
                1,
                sess.run(feedable_int_scalar.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            self.assertEqual(
                2,
                sess.run(feedable_int_any.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(feedable_int_vector.get_next(),
                             feed_dict={handle_placeholder:
                                        handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(
                        feedable_int_vector.get_next(),
                        feed_dict={handle_placeholder: handle_float_vector}))

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 3

        with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        @function.Defun(dtypes.string)
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            remote_op = functional_ops.remote_call(args=[iterator_3_handle],
                                                   Tout=[dtypes.int32],
                                                   f=_remote_fn,
                                                   target=target_placeholder)

        with self.session(config=worker_config) as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [1])
            # Fails when target is cpu:2 where the resource is not located.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:2"
                         })
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:1"
                         })

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
        s1 = server_lib.Server.create_local_server()
        s2 = server_lib.Server.create_local_server()
        s3 = server_lib.Server.create_local_server()

        cluster_def = cluster_pb2.ClusterDef()
        workers = cluster_def.job.add()
        workers.name = "worker"
        workers.tasks[0] = s1.target[len("grpc://"):]
        workers.tasks[1] = s2.target[len("grpc://"):]
        client = cluster_def.job.add()
        client.name = "client"
        client.tasks[0] = s3.target[len("grpc://"):]
        config = config_pb2.ConfigProto(cluster_def=cluster_def)

        worker_devices = [
            "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
        ]
        itr_handles = []
        for device in worker_devices:
            with ops.device(device):
                src = dataset_ops.Dataset.from_tensor_slices([device])
                itr = dataset_ops.make_one_shot_iterator(src)
                itr_handles.append(itr.string_handle())

        targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
        handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)

        @function.Defun(dtypes.string)
        def loading_func(h):
            remote_itr = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(itr),
                dataset_ops.get_legacy_output_shapes(itr))
            return remote_itr.get_next()

        def map_fn(target, handle):
            return functional_ops.remote_call(args=[handle],
                                              Tout=[dtypes.string],
                                              f=loading_func,
                                              target=target)

        with ops.device("/job:client"):
            client_dataset = dataset_ops.Dataset.zip(
                (targets, handles)).map(map_fn)
            itr = dataset_ops.make_initializable_iterator(client_dataset)
            n = itr.get_next()

        with session.Session(s3.target, config=config) as sess:
            sess.run(itr.initializer)
            expected_values = worker_devices
            for expected in expected_values:
                self.assertEqual((compat.as_bytes(expected), ), sess.run(n))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(n)

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        def _encode_raw(byte_array):
            return bytes(bytearray(byte_array))

        @function.Defun(dtypes.uint8)
        def _remote_fn(h):
            handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            iterator_3_handle_uint8 = parsing_ops.decode_raw(
                input_bytes=iterator_3_handle, out_type=dtypes.uint8)
            remote_op = functional_ops.remote_call(
                args=[iterator_3_handle_uint8],
                Tout=[dtypes.int32],
                f=_remote_fn,
                target=target_placeholder)

        with self.cached_session() as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [1])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:0"
                         })

    @combinations.generate(test_base.graph_only_combinations())
    def testRepeatedGetNextWarning(self):
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.range(10))
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            for _ in range(100):
                iterator.get_next()
        self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD,
                         len(w))
        for warning in w:
            self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE,
                          str(warning.message))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=tensor_spec.TensorSpec(
                    [], dtypes.float32),
                expected_output_classes=ops.Tensor,
                expected_output_types=dtypes.float32,
                expected_output_shapes=[[]])))
    def testTensorIteratorStructure(self, expected_element_structure,
                                    expected_output_classes,
                                    expected_output_types,
                                    expected_output_shapes):
        tf_value_fn = lambda: constant_op.constant(37.0)
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=sparse_tensor.SparseTensorSpec(
                    [1], dtypes.int32),
                expected_output_classes=sparse_tensor.SparseTensor,
                expected_output_types=dtypes.int32,
                expected_output_shapes=[[1]])))
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(expected_element_structure={
                "a":
                tensor_spec.TensorSpec([], dtypes.float32),
                "b": (tensor_spec.TensorSpec([1], dtypes.string),
                      tensor_spec.TensorSpec([], dtypes.string))
            },
                                 expected_output_classes={
                                     "a": ops.Tensor,
                                     "b": (ops.Tensor, ops.Tensor)
                                 },
                                 expected_output_types={
                                     "a": dtypes.float32,
                                     "b": (dtypes.string, dtypes.string)
                                 },
                                 expected_output_shapes={
                                     "a": [],
                                     "b": ([1], [])
                                 })))
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(test_base.default_test_combinations())
    def testIteratorGetNextName(self):
        with ops.Graph().as_default():
            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensors(37.0))
            next_element = iterator.get_next(name="overridden_name")
            self.assertEqual("overridden_name", next_element.op.name)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2],
                             mode="eager",
                             execution_mode=[context.ASYNC, context.SYNC]))
    def testIteratorEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for foo in iterator:
                self.assertEqual(val, foo.numpy())
                val += 1

    @combinations.generate(test_base.eager_only_combinations())
    def testOwnedIteratorFunction(self):

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)

        @def_function.function
        def fn():
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for _ in range(10):
                queue.enqueue(next(iterator))

        fn()

        for i in range(10):
            self.assertEqual(queue.dequeue().numpy(), i)

    @combinations.generate(test_base.eager_only_combinations())
    def testOwnedIteratorFunctionError(self):
        # In this test we verify that a function that raises an error ends up
        # properly deallocating the iterator resource.

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
        queue.enqueue(0)

        def init_fn(n):
            return n

        def next_fn(_):
            ds = dataset_ops.Dataset.range(0)
            return next(iter(ds))

        def finalize_fn(n):
            queue.enqueue(0)
            return n

        @def_function.function
        def fn():
            dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn,
                                                    finalize_fn)
            iterator = iter(dataset)
            next(iterator)

        with self.assertRaises(errors.OutOfRangeError):
            fn()

        self.assertEqual(queue.size().numpy(), 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(iterator):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in iterator:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(iter(dataset))), 10)
            self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
            self.assertEqual(trace_count[0], 1)

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedFunctionsIteratorResource(self):
        @def_function.function
        def sum_dataset(ds):
            it = iter(ds)

            @def_function.function
            def next_element(it):
                return next(it)

            total = 0
            for _ in range(10):
                total += next_element(it)
            return total

        ds = dataset_ops.Dataset.range(10)
        self.assertEqual(sum_dataset(ds).numpy(), 45)
        self.assertEqual(sum_dataset(ds).numpy(), 45)

    @combinations.generate(test_base.default_test_combinations())
    def testNestedAutomaticControlDependencies(self):
        counter_var = variables.Variable(0)

        def map_fn(x):
            counter_var.assign_add(1)
            return x

        def dataset_fn():
            return dataset_ops.Dataset.range(10).map(map_fn)

        @def_function.function
        def fn():
            it = iter(dataset_fn())
            for _ in range(10):
                _ = next(it)
            return counter_var

        self.evaluate(counter_var.initializer)
        self.assertEqual(self.evaluate(fn()), 10)
示例#10
0
class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testPrefetchToDevice(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testPrefetchToSameDevice(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device(
                "/job:localhost/replica:0/task:0/device:CPU:0"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testPrefetchDictToDevice(self):
        host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element["a"].dtype)
        self.assertEqual([], next_element["a"].shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual({"a": i}, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testPrefetchSparseTensorsToDevice(self):
        def make_tensor(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0]],
                                                   values=(i * [1]),
                                                   dense_shape=[2, 2])

        host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)

        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                actual = self.evaluate(next_element)
                self.assertAllEqual([i], actual.values)
                self.assertAllEqual([[0, 0]], actual.indices)
                self.assertAllEqual([2, 2], actual.dense_shape)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testPrefetchToDeviceGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/gpu:0"))

        self.assertDatasetProduces(device_dataset, list(range(10)))

    @combinations.generate(test_base.graph_only_combinations())
    def testPrefetchToDeviceWithReInit(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testPrefetchToDeviceGpuWithReInit(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/gpu:0"))

        iterator = dataset_ops.make_initializable_iterator(device_dataset)
        next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.eager_only_combinations())
    def testPrefetchToDevicePlacement(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.prefetch_to_device("/gpu:0"))

        self.assertEqual(device_dataset._variant_tensor.device,
                         "/job:localhost/replica:0/task:0/device:GPU:0")
示例#11
0
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(count=[32, 34],
                                 padded_shapes=[[None], [25]],
                                 drop_remainder=[True, False])))
    def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder):
        seq_lens = np.random.randint(20, size=(count, )).astype(np.int32)
        batch_size = 4
        dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
            lambda x: array_ops.fill([x], x)).padded_batch(
                batch_size=batch_size,
                drop_remainder=drop_remainder,
                padded_shapes=padded_shapes)

        num_full_batches = len(seq_lens) // batch_size
        get_next = self.getNext(dataset)
        for i in range(num_full_batches):
            result = self.evaluate(get_next())
            padded_len = padded_shapes[0]
            if padded_len is None or padded_len == -1:
                padded_len = np.max(result) if result.size > 0 else 0
            self.assertEqual((batch_size, padded_len), result.shape)
            for j in range(batch_size):
                seq_len = seq_lens[(i * batch_size) + j]
                self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
                self.assertAllEqual(result[j, seq_len:],
                                    [0] * (padded_len - seq_len))

        if not drop_remainder and len(seq_lens) % batch_size > 0:
            result = self.evaluate(get_next())
            padded_len = padded_shapes[0]
            if padded_len is None or padded_len == -1:
                padded_len = np.max(result) if result.size > 0 else 0
            self.assertEqual((len(seq_lens) % batch_size, padded_len),
                             result.shape)
            for j in range(len(seq_lens) % batch_size):
                seq_len = seq_lens[num_full_batches * batch_size + j]
                self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
                self.assertAllEqual(result[j, seq_len:],
                                    [0] * (padded_len - seq_len))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShortPadding(self):
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [6, 5, 5, 5,
             5]).map(lambda x: array_ops.fill([x], x)).padded_batch(
                 batch_size=4, padded_shapes=[5]))
        self.assertDatasetProduces(dataset,
                                   expected_error=(errors.DataLossError, ''))

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchEmptyTensors(self):
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [0, 0, 0, 0]).map(lambda x: array_ops.fill([x], x)).padded_batch(
                batch_size=4, padded_shapes=[-1]))
        self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])

    @combinations.generate(test_base.default_test_combinations())
    def testDefaultPaddedShapes(self):
        def fill(x):
            return array_ops.fill([x], x)

        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [1, 2, 3, 4]).map(fill).padded_batch(batch_size=2))
        self.assertDatasetProduces(dataset,
                                   expected_output=[[[1, 0], [2, 2]],
                                                    [[3, 3, 3, 0],
                                                     [4, 4, 4, 4]]])

    @combinations.generate(test_base.default_test_combinations())
    def testNestedDefaultPaddedShapes(self):
        def fill_tuple(x):
            return (x, array_ops.fill([x], x))

        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [1, 2, 3, 4]).map(fill_tuple).padded_batch(batch_size=2))
        self.assertDatasetProduces(dataset,
                                   expected_output=[([1, 2], [[1, 0], [2, 2]]),
                                                    ([3, 4], [[3, 3, 3, 0],
                                                              [4, 4, 4, 4]])])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(padding_values=[(-1, '<end>', {
                'structure': ''
            }), (-1, '<end>', None)])))
    def testPaddedBatchDatasetNonDefaultPadding(self, padding_values):
        def fill_tuple(x):
            filled = array_ops.fill([x], x)
            return (filled, string_ops.as_string(filled), {
                'structure': string_ops.as_string(filled)
            })

        random_seq_lens = np.random.randint(20, size=(32, )).astype(np.int32)
        dataset = (dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map(
            fill_tuple).padded_batch(4,
                                     padded_shapes=([-1], [-1], {
                                         'structure': [-1]
                                     }),
                                     padding_values=padding_values))

        get_next = self.getNext(dataset)
        for i in range(8):
            result = self.evaluate(get_next())
            padded_len = np.max(result[0])
            self.assertEqual((4, padded_len), result[0].shape)
            self.assertEqual((4, padded_len), result[1].shape)
            self.assertEqual((4, padded_len), result[2]['structure'].shape)
            for j in range(4):
                seq_len = random_seq_lens[(i * 4) + j]
                self.assertAllEqual(result[0][j, :seq_len],
                                    [seq_len] * seq_len)
                self.assertAllEqual(result[0][j, seq_len:],
                                    [-1] * (padded_len - seq_len))
                self.assertAllEqual(result[1][j, :seq_len],
                                    [compat.as_bytes(str(seq_len))] * seq_len)
                self.assertAllEqual(result[1][j, seq_len:],
                                    [b'<end>'] * (padded_len - seq_len))
                self.assertAllEqual(result[2]['structure'][j, :seq_len],
                                    [compat.as_bytes(str(seq_len))] * seq_len)
                self.assertAllEqual(result[2]['structure'][j, seq_len:],
                                    [b''] * (padded_len - seq_len))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchDatasetUnicode(self):
        # See GitHub issue 16149
        def generator():
            data = [[u'Простой', u'тест', u'юникода'],
                    [u'никогда', u'не', u'бывает', u'простым']]

            for seq in data:
                yield seq, [0, 1, 2, 3]

        dataset = dataset_ops.Dataset.from_generator(
            generator, (dtypes.string, dtypes.int32),
            (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None
                                                                         ])))
        padded_dataset = dataset.padded_batch(2,
                                              padded_shapes=([None], [None]),
                                              padding_values=('', 0))
        next_element = self.getNext(padded_dataset)
        self.evaluate(next_element())

    @combinations.generate(test_base.graph_only_combinations())
    def testPaddedBatchDatasetShapeSpecifications(self):
        int_placeholder = array_ops.placeholder(dtypes.int32)
        float_placeholder = array_ops.placeholder(dtypes.float32)
        string_placeholder = array_ops.placeholder(dtypes.string)
        input_dataset = dataset_ops.Dataset.from_tensors(
            (int_placeholder, float_placeholder, string_placeholder))

        # Test different ways of specifying the `padded_shapes` argument.
        dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
            32,
            padded_shapes=(tensor_shape.TensorShape([None]),
                           tensor_shape.TensorShape([None, None]),
                           tensor_shape.TensorShape([37])))
        dynamic_padding_from_lists = input_dataset.padded_batch(
            32, padded_shapes=([None], [None, None], [37]))
        dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
            32, padded_shapes=([-1], [-1, -1], [37]))
        dynamic_padding_from_tensors = input_dataset.padded_batch(
            32,
            padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
                           constant_op.constant([-1, -1], dtype=dtypes.int64),
                           constant_op.constant([37], dtype=dtypes.int64)))

        for dataset in [
                dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
                dynamic_padding_from_lists_with_minus_one,
                dynamic_padding_from_tensors
        ]:
            dataset_output_shapes = dataset_ops.get_legacy_output_shapes(
                dataset)
            self.assertEqual([None, None], dataset_output_shapes[0].as_list())
            self.assertEqual([None, None, None],
                             dataset_output_shapes[1].as_list())
            self.assertEqual([None, 37], dataset_output_shapes[2].as_list())

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchSparseError(self):

        st = sparse_tensor.SparseTensorValue(indices=[[0, 0]],
                                             values=([42]),
                                             dense_shape=[1, 1])

        with self.assertRaises(TypeError):
            _ = dataset_ops.Dataset.from_tensors(st).repeat(10).padded_batch(
                10)

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchRaggedError(self):

        rt = ragged_tensor_value.RaggedTensorValue(
            np.array([0, 42]), np.array([0, 2], dtype=np.int64))

        with self.assertRaises(TypeError):
            _ = dataset_ops.Dataset.from_tensors(rt).repeat(10).padded_batch(
                10)

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeErrorWrongRank(self):
        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \(1,\) is not compatible with the '
                r'corresponding input component shape \(\).'):
            _ = dataset_ops.Dataset.range(10).padded_batch(5,
                                                           padded_shapes=[1])

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeErrorTooSmall(self):
        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \(1,\) is not compatible with the '
                r'corresponding input component shape \(3,\).'):
            _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
                5, padded_shapes=[1])

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeErrorShapeNotRank1(self):
        with self.assertRaisesRegexp(
                ValueError, r'Padded shape .* must be a 1-D tensor '
                r'of tf.int64 values, but its shape was \(2, 2\).'):
            _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
                5, padded_shapes=[[1, 1], [1, 1]])

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeErrorShapeNotInt(self):
        with self.assertRaisesRegexp(
                TypeError, r'Padded shape .* must be a 1-D tensor '
                r'of tf.int64 values, but its element type was float32.'):
            _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
                5, padded_shapes=constant_op.constant([1.5, 2., 3.]))

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeErrorWrongRankFromTensor(self):
        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \(1,\) is not compatible with the '
                r'corresponding input component shape \(\).'):
            shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
            _ = dataset_ops.Dataset.range(10).padded_batch(
                5, padded_shapes=shape_as_tensor)

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeErrorDefaultShapeWithUnknownRank(self):
        with self.assertRaisesRegexp(ValueError,
                                     r'`padded_shapes`.*unknown rank'):
            ds = dataset_ops.Dataset.from_generator(lambda: iter([1, 2, 3]),
                                                    output_types=dtypes.int32)
            ds.padded_batch(2)

    @combinations.generate(test_base.graph_only_combinations())
    def testPaddedBatchShapeErrorPlaceholder(self):
        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
                r'corresponding input component shape \(\).'):
            shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
            _ = dataset_ops.Dataset.range(10).padded_batch(
                5, padded_shapes=shape_as_tensor)

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchBfloat16(self):
        ds = dataset_ops.Dataset.range(5)
        ds = ds.map(lambda x: math_ops.cast(x, dtypes.bfloat16))
        ds = ds.padded_batch(10)
        self.assertDatasetProduces(ds,
                                   expected_output=[[0.0, 1.0, 2.0, 3.0, 4.0]])

    @combinations.generate(test_base.default_test_combinations())
    def testDefaultPaddedValueShapes(self):
        def fill(x):
            return array_ops.fill([x], x)

        dataset = dataset_ops.Dataset.zip(
            (dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]).map(fill),
             dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]).map(fill)))
        dataset = dataset.padded_batch(batch_size=2, padding_values=-1)
        self.assertDatasetProduces(dataset,
                                   expected_output=[([[1, -1], [2,
                                                                2]], [[1, -1],
                                                                      [2, 2]]),
                                                    ([[3, 3, 3, -1],
                                                      [4, 4, 4,
                                                       4]], [[3, 3, 3, -1],
                                                             [4, 4, 4, 4]])])
示例#12
0
class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testFromTensors(self):
        """Test a dataset that represents a single tuple of tensors."""
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertEqual([c.shape for c in components],
                         nest.flatten(
                             dataset_ops.get_legacy_output_shapes(dataset)))

        self.assertDatasetProduces(dataset, expected_output=[components])

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsDataset(self):
        """Test a dataset that represents a dataset."""
        dataset = dataset_ops.Dataset.from_tensors(
            dataset_ops.Dataset.range(10))
        dataset = dataset.flat_map(lambda x: x)
        self.assertDatasetProduces(dataset, expected_output=range(10))

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsTensorArray(self):
        """Test a dataset that represents a TensorArray."""
        components = (tensor_array_ops.TensorArray(dtypes.float32,
                                                   element_shape=(),
                                                   size=2).unstack([1.0, 2.0]))

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertDatasetProduces(dataset,
                                   expected_output=[[1.0, 2.0]],
                                   requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsSparse(self):
        """Test a dataset that represents a single tuple of tensors."""
        components = (sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                                      values=np.array([0]),
                                                      dense_shape=np.array(
                                                          [1])),
                      sparse_tensor.SparseTensorValue(
                          indices=np.array([[0, 0], [1, 1]]),
                          values=np.array([-1, 1]),
                          dense_shape=np.array([2, 2])))

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertEqual(
            [tensor_shape.TensorShape(c.dense_shape) for c in components],
            [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
        self.assertDatasetProduces(dataset, expected_output=[components])

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsMixed(self):
        """Test an dataset that represents a single tuple of tensors."""
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
                      sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                                      values=np.array([0]),
                                                      dense_shape=np.array(
                                                          [1])),
                      sparse_tensor.SparseTensorValue(
                          indices=np.array([[0, 0], [1, 1]]),
                          values=np.array([-1, 1]),
                          dense_shape=np.array([2, 2])))

        dataset = dataset_ops.Dataset.from_tensors(components)
        self.assertEqual([
            tensor_shape.TensorShape(c.dense_shape)
            if sparse_tensor.is_sparse(c) else c.shape for c in components
        ], [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])

        self.assertDatasetProduces(dataset, expected_output=[components])

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsRagged(self):
        components = (
            ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]),
            ragged_factory_ops.constant_value([[[3]], [[4]], [[5]]]),
        )

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertDatasetProduces(dataset, expected_output=[components])

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsNamedTuple(self):
        Foo = collections.namedtuple("Foo", ["x", "y"])
        element = Foo(x=1, y=2)
        dataset = dataset_ops.Dataset.from_tensors(element)
        self.assertDatasetProduces(dataset, expected_output=[element])

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsAttrs(self):
        if attr is None:
            self.skipTest("attr module is not available.")

        @attr.s
        class Foo(object):
            x = attr.ib()
            y = attr.ib()

        element = Foo(x=1, y=2)
        dataset = dataset_ops.Dataset.from_tensors(element)
        self.assertDatasetProduces(dataset, expected_output=[element])

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsMixedRagged(self):
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
                      sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                                      values=np.array([0]),
                                                      dense_shape=np.array(
                                                          [1])),
                      sparse_tensor.SparseTensorValue(
                          indices=np.array([[0, 0], [1, 1]]),
                          values=np.array([-1, 1]),
                          dense_shape=np.array([2, 2])),
                      ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))

        dataset = dataset_ops.Dataset.from_tensors(components)

        self.assertDatasetProduces(dataset, expected_output=[components])

    @combinations.generate(
        combinations.combine(
            tf_api_version=[1],
            mode=["graph"],
            components=(np.array([1, 2, 3], dtype=np.int64),
                        (np.array([4., 5.]), np.array(
                            [6., 7.])), np.array([8, 9, 10], dtype=np.int64)),
            expected_shapes=[[[None, 3], [None, 3], [None, 2], [None, 2]]]) +
        combinations.combine(
            tf_api_version=[1],
            mode=["eager"],
            components=(np.array([1, 2, 3], dtype=np.int64),
                        (np.array([4., 5.]), np.array(
                            [6., 7.])), np.array([8, 9, 10], dtype=np.int64)),
            expected_shapes=[[[1, 3], [1, 3], [1, 2], [1, 2]]]))
    def testNestedStructure(self, components, expected_shapes):
        dataset = dataset_ops.Dataset.from_tensors(components)
        dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))

        dataset = dataset.flat_map(
            lambda x, y: dataset_ops.Dataset.from_tensors(
                ((x[0], x[1]), (y[0], y[1])))).batch(32)

        get_next = self.getNext(dataset)
        (w, x), (y, z) = get_next()
        self.assertEqual(dtypes.int64, w.dtype)
        self.assertEqual(dtypes.int64, x.dtype)
        self.assertEqual(dtypes.float64, y.dtype)
        self.assertEqual(dtypes.float64, z.dtype)
        self.assertEqual(expected_shapes, [
            w.shape.as_list(),
            x.shape.as_list(),
            y.shape.as_list(),
            z.shape.as_list()
        ])

        get_next = self.getNext(dataset)
        (w, x), (y, z) = get_next()
        self.assertEqual(dtypes.int64, w.dtype)
        self.assertEqual(dtypes.int64, x.dtype)
        self.assertEqual(dtypes.float64, y.dtype)
        self.assertEqual(dtypes.float64, z.dtype)
        self.assertEqual(expected_shapes, [
            w.shape.as_list(),
            x.shape.as_list(),
            y.shape.as_list(),
            z.shape.as_list()
        ])

    @combinations.generate(test_base.default_test_combinations())
    def testNestedDict(self):
        components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
        dataset = dataset_ops.Dataset.from_tensors(components)
        self.assertEqual(
            dtypes.int32,
            dataset_ops.get_legacy_output_types(dataset)["a"]["aa"])
        self.assertEqual(
            dtypes.float32,
            dataset_ops.get_legacy_output_types(dataset)["a"]["ab"])
        self.assertEqual(dtypes.int32,
                         dataset_ops.get_legacy_output_types(dataset)["b"])
        self.assertEqual(
            [],
            dataset_ops.get_legacy_output_shapes(dataset)["a"]["aa"])
        self.assertEqual(
            [2],
            dataset_ops.get_legacy_output_shapes(dataset)["a"]["ab"])
        self.assertEqual([3],
                         dataset_ops.get_legacy_output_shapes(dataset)["b"])

    @combinations.generate(test_base.default_test_combinations())
    def testNonSequenceNestedStructure(self):
        components = np.array([1, 2, 3], dtype=np.int64)

        dataset = dataset_ops.Dataset.from_tensors(components)
        self.assertEqual(dtypes.int64,
                         dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.filter(
            lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
        self.assertEqual(dtypes.int64,
                         dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.map(lambda x: array_ops.stack([x, x]))
        self.assertEqual(dtypes.int64,
                         dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual([2, 3], dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.flat_map(
            lambda x: dataset_ops.Dataset.from_tensor_slices(x))
        self.assertEqual(dtypes.int64,
                         dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

        get_next = self.getNext(dataset)
        self.assertEqual(dtypes.int64, get_next().dtype)
        self.assertEqual([3], get_next().shape)

    # TODO(b/121264236): needs mechanism for multiple device in eager mode.
    @combinations.generate(test_base.graph_only_combinations())
    def testSplitPipeline(self):
        with session.Session(target="",
                             config=config_pb2.ConfigProto(
                                 device_count={"CPU": 2})) as sess:

            dataset = dataset_ops.Dataset.from_tensors(0)

            # Define a pipeline that attempts to use variables on two
            # different devices.
            #
            # Initialize the variables before creating to iterator, to avoid the
            # placement algorithm overriding the DT_RESOURCE colocation constraints.
            with ops.device("/cpu:0"):
                var_0 = resource_variable_ops.ResourceVariable(initial_value=1)
            dataset = dataset.map(lambda x: x + var_0.read_value())
            sess.run(var_0.initializer)

            with ops.device("/cpu:1"):
                var_1 = resource_variable_ops.ResourceVariable(initial_value=1)
            dataset = dataset.map(lambda x: x + var_1.read_value())
            sess.run(var_1.initializer)

            iterator = dataset_ops.make_initializable_iterator(dataset)
            sess.run(iterator.initializer)

            self.assertEqual(sess.run(iterator.get_next()), 2)

    @combinations.generate(test_base.default_test_combinations())
    def testName(self):
        dataset = dataset_ops.Dataset.from_tensors(42, name="from_tensors")
        self.assertDatasetProduces(dataset, [42])
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationStatefulFunction(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda _: random_ops.random_uniform([])).batch(10)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        get_next = self.getNext(dataset)
        self.evaluate(get_next())

    # TODO(b/123902160)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensor(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
        dataset = dataset_ops.Dataset.from_tensors(input_t)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    # TODO(b/123902160)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensorSlices(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
        dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op,
                     {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDataset(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
            dataset = dataset.skip(0)  # Should be removed by noop elimination
            dataset = dataset.cache()
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[0])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDatasetWithModifiedRetval(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
            # Should be fused by map and batch fusion
            dataset = dataset.map(lambda x: x)
            dataset = dataset.batch(1)
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)

        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_batch_fusion = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[[0]])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDoubleOptimizeDatasetNested(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
            dataset = dataset.skip(0)
            # Should be fused by map and batch fusion
            dataset = dataset.map(lambda x: x)
            dataset = dataset.batch(1)
            return dataset

        dataset = dataset_ops.Dataset.from_tensors(0)
        dataset = dataset.flat_map(flat_map_fn)
        dataset = dataset_ops._OptimizeDataset(dataset,
                                               ["map_and_batch_fusion"], [],
                                               [])
        dataset = dataset_ops._OptimizeDataset(dataset, ["noop_elimination"],
                                               [], [])

        self.assertDatasetProduces(dataset, expected_output=[[0]])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDifferentOrderOptionsCompareEqual(self):
        with ops.Graph().as_default() as first_graph:
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset_ops._OptimizeDataset(
                dataset, ["map_and_batch_fusion", "noop_elimination"], [], [])

        with ops.Graph().as_default() as second_graph:
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset_ops._OptimizeDataset(
                dataset, ["noop_elimination", "map_and_batch_fusion"], [], [])

        self.assertEqual(first_graph.as_graph_def(),
                         second_graph.as_graph_def())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            _disable_intra_op_parallelism_test_combinations(),
            combinations.combine(apply_autotune=[None, True, False])))
    def testOptimizationDisableIntraOpParallelism(self, dataset_fn,
                                                  expected_output,
                                                  apply_autotune):
        dataset = dataset_fn()
        dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"]))
        if apply_autotune is not None:
            options = dataset_ops.Options()
            options.experimental_optimization.autotune = apply_autotune
            dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=expected_output)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune=False, autotune_buffers=False) +
            combinations.combine(autotune=True, autotune_buffers=False) +
            combinations.combine(autotune=True, autotune_buffers=True),
            combinations.combine(set_env=[False, True])))
    def testOptimizationEnableGradientDescent(self, autotune, autotune_buffers,
                                              set_env):
        if set_env:
            os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "enable_gradient_descent"
            os.environ["TF_JOB_NAME"] = "test_job"

        dataset = dataset_ops.Dataset.range(5)
        dataset = dataset.prefetch(buffer_size=-1)
        dataset = dataset.map(lambda x: x + 1, num_parallel_calls=2)
        dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
        dataset = dataset.prefetch(buffer_size=3)
        dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
        dataset = dataset.prefetch(buffer_size=1)

        options = dataset_ops.Options()
        options.experimental_optimization.autotune = autotune
        options.experimental_optimization.autotune_buffers = autotune_buffers
        dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=list(range(3, 8)))

        if set_env:
            del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
            del os.environ["TF_JOB_NAME"]

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune=[True, False, None]),
            combinations.combine(map_parallelization=[True, False, None])))
    def testOptimizationMapParallelization(self, autotune,
                                           map_parallelization):
        dataset = dataset_ops.Dataset.range(5)
        if autotune is not False and map_parallelization is not False:  # pylint: disable=g-bool-id-comparison
            dataset = dataset.apply(testing.assert_next(["ParallelMap"]))
        else:
            dataset = dataset.apply(testing.assert_next(["Map"]))
        dataset = dataset.map(lambda x: x + 1)

        options = dataset_ops.Options()
        if autotune is not None:
            options.experimental_optimization.autotune = autotune
        if map_parallelization is not None:
            options.experimental_optimization.map_parallelization = (
                map_parallelization)
        dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=list(range(1, 6)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(set_env=[True, False])))
    def testOptimizationUsePrivateThreadPool(self, set_env):
        if set_env:
            os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "use_private_thread_pool"
            os.environ["TF_JOB_NAME"] = "test_job"

        dataset = dataset_ops.Dataset.range(6)
        if set_env:
            dataset = dataset.apply(
                testing.assert_next(
                    ["MaxIntraOpParallelism", "PrivateThreadPool", "Model"]))
        else:
            dataset = dataset.apply(
                testing.assert_next(["MaxIntraOpParallelism", "Model"]))

        self.assertDatasetProduces(dataset, expected_output=list(range(6)))

        if set_env:
            del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
            del os.environ["TF_JOB_NAME"]

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune=False, autotune_buffers=False) +
            combinations.combine(autotune=True, autotune_buffers=False) +
            combinations.combine(autotune=True, autotune_buffers=True),
            combinations.combine(first_buffer_sizes=[(1, -1, -1,
                                                      4), (2, -1, 3,
                                                           -1), (2, 1, -1,
                                                                 -1)]),
            combinations.combine(second_buffer_sizes=[(1, -1, -1,
                                                       4), (2, -1, 3,
                                                            -1), (2, 1, -1,
                                                                  -1)])))
    def testOptimizationAutotuneBuffers(self, autotune, autotune_buffers,
                                        first_buffer_sizes,
                                        second_buffer_sizes):
        dataset = dataset_ops.Dataset.range(10)
        for buffer_size in first_buffer_sizes:
            dataset = dataset.prefetch(buffer_size=buffer_size)
        dataset = dataset.map(lambda x: x + 1)
        for buffer_size in second_buffer_sizes:
            dataset = dataset.prefetch(buffer_size=buffer_size)
        options = dataset_ops.Options()
        options.experimental_optimization.autotune = autotune
        options.experimental_optimization.autotune_buffers = autotune_buffers
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=list(range(1, 11)))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationThreadPoolDataset(self):
        dataset = dataset_ops.Dataset.range(10).batch(10)

        dataset = threadpool.override_threadpool(
            dataset,
            threadpool.PrivateThreadPool(
                2, display_name="private_thread_pool_%d" % 2))

        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset,
                                   expected_output=[list(range(10))],
                                   requires_initialization=True)

    # Reference variables are not supported in eager mode.
    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           _captured_refvar_test_combinations()))
    def testOptimizationWithCapturedRefVar(self, dataset_fn):
        """Tests that default optimizations are disabled with ref variables."""
        variable = variable_scope.get_variable("v",
                                               initializer=0,
                                               use_resource=False)
        assign_op = variable.assign_add(1)

        # Check that warning is logged.
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            unoptimized_dataset = dataset_fn(variable)

            options = dataset_ops.Options()
            options.experimental_optimization.apply_default_optimizations = False
            options.experimental_optimization.noop_elimination = True
            options.experimental_optimization.map_and_batch_fusion = True
            optimized_dataset = unoptimized_dataset.with_options(options)
            optimized_it = dataset_ops.make_initializable_iterator(
                optimized_dataset)

        self.assertGreaterEqual(len(w), 1)
        graph_rewrites = options._graph_rewrites()
        expected = (
            "tf.data graph rewrites are not compatible with "
            "tf.Variable. The following rewrites will be disabled: %s."
            " To enable rewrites, use resource variables instead by "
            "calling `tf.enable_resource_variables()` at the start of the "
            "program." %
            (", ".join(graph_rewrites.enabled + graph_rewrites.default)))
        self.assertTrue(any(expected in str(warning) for warning in w))

        # Check that outputs are the same in the optimized and unoptimized cases,
        # when the variable value is changing.
        unoptimized_it = dataset_ops.make_initializable_iterator(
            unoptimized_dataset)
        with ops.control_dependencies([assign_op]):
            unoptimized_output = unoptimized_it.get_next()
            optimized_output = optimized_it.get_next()

        self.evaluate(variable.initializer)
        self.evaluate((unoptimized_it.initializer, optimized_it.initializer))
        while True:
            try:
                unoptimized, optimized = self.evaluate(
                    (unoptimized_output, optimized_output))
                self.assertEqual(unoptimized, optimized)
            except errors.OutOfRangeError:
                break

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDefault(self):
        """Tests the optimization settings by default."""
        options = dataset_ops.Options()
        expected_optimizations_enabled = []
        expected_optimizations_disabled = []
        expected_optimizations_default = [
            "map_and_batch_fusion",
            "map_parallelization",
            "noop_elimination",
            "shuffle_and_repeat_fusion",
        ]
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

        options.experimental_optimization.apply_default_optimizations = True
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

        options.experimental_optimization.apply_default_optimizations = False
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationEnabled(self):
        """Tests the optimization settings by enabling all."""
        options = dataset_ops.Options()
        options.experimental_optimization.filter_fusion = True
        options.experimental_optimization.filter_with_random_uniform_fusion = True
        options.experimental_optimization.hoist_random_uniform = True
        options.experimental_optimization.map_and_batch_fusion = True
        options.experimental_optimization.map_and_filter_fusion = True
        options.experimental_optimization.map_parallelization = True
        options.experimental_optimization.map_fusion = True
        options.experimental_optimization.noop_elimination = True
        options.experimental_optimization.parallel_batch = True
        options.experimental_optimization.shuffle_and_repeat_fusion = True
        options.experimental_optimization.map_vectorization.enabled = True
        options.experimental_optimization.autotune_buffers = True
        options.experimental_deterministic = False
        options.experimental_stats.latency_all_edges = True
        options.experimental_slack = True

        expected_optimizations_enabled = [
            "filter_fusion",
            "filter_with_random_uniform_fusion",
            "hoist_random_uniform",
            "map_and_batch_fusion",
            "map_and_filter_fusion",
            "map_parallelization",
            "map_fusion",
            "noop_elimination",
            "parallel_batch",
            "shuffle_and_repeat_fusion",
            "map_vectorization",
            "autotune_buffer_sizes",
            "make_sloppy",
            "latency_all_edges",
            "slack",
            "disable_prefetch_legacy_autotune",
        ]
        expected_optimizations_disabled = []
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDisabled(self):
        """Tests the optimization settings by disabling all."""
        options = dataset_ops.Options()
        options.experimental_optimization.filter_fusion = False
        options.experimental_optimization.filter_with_random_uniform_fusion = False
        options.experimental_optimization.hoist_random_uniform = False
        options.experimental_optimization.map_and_batch_fusion = False
        options.experimental_optimization.map_and_filter_fusion = False
        options.experimental_optimization.map_parallelization = False
        options.experimental_optimization.map_fusion = False
        options.experimental_optimization.noop_elimination = False
        options.experimental_optimization.parallel_batch = False
        options.experimental_optimization.shuffle_and_repeat_fusion = False
        options.experimental_optimization.map_vectorization.enabled = False
        options.experimental_optimization.autotune = False
        options.experimental_deterministic = True
        options.experimental_stats.latency_all_edges = False
        options.experimental_slack = False

        expected_optimizations_enabled = []
        expected_optimizations_disabled = [
            "filter_fusion",
            "filter_with_random_uniform_fusion",
            "hoist_random_uniform",
            "map_and_batch_fusion",
            "map_and_filter_fusion",
            "map_parallelization",
            "map_fusion",
            "noop_elimination",
            "parallel_batch",
            "shuffle_and_repeat_fusion",
            "map_vectorization",
            "autotune_buffer_sizes",
            "make_sloppy",
            "latency_all_edges",
            "slack",
            "disable_prefetch_legacy_autotune",
        ]
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune=[True, False, None]),
            combinations.combine(autotune_buffers=[True, False, None])))
    def testAutotuningSettings(self, autotune, autotune_buffers):
        options = dataset_ops.Options()
        if autotune is not None:
            options.experimental_optimization.autotune = autotune
        if autotune_buffers is not None:
            options.experimental_optimization.autotune_buffers = autotune_buffers

        # Check defaults
        autotune_settings = options._autotune_settings()
        autotune_val = autotune_settings[0]
        autotune_buffers_val = options.experimental_optimization._autotune_buffers(
        )

        if autotune is not False:  # pylint: disable=g-bool-id-comparison
            self.assertTrue(autotune_val)
        else:
            self.assertFalse(autotune_val)
        if autotune_buffers is True:  # pylint: disable=g-bool-id-comparison
            self.assertTrue(autotune_buffers_val)
        else:
            self.assertFalse(autotune_buffers_val)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune_buffers=[True, False, None])))
    def testAutotuneBuffersSettings(self, autotune_buffers):
        options = dataset_ops.Options()
        if autotune_buffers is not None:
            options.experimental_optimization.autotune_buffers = autotune_buffers

        graph_rewrites = options._graph_rewrites()
        autotune_settings = options._autotune_settings()
        algorithm = autotune_settings[1]

        if autotune_buffers is True:  # pylint: disable=g-bool-id-comparison
            self.assertIn("autotune_buffer_sizes", graph_rewrites.enabled)
            self.assertIn("disable_prefetch_legacy_autotune",
                          graph_rewrites.enabled)
            self.assertEqual(
                algorithm,
                optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
        else:
            self.assertNotIn("autotune_buffer_sizes", graph_rewrites.enabled)
            self.assertNotIn("disable_prefetch_legacy_autotune",
                             graph_rewrites.enabled)
            self.assertEqual(
                algorithm, optimization_options._AutotuneAlgorithm.HILL_CLIMB)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(set_budget=[True, False]),
        ))
    def testResourceBudgets(self, set_budget):
        options = dataset_ops.Options()
        if set_budget:
            options.experimental_optimization.autotune_cpu_budget = 1000
            options.experimental_optimization.autotune_ram_budget = 999999999

        autotune_settings = options._autotune_settings()
        cpu_budget = autotune_settings[2]
        ram_budget = autotune_settings[3]

        if set_budget:
            self.assertEqual(cpu_budget, 1000)
            self.assertEqual(ram_budget, 999999999)
        else:
            self.assertEqual(cpu_budget, 0)
            self.assertEqual(ram_budget, 0)
示例#14
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         data_service_test_base.all_cluster_configurations()))
  def testDistributeBasic(self, work_dir, fault_tolerant_mode):
    cluster = data_service_test_base.TestCluster(
        num_workers=1,
        work_dir=work_dir,
        fault_tolerant_mode=fault_tolerant_mode)
    num_elements = 10
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    self.assertDatasetProduces(ds, list(range(num_elements)))

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(compression=[None, "AUTO"])))
  def testDistributeCompression(self, compression):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 10
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, compression=compression)
    self.assertDatasetProduces(ds, list(range(num_elements)))

  @combinations.generate(test_base.default_test_combinations())
  def testDistributeInvalidCompression(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    with self.assertRaisesRegex(ValueError, "Invalid compression argument"):
      self.make_distributed_range_dataset(10, cluster, compression="foo")

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeSparse(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    element = sparse_tensor.SparseTensor(
        indices=[[0]],
        values=constant_op.constant([0], dtype=dtypes.int32),
        dense_shape=[1])
    ds = dataset_ops.Dataset.from_tensors(element)
    ds = self.make_distributed_dataset(ds, cluster)
    results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
    self.assertAllEqual(results, [[0]])

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeRagged(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
    ds = ds.map(math_ops.range)
    ds = ds.apply(batching.dense_to_ragged_batch(2))
    ds = self.make_distributed_dataset(ds, cluster)
    results = [elem.to_tensor() for elem in ds]
    self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
    self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
    self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(
              init_source=["textfile", "keyvaluetensor", "dataset"])))
  def testDistributeLookupTable(self, init_source):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    initializer = self.lookupTableInitializer(init_source, [10, 11])
    table = lookup_ops.StaticHashTable(initializer, -1)
    ds = dataset_ops.Dataset.range(3)
    ds = ds.map(table.lookup)
    ds = self.make_distributed_dataset(ds, cluster)
    self.evaluate(lookup_ops.tables_initializer())
    self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(value_rank=[0, 1])))
  def testDistributeMutableHashTable(self, value_rank):

    def value(v):
      for _ in range(value_rank):
        v = [v, v]
      return v

    v1 = value(10)
    v2 = value(11)
    default_value = value(-1)

    cluster = data_service_test_base.TestCluster(num_workers=1)
    table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
                                        default_value)
    self.evaluate(table.insert([0, 1], [v1, v2]))
    ds = dataset_ops.Dataset.range(3)
    ds = ds.map(table.lookup)
    ds = self.make_distributed_dataset(ds, cluster)
    self.assertDatasetProduces(
        ds, [v1, v2, default_value], requires_initialization=True)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(shuffle_seed=[None, 10])))
  def testShuffleOrder(self, shuffle_seed):
    random_seed.set_random_seed(None)
    num_elements = 100
    cluster = data_service_test_base.TestCluster(num_workers=2)
    ds = dataset_ops.Dataset.range(num_elements)
    ds = ds.shuffle(num_elements, seed=shuffle_seed)
    ds = self.make_distributed_dataset(ds, cluster)
    output = self.getDatasetOutput(ds)

    # The output will be two sequences of range(num_elements)
    # non-deterministically interleaved together. If the orders of the elements
    # were the same, first_order and second_order computed below will be equal.
    first_order = {}
    second_order = {}
    for element in output:
      if element in first_order:
        second_order[element] = len(second_order)
      else:
        first_order[element] = len(first_order)
    if shuffle_seed is None:
      self.assertNotEqual(first_order, second_order)
    else:
      self.assertEqual(first_order, second_order)

  @combinations.generate(test_base.default_test_combinations())
  def testMultipleEpochs(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 3
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    for _ in range(10):
      self.assertDatasetProduces(ds, list(range(num_elements)))

  @combinations.generate(test_base.default_test_combinations())
  def testRepeatedDataset(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 10
    num_repetitions = 5
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    ds = ds.repeat(num_repetitions)
    self.assertDatasetProduces(
        ds, expected_output=num_repetitions * list(range(num_elements)))

  @combinations.generate(test_base.default_test_combinations())
  def testConcurrentEpoch(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 10
    num_datasets = 3
    get_nexts = []
    results = []
    for _ in range(num_datasets):
      ds = self.make_distributed_range_dataset(num_elements, cluster)
      get_nexts.append(self.getNext(ds))
      results.append([])

    for _ in range(num_elements):
      for dataset_ind in range(num_datasets):
        result = self.evaluate(get_nexts[dataset_ind]())
        results[dataset_ind].append(result)
    for result in results:
      self.assertEqual(list(range(num_elements)), result)

  @combinations.generate(test_base.default_test_combinations())
  def testMultiWorker(self):
    num_workers = 3
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    num_elements = 10
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    self.assertDatasetProduces(
        ds, num_workers * list(range(num_elements)), assert_items_equal=True)

  @combinations.generate(test_base.default_test_combinations())
  def testMaxOutstandingRequests(self):
    num_workers = 3
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    num_elements = 10
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, max_outstanding_requests=1)
    self.assertDatasetProduces(
        ds, num_workers * list(range(num_elements)), assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testInsideFunction(self):
    num_workers = 3
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    num_elements = 10

    @def_function.function
    def f():
      ds = self.make_distributed_range_dataset(num_elements, cluster)
      result = tensor_array_ops.TensorArray(
          dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
      i = 0
      for elem in ds:
        result = result.write(i, elem)
        i += 1
      return result.stack()

    result = list(f().numpy())
    self.assertCountEqual(num_workers * list(range(num_elements)), result)

  @combinations.generate(test_base.default_test_combinations())
  def testSharedJobName(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 1000

    def make_ds():
      return dataset_ops.Dataset.range(num_elements).shuffle(num_elements)

    ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name")
    ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name")
    get_next_1 = self.getNext(ds1)
    get_next_2 = self.getNext(ds2)
    results = []
    for _ in range(num_elements // 5):
      results.append(self.evaluate(get_next_1()))
      results.append(self.evaluate(get_next_2()))
    results += self.getIteratorOutput(get_next_1)
    results += self.getIteratorOutput(get_next_2)
    self.assertCountEqual(list(range(num_elements)), results)

  @combinations.generate(test_base.default_test_combinations())
  def testDifferentJobNames(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 10
    ds1 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name1")
    ds2 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name2")
    self.assertDatasetProduces(ds1, list(range(num_elements)))
    self.assertDatasetProduces(ds2, list(range(num_elements)))

  @combinations.generate(test_base.eager_only_combinations())
  def testSharedJobNameMultiIteration(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 10
    ds1 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    ds2 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    # iteration 1
    self.assertDatasetProduces(ds1, list(range(num_elements)))
    self.assertDatasetProduces(ds2, [])
    # iteration 2
    self.assertDatasetProduces(ds2, list(range(num_elements)))
    self.assertDatasetProduces(ds1, [])

  @combinations.generate(test_base.default_test_combinations())
  def testSharedJobNameRepeat(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    num_repetitions = 3
    ds1 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    ds1 = ds1.repeat(num_repetitions)
    ds2 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    ds2 = ds2.repeat(num_repetitions)
    results = []
    get_next_1 = self.getNext(ds1)
    get_next_2 = self.getNext(ds2)
    for _ in range((num_elements * num_repetitions) // 5):
      results.append(self.evaluate(get_next_1()))
    for _ in range((num_elements * num_repetitions) // 5):
      results.append(self.evaluate(get_next_2()))
    results += self.getIteratorOutput(get_next_1)
    results += self.getIteratorOutput(get_next_2)
    self.assertCountEqual(num_repetitions * list(range(num_elements)), results)

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations(),
                         combinations.combine(job_name=[None, "test"])))
  def testGcUnusedJob(self, job_name):
    cluster = data_service_test_base.TestCluster(
        num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
    num_elements = 100
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, job_name=job_name)
    it = iter(ds)
    self.assertEqual(next(it).numpy(), 0)
    self.assertEqual(cluster.workers[0].num_tasks(), 1)
    del it
    while cluster.workers[0].num_tasks() > 0:
      time.sleep(0.1)

  @combinations.generate(test_base.eager_only_combinations())
  def testDontGcUsedJob(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
    num_elements = 10
    it1 = iter(
        self.make_distributed_range_dataset(
            num_elements, cluster, job_name="test1"))
    it2 = iter(
        self.make_distributed_range_dataset(
            num_elements, cluster, job_name="test2"))
    it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
        self.make_distributed_range_dataset(
            num_elements, cluster, job_name="test2"))
    self.assertEqual(cluster.workers[0].num_tasks(), 2)
    del it1
    del it2
    # Check that only the first job is gced. The second job will not be gced
    # because there is still an outstanding iterator for it.
    while cluster.workers[0].num_tasks() > 1:
      time.sleep(0.1)
    self.assertEqual(cluster.workers[0].num_tasks(), 1)

  @combinations.generate(test_base.eager_only_combinations())
  def testGcErrorMessage(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
    num_elements = 100
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="test")
    it = iter(ds)
    self.assertEqual(next(it).numpy(), 0)
    self.assertEqual(cluster.workers[0].num_tasks(), 1)
    del it
    while cluster.workers[0].num_tasks() > 0:
      time.sleep(0.1)

    ds = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="test")
    with self.assertRaisesRegex(
        errors.FailedPreconditionError,
        "The requested job has been garbage collected due to inactivity"):
      list(ds)

  @combinations.generate(test_base.default_test_combinations())
  def testApplyDeterminismOption(self):
    elements = list(range(10))
    cluster = data_service_test_base.TestCluster(num_workers=1)

    def dataset_fn(delay_ms):

      def interleave_fn(x):
        ds = dataset_ops.Dataset.from_tensors(x)
        if math_ops.equal(x, 0):
          ds = ds.apply(testing.sleep(delay_ms * 1000))
        else:
          ds = ds.apply(testing.sleep(0))
        return ds

      ds = dataset_ops.Dataset.from_tensor_slices(elements)
      ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10)
      opts = dataset_ops.Options()
      opts.experimental_deterministic = False
      ds = ds.with_options(opts)
      ds = self.make_distributed_dataset(ds, cluster)
      return ds

    self.checkDeterminism(
        dataset_fn=dataset_fn,
        expect_determinism=False,
        expected_elements=elements)

  def run_stateful(self, external_state_policy):
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements).map(
        lambda _: random_ops.random_uniform(()))

    options = dataset_ops.Options()
    options.experimental_external_state_policy = external_state_policy
    ds = ds.with_options(options)

    cluster = data_service_test_base.TestCluster(num_workers=3)
    ds = self.make_distributed_dataset(ds, cluster)
    self.getDatasetOutput(ds)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(external_state_policy=[
              distribute_options.ExternalStatePolicy.IGNORE,
              distribute_options.ExternalStatePolicy.WARN
          ])))
  def testStatefulNoError(self, external_state_policy):
    self.run_stateful(external_state_policy)

  @combinations.generate(test_base.default_test_combinations())
  def testStatefulError(self):
    with self.assertRaises(errors.FailedPreconditionError):
      self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

  @combinations.generate(test_base.default_test_combinations())
  def testDistributeFromInterleave(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    ds = dataset_ops.Dataset.range(2)

    def interleave_fn(_):
      dataset = dataset_ops.Dataset.range(2)
      dataset = self.make_distributed_dataset(dataset, cluster)
      return dataset

    ds = ds.interleave(interleave_fn, cycle_length=2)
    self.assertDatasetProduces(ds, [0, 0, 1, 1])

  @combinations.generate(test_base.default_test_combinations())
  def testDistributeNonStringAddresses(self):
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(ValueError, "service must be a string"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="parallel_epochs", service=1))

  @combinations.generate(test_base.default_test_combinations())
  def testDistributeEmptyAddress(self):
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesWithLiteralMatch(ValueError,
                                           "service must not be empty"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="parallel_epochs", service=""))

  @combinations.generate(test_base.default_test_combinations())
  def testDistributeExplicitProtocol(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    ds = dataset_ops.Dataset.range(10)
    ds = ds.apply(
        data_service_ops.distribute(
            processing_mode="parallel_epochs",
            service="grpc://" + cluster.dispatcher_address()))
    self.assertDatasetProduces(ds, list(range(10)))

  @combinations.generate(test_base.default_test_combinations())
  def testDistributeInvalidProtocol(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(
        errors.NotFoundError,
        "No credentials factory has been registered for protocol grp"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="parallel_epochs",
              service="grp://" + cluster.dispatcher_address()))
      self.getDatasetOutput(ds)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeInvalidProcessingMode(self):
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(ValueError,
                                "invalid is not a valid processing mode"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="invalid", service="grpc://localhost:5000"))

  @combinations.generate(test_base.default_test_combinations())
  def testZipDifferentProcessingModesDatasets(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    ds1 = dataset_ops.Dataset.range(num_elements)
    ds1 = self.make_distributed_dataset(
        ds1, cluster, processing_mode="distributed_epoch")
    ds2 = dataset_ops.Dataset.range(num_elements)
    ds2 = self.make_distributed_dataset(
        ds2, cluster, processing_mode="parallel_epochs")
    ds = dataset_ops.Dataset.zip((ds1, ds2))
    self.assertDatasetProduces(
        ds,
        list(zip(range(num_elements), range(num_elements))),
        assert_items_equal=True)

  @combinations.generate(test_base.default_test_combinations())
  def testZipDifferentProcessingModesDatasetsSharedJobName(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    ds1 = dataset_ops.Dataset.range(num_elements)
    ds1 = self.make_distributed_dataset(
        ds1, cluster, processing_mode="distributed_epoch", job_name="job_name")
    ds2 = dataset_ops.Dataset.range(num_elements)
    ds2 = self.make_distributed_dataset(
        ds2, cluster, processing_mode="parallel_epochs", job_name="job_name")
    ds = dataset_ops.Dataset.zip((ds1, ds2))
    with self.assertRaisesRegex(errors.FailedPreconditionError,
                                "but there is already an existing job"):
      self.getDatasetOutput(ds)

  @combinations.generate(test_base.default_test_combinations())
  def testFromDatasetId(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id,
        ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))

  @combinations.generate(test_base.default_test_combinations())
  def testRegisteringDatasetAsTfFunction(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    register_func = def_function.function(data_service_ops.register_dataset)
    dataset_id = register_func(
        (constant_op.constant("grpc"),
         constant_op.constant(cluster.dispatcher_address())), ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id,
        ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))

  @combinations.generate(test_base.default_test_combinations())
  def testFromDatasetIdMultipleComponents(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id,
        ds.element_spec)
    output = self.getDatasetOutput(from_dataset_id_ds)
    for i in range(num_elements):
      self.assertEqual(i, output[i]["a"][0])
      self.assertEqual(i, output[i]["a"][1])
      self.assertEqual(i, output[i]["b"])

  @combinations.generate(test_base.default_test_combinations())
  def testFromDatasetIdWrongElementSpec(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                   ds)
    wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id, wrong_spec)
    with self.assertRaisesRegex(errors.FailedPreconditionError,
                                "Expected a tensor of type variant"):
      self.evaluate(self.getNext(from_dataset_id_ds)())

  @combinations.generate(test_base.default_test_combinations())
  def testFromDatasetIdNotRegistered(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)

    dataset_id = 0
    element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.dispatcher_address(), dataset_id,
        element_spec)
    with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
      self.evaluate(self.getNext(from_dataset_id_ds)())

  @combinations.generate(test_base.default_test_combinations())
  def testCancellation(self):
    self.skipTest("b/162521601")
    sleep_microseconds = int(1e6) * 1000

    cluster = data_service_test_base.TestCluster(num_workers=1)
    # Create a dataset which produces the first element quickly, and the second
    # element slowly. Fetching the first element triggers prefetching of the
    # second element, which we should be able to cancel.
    slow = dataset_ops.Dataset.range(1)
    slow = slow.apply(testing.sleep(sleep_microseconds))
    ds = dataset_ops.Dataset.range(1).concatenate(slow)
    ds = self.make_distributed_dataset(ds, cluster)
    ds = ds.prefetch(1)
    get_next = self.getNext(ds)
    self.assertEqual(0, self.evaluate(get_next()))
    # Without properly implemented cancellation, we will hang here while trying
    # to garbage collect the dataset iterator.

  @combinations.generate(test_base.default_test_combinations())
  def testRegisterEquivalentDatasets(self):
    ds_1 = dataset_ops.Dataset.range(10)
    ds_2 = dataset_ops.Dataset.range(10)
    cluster = data_service_test_base.TestCluster(num_workers=1)
    id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1)
    id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2)
    self.assertEqual(self.evaluate(id_1), self.evaluate(id_2))

  @combinations.generate(test_base.default_test_combinations())
  def testRegisterDifferentDatasets(self):
    ds_1 = dataset_ops.Dataset.range(10)
    ds_2 = dataset_ops.Dataset.range(20)
    cluster = data_service_test_base.TestCluster(num_workers=1)
    id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_1)
    id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(), ds_2)
    self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))

  @combinations.generate(test_base.default_test_combinations())
  def testTwoLevelDistribute(self):
    cluster_1_size = 3
    cluster_1 = data_service_test_base.TestCluster(num_workers=cluster_1_size)
    cluster_2 = data_service_test_base.TestCluster(num_workers=1)
    num_sizes = 10
    size_repeats = 5
    strings = ["a" * i for i in range(num_sizes)] * size_repeats
    ds = dataset_ops.Dataset.from_tensor_slices(strings)
    ds = ds.shuffle(len(strings))
    ds = self.make_distributed_dataset(ds, cluster_1)
    # Large enough so that all strings of the same size are windowed together.
    window_size = cluster_1_size * size_repeats
    batch_size = size_repeats

    def key_func(x):
      return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

    ds = ds.apply(
        grouping.group_by_window(
            key_func=key_func,
            reduce_func=lambda _, x: x.batch(batch_size),
            window_size=window_size))
    ds = self.make_distributed_dataset(ds, cluster_2)

    get_next = self.getNext(ds)
    for _ in range(num_sizes):
      element = self.evaluate(get_next())
      for _ in range(1, cluster_1_size):
        self.assertAllEqual(self.evaluate(get_next()), element)
    self.assertEmpty(self.getIteratorOutput(get_next))

  @combinations.generate(
      combinations.times(test_base.default_test_combinations()))
  def testDistributeLargeGraph(self):
    cluster = data_service_test_base.TestCluster(
        num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
    # Larger than default OSS grpc message size limit of 4MB.
    tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
    ds = dataset_ops.Dataset.from_tensors(tensor)
    ds = self.make_distributed_dataset(ds, cluster)
    self.assertDatasetProduces(ds, [tensor])

  @combinations.generate(
      combinations.times(test_base.graph_only_combinations(),
                         combinations.combine(use_resource=False)) +
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(use_resource=True)))
  def testVariables(self, use_resource):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    if not use_resource:
      with variable_scope.variable_scope("foo", use_resource=False):
        v = variables.VariableV1(10, dtype=dtypes.int64)
    else:
      v = variables.Variable(10, dtype=dtypes.int64)

    ds = dataset_ops.Dataset.range(3)
    ds = ds.map(lambda x: x + v)
    ds = self.make_distributed_dataset(ds, cluster)
    self.evaluate(v.initializer)
    self.assertDatasetProduces(
        ds, list(range(10, 13)), requires_initialization=True)
示例#15
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):

    # pylint: disable=g-long-lambda,protected-access
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_flat_structure_combinations()))
    def testFlatStructure(self, value_fn, expected_structure_fn,
                          expected_types_fn, expected_shapes_fn):
        value = value_fn()
        expected_structure = expected_structure_fn()
        expected_types = expected_types_fn()
        expected_shapes = expected_shapes_fn()
        s = structure.type_spec_from_value(value)
        self.assertIsInstance(s, expected_structure)
        flat_types = structure.get_flat_tensor_types(s)
        self.assertEqual(expected_types, flat_types)
        flat_shapes = structure.get_flat_tensor_shapes(s)
        self.assertLen(flat_shapes, len(expected_shapes))
        for expected, actual in zip(expected_shapes, flat_shapes):
            if expected is None:
                self.assertEqual(actual.ndims, None)
            else:
                self.assertEqual(actual.as_list(), expected)

    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           _test_is_compatible_with_structure_combinations()))
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()

        s = structure.type_spec_from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                structure.are_compatible(
                    s, structure.type_spec_from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                structure.are_compatible(
                    s, structure.type_spec_from_value(incompatible_value)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_structure_from_value_equality_combinations()))
    def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                       not_equal_value_fns):
        # pylint: disable=g-generic-assert
        not_equal_value_fns = not_equal_value_fns._obj
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        self.assertEqual(s1, s1)  # check __eq__ operator.
        self.assertEqual(s1, s2)  # check __eq__ operator.
        self.assertFalse(s1 != s1)  # check __ne__ operator.
        self.assertFalse(s1 != s2)  # check __ne__ operator.
        for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
        for value_fn in not_equal_value_fns:
            s3 = structure.type_spec_from_value(value_fn())
            self.assertNotEqual(s1, s3)  # check __ne__ operator.
            self.assertNotEqual(s2, s3)  # check __ne__ operator.
            self.assertFalse(s1 == s3)  # check __eq_ operator.
            self.assertFalse(s2 == s3)  # check __eq_ operator.

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_ragged_structure_inequality_combinations()))
    def testRaggedStructureInequality(self, spec1, spec2):
        # pylint: disable=g-generic-assert
        self.assertNotEqual(spec1, spec2)  # check __ne__ operator.
        self.assertFalse(spec1 == spec2)  # check __eq__ operator.

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_hash_combinations()))
    def testHash(self, value1_fn, value2_fn, value3_fn):
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        s3 = structure.type_spec_from_value(value3_fn())
        for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2),
                              nest.flatten(s3)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
            self.assertNotEqual(hash(c1), hash(c3))
            self.assertNotEqual(hash(c2), hash(c3))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_round_trip_conversion_combinations()))
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.type_spec_from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(
                structure.from_tensor_list(s,
                                           structure.to_tensor_list(s,
                                                                    value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            elif isinstance(b, (ragged_tensor.RaggedTensor,
                                ragged_tensor_value.RaggedTensorValue)):
                self.assertAllEqual(b, a)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def preserveStaticShape(self):
        rt = ragged_factory_ops.constant([[1, 2], [], [3]])
        rt_s = structure.type_spec_from_value(rt)
        rt_after = structure.from_tensor_list(
            rt_s, structure.to_tensor_list(rt_s, rt))
        self.assertEqual(rt_after.row_splits.shape.as_list(),
                         rt.row_splits.shape.as_list())
        self.assertEqual(rt_after.values.shape.as_list(), [None])

        st = sparse_tensor.SparseTensor(indices=[[3, 4]],
                                        values=[-1],
                                        dense_shape=[4, 5])
        st_s = structure.type_spec_from_value(st)
        st_after = structure.from_tensor_list(
            st_s, structure.to_tensor_list(st_s, st))
        self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
        self.assertEqual(st_after.values.shape.as_list(), [None])
        self.assertEqual(st_after.dense_shape.shape.as_list(),
                         st.dense_shape.shape.as_list())

    @combinations.generate(test_base.default_test_combinations())
    def testPreserveTensorArrayShape(self):
        ta = tensor_array_ops.TensorArray(dtype=dtypes.int32,
                                          size=1,
                                          element_shape=(3, ))
        ta_s = structure.type_spec_from_value(ta)
        ta_after = structure.from_tensor_list(
            ta_s, structure.to_tensor_list(ta_s, ta))
        self.assertEqual(ta_after.element_shape.as_list(), [3])

    @combinations.generate(test_base.default_test_combinations())
    def testPreserveInferredTensorArrayShape(self):
        ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
        # Shape is inferred from the write.
        ta = ta.write(0, [1, 2, 3])
        ta_s = structure.type_spec_from_value(ta)
        ta_after = structure.from_tensor_list(
            ta_s, structure.to_tensor_list(ta_s, ta))
        self.assertEqual(ta_after.element_shape.as_list(), [3])

    @combinations.generate(test_base.default_test_combinations())
    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegex(
                TypeError, "neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a Tensor from the tensor list because item 0 "
                ".*tf.Tensor.* is incompatible with the expected TypeSpec "
                ".*TensorSpec.*"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a SparseTensor from the tensor list because "
                "item 0 .*tf.Tensor.* is incompatible with the expected TypeSpec "
                ".*TensorSpec.*"):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)

    @combinations.generate(test_base.default_test_combinations())
    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegex(
                TypeError, "neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegex(
                ValueError, r"Cannot create a Tensor from the tensor list"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3"):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegex(
                ValueError,
                "Cannot create a SparseTensor from the tensor list"):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3"):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2"):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2"):
            structure.from_tensor_list(s_2, flat_s_1)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_convert_legacy_structure_combinations()))
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertEqual(actual_structure, expected_structure)

    @combinations.generate(test_base.default_test_combinations())
    def testConvertLegacyStructureFail(self):
        with self.assertRaisesRegex(
                TypeError, "Could not build a structure for output class "
                "_EagerTensorArray. Make sure any component class in "
                "`output_classes` inherits from one of the following classes: "
                "`tf.TypeSpec`, `tf.sparse.SparseTensor`, `tf.Tensor`, "
                "`tf.TensorArray`."):
            structure.convert_legacy_structure(
                dtypes.int32, tensor_shape.TensorShape([2, None]),
                tensor_array_ops._EagerTensorArray)

    @combinations.generate(test_base.default_test_combinations())
    def testNestedNestedStructure(self):
        s = (tensor_spec.TensorSpec([], dtypes.int64),
             (tensor_spec.TensorSpec([], dtypes.float32),
              tensor_spec.TensorSpec([], dtypes.string)))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_batch_combinations()))
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = nest.map_structure(
            lambda component_spec: component_spec._batch(batch_size),
            element_structure)
        self.assertEqual(batched_structure, expected_batched_structure)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_unbatch_combinations()))
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = nest.map_structure(
            lambda component_spec: component_spec._unbatch(),
            element_structure)
        self.assertEqual(unbatched_structure, expected_unbatched_structure)

    # pylint: disable=g-long-lambda
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _test_to_batched_tensor_list_combinations()))
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.type_spec_from_value(batched_value)
        batched_tensor_list = structure.to_batched_tensor_list(
            s, batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = nest.map_structure(
            lambda component_spec: component_spec._unbatch(), s)
        actual_element_0 = structure.from_tensor_list(
            unbatched_s, [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            self.assertValuesEqual(expected, actual)

    # pylint: enable=g-long-lambda

    @combinations.generate(test_base.default_test_combinations())
    def testDatasetSpecConstructor(self):
        rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
        st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
        t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
        element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
        ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
        self.assertEqual(ds_struct._element_spec, element_spec)
        # Note: shape was automatically converted from a list to a TensorShape.
        self.assertEqual(ds_struct._dataset_shape,
                         tensor_shape.TensorShape([5]))

    @combinations.generate(test_base.default_test_combinations())
    def testCustomMapping(self):
        elem = CustomMap(foo=constant_op.constant(37.))
        spec = structure.type_spec_from_value(elem)
        self.assertIsInstance(spec, CustomMap)
        self.assertEqual(spec["foo"], tensor_spec.TensorSpec([],
                                                             dtypes.float32))

    @combinations.generate(test_base.default_test_combinations())
    def testObjectProxy(self):
        nt_type = collections.namedtuple("A", ["x", "y"])
        proxied = wrapt.ObjectProxy(nt_type(1, 2))
        proxied_spec = structure.type_spec_from_value(proxied)
        self.assertEqual(structure.type_spec_from_value(nt_type(1, 2)),
                         proxied_spec)

    @combinations.generate(test_base.default_test_combinations())
    def testTypeSpecNotBuild(self):
        with self.assertRaisesRegex(
                TypeError,
                "Could not build a `TypeSpec` for 100 with type int"):
            structure.type_spec_from_value(100, use_fallback=False)

    @combinations.generate(test_base.default_test_combinations())
    def testTypeSpecNotCompatible(self):
        test_obj = structure.NoneTensorSpec()
        with self.assertRaisesRegex(
                ValueError,
                r"No `TypeSpec` is compatible with both NoneTensorSpec\(\) "
                "and 100"):
            test_obj.most_specific_compatible_shape(100)
        self.assertEqual(test_obj,
                         test_obj.most_specific_compatible_shape(test_obj))
示例#16
0
class IOTest(test_base.DatasetTestBase, parameterized.TestCase):

  def setUp(self):
    super(IOTest, self).setUp()
    tmpdir = self.get_temp_dir()
    tmpdir = os.path.join(tmpdir, "io_test")
    os.mkdir(tmpdir)
    self._test_dir = tmpdir

    self._checkpoint_prefix = os.path.join(self.get_temp_dir(), "ckpt")
    os.mkdir(self._checkpoint_prefix)
    self._save_dir = os.path.join(self.get_temp_dir(), "save")
    os.mkdir(self._save_dir)

  def tearDown(self):
    super(IOTest, self).tearDown()
    shutil.rmtree(self._test_dir)
    shutil.rmtree(self._checkpoint_prefix)
    shutil.rmtree(self._save_dir)

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations(),
                         combinations.combine(compression=[None, "GZIP"])))
  def testBasic(self, compression):
    dataset = dataset_ops.Dataset.range(42)
    io.save(dataset, self._test_dir, compression=compression)
    dataset2 = io.load(
        self._test_dir, dataset.element_spec, compression=compression)
    self.assertDatasetProduces(dataset2, range(42))

  @combinations.generate(test_base.eager_only_combinations())
  def testCardinality(self):
    dataset = dataset_ops.Dataset.range(42)
    io.save(dataset, self._test_dir)
    dataset2 = io.load(self._test_dir, dataset.element_spec)
    self.assertEqual(self.evaluate(dataset2.cardinality()), 42)

  @combinations.generate(test_base.eager_only_combinations())
  def testCustomShardFunction(self):
    dataset = dataset_ops.Dataset.range(42)
    io.save(dataset, self._test_dir, shard_func=lambda x: x // 21)
    dataset2 = io.load(self._test_dir, dataset.element_spec)
    expected = []
    for i in range(21):
      expected.extend([i, i + 21])
    self.assertDatasetProduces(dataset2, expected)

  @combinations.generate(test_base.eager_only_combinations())
  def testCustomReaderFunction(self):
    dataset = dataset_ops.Dataset.range(42)
    io.save(dataset, self._test_dir, shard_func=lambda x: x % 7)
    dataset2 = io.load(
        self._test_dir,
        dataset.element_spec,
        reader_func=lambda x: x.flat_map(lambda y: y))
    expected = []
    for i in range(7):
      expected.extend(range(i, 42, 7))
    self.assertDatasetProduces(dataset2, expected)

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations(),
                         combinations.combine(compression=[None, "GZIP"])))
  def testSaveInsideFunction(self, compression):

    dataset = dataset_ops.Dataset.range(42)

    @def_function.function
    def save_fn():
      io.save(dataset, self._test_dir, compression=compression)

    save_fn()
    dataset = io.load(
        self._test_dir, dataset.element_spec, compression=compression)
    self.assertDatasetProduces(dataset, range(42))

  @combinations.generate(test_base.eager_only_combinations())
  def testElementSpecOptional(self):
    range_dataset = dataset_ops.Dataset.range(42)
    dict_dataset = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2],
                                                           "b": [3, 4]})
    tuple_dataset = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4]))
    dataset = dataset_ops.Dataset.zip((range_dataset, dict_dataset,
                                       tuple_dataset))
    io.save(dataset, self._test_dir)
    dataset_loaded = io.load(self._test_dir)
    self.assertDatasetsEqual(dataset, dataset_loaded)

  @combinations.generate(test_base.graph_only_combinations())
  def testElementSpecRequired(self):
    dataset = dataset_ops.Dataset.range(42)
    io.save(dataset, self._test_dir)
    with self.assertRaises(ValueError):
      _ = io.load(self._test_dir)

  @combinations.generate(test_base.eager_only_combinations())
  def testRepeatAndPrefetch(self):
    """This test reproduces github.com/tensorflow/tensorflow/issues/49165."""
    dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
    io.save(dataset1, self._test_dir)
    dataset = io.load(self._test_dir)
    dataset = dataset.shuffle(buffer_size=16)
    dataset = dataset.batch(16)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(1)
    next_element = self.getNext(dataset)
    for _ in range(30):
      self.evaluate(next_element())
示例#17
0
class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase):
  """Tests for tf.data placement within tf.functions.

  Specifically, tf.data dataset tensors cannot be copied between devices. These
  tests verify the ops are placed in a way that avoids this.
  """

  def setUp(self):
    super(PlacementTest, self).setUp()
    # Grappler optimizations can affect whether the placement issues occur,
    # since they may inadvertently rewrite nodes and edges in a way that removes
    # cross-device copies.
    config.set_optimizer_experimental_options({"disable_meta_optimizer": True})

  @combinations.generate(test_base.eager_only_combinations())
  def testWhileWithCapturedDataset(self):
    dataset = dataset_ops.Dataset.range(10)

    @def_function.function
    def f():
      total = constant_op.constant(0, dtypes.int64)
      for _ in math_ops.range(1):
        for elem in dataset:
          total += elem
      return total

    self.assertEqual(f().numpy(), 45)

  @combinations.generate(test_base.eager_only_combinations())
  def testWhile(self):

    @def_function.function
    def f():
      dataset = dataset_ops.Dataset.range(10)
      total = constant_op.constant(0, dtypes.int64)
      for _ in math_ops.range(1):
        for elem in dataset:
          total += elem
      return total

    self.assertEqual(f().numpy(), 45)

  @combinations.generate(test_base.eager_only_combinations())
  def testCondWithPlacement(self):
    # When the cond op is explicitly placed, there shouldn't be cross-device
    # copies.
    @def_function.function
    def f():
      dataset = dataset_ops.Dataset.range(10)

      def fn():
        return dataset.map(lambda x: x+1)

      c = constant_op.constant(2)
      with ops.device("/cpu:0"):
        a = control_flow_ops.cond(math_ops.equal(c, 2), fn, fn)
        iterator = iter(a)
        nxt = next(iterator)
      return nxt

    self.assertEqual(f().numpy(), 1)

  @combinations.generate(test_base.eager_only_combinations())
  def testCondWithColocation(self):
    # When the cond op is colocated with the dataset, there shouldn't be
    # cross-device copies.
    @def_function.function
    def f():
      dataset = dataset_ops.Dataset.range(8)

      def fn():
        return dataset.map(lambda x: x+1)

      c = constant_op.constant(2)
      with ops.colocate_with(dataset._variant_tensor):  # pylint:disable=protected-access
        a = control_flow_ops.cond(math_ops.equal(c, 2), fn, fn)
        iterator = iter(a)
        nxt = next(iterator)
      return nxt

    self.assertEqual(f().numpy(), 1)

  @combinations.generate(test_base.eager_only_combinations())
  def testCond(self):
    # Ideally, placer should avoid cross-device copies even when the cond op
    # has no placement constraints.
    @def_function.function
    def f():
      dataset = dataset_ops.Dataset.range(8)

      def fn():
        return dataset.map(lambda x: x+1)

      c = constant_op.constant(2)
      a = control_flow_ops.cond(math_ops.equal(c, 2), fn, fn)
      iterator = iter(a)
      nxt = next(iterator)
      return nxt

    self.assertEqual(f().numpy(), 1)

  @combinations.generate(test_base.eager_only_combinations())
  def testId(self):
    # Ideally, placer should know that Identity(dataset) should be on the same
    # device as the dataset.
    @def_function.function
    def f():
      dataset = dataset_ops.Dataset.range(10)
      dataset = array_ops.identity(dataset)
      return dataset
    f()

  @combinations.generate(test_base.eager_only_combinations())
  def testIteratorOnDeviceEagerMode(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
    iterator = iter(dataset)
    data = next(iterator)
    optional_data = iterator.get_next_as_optional()

    self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
    self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
    self.assertIn("gpu:0", data.device.lower())
    self.assertIn("gpu:0", optional_data.get_value().device.lower())
    self.assertIn("gpu:0", optional_data.has_value().device.lower())

  @combinations.generate(test_base.graph_only_combinations())
  def testIteratorOnDeviceGraphModeOneShotIterator(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    self.skipTest("TODO(b/169429285): tf.data.Dataset.make_one_shot_iterator "
                  "does not support GPU placement.")

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    data = iterator.get_next()
    optional_data = iterator.get_next_as_optional()

    with ops.colocate_with(dataset._variant_tensor):
      dataset_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(dataset_device))

    with ops.colocate_with(iterator._iterator_resource):
      iterator_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(iterator_device))

    with ops.colocate_with(data):
      data_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(data_device))

    with ops.colocate_with(optional_data.get_value()):
      get_value_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(get_value_device))

    with ops.colocate_with(optional_data.has_value()):
      has_value_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(has_value_device))

  @combinations.generate(test_base.graph_only_combinations())
  def testIteratorOnDeviceGraphModeInitializableIterator(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
    iterator = dataset_ops.make_initializable_iterator(dataset)
    data = iterator.get_next()
    optional_data = iterator.get_next_as_optional()

    with ops.colocate_with(dataset._variant_tensor):
      dataset_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(dataset_device))

    with ops.colocate_with(iterator._iterator_resource):
      iterator_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(iterator_device))

    with ops.colocate_with(data):
      data_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(data_device))

    with ops.colocate_with(optional_data.get_value()):
      get_value_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(get_value_device))

    with ops.colocate_with(optional_data.has_value()):
      has_value_device = test_ops.device_placement_op()
    self.assertIn(b"GPU:0", self.evaluate(has_value_device))

  @combinations.generate(test_base.eager_only_combinations())
  def testIterDatasetEagerModeWithExplicitDevice(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    @def_function.function
    def comp():
      value = constant_op.constant(0, dtype=dtypes.int64)
      for d in iter(dataset_ops.Dataset.range(10)):
        value += d
      return value

    with ops.device("/gpu:0"):
      result = comp()
    self.assertEqual(result.numpy(), 45)

  @combinations.generate(test_base.eager_only_combinations())
  def testFunctionInliningColocation(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    @def_function.function
    def f(ds):
      return next(iter(ds))

    @def_function.function
    def g():
      dataset = dataset_ops.Dataset.range(10)
      return f(dataset)

    with ops.device("/gpu:0"):
      self.assertEqual(self.evaluate(g()), 0)
示例#18
0
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchWithUnknownRankInput(self):
        dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]).unbatch()
        self.assertDatasetProduces(dataset, range(4))

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchScalarDataset(self):
        data = tuple([math_ops.range(10) for _ in range(3)])
        data = dataset_ops.Dataset.from_tensor_slices(data)
        expected_types = (dtypes.int32, ) * 3
        data = data.batch(2)
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))
        data = data.unbatch()
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))

        self.assertDatasetProduces(data, [(i, ) * 3 for i in range(10)])

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchNestedDataset(self):
        data = dataset_ops.Dataset.from_tensors(
            [dataset_ops.Dataset.range(10) for _ in range(10)])
        data = data.unbatch().flat_map(lambda x: x)
        self.assertDatasetProduces(data, list(range(10)) * 10)

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchDatasetWithStrings(self):
        data = tuple([math_ops.range(10) for _ in range(3)])
        data = dataset_ops.Dataset.from_tensor_slices(data)
        data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
        expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
        data = data.batch(2)
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))
        data = data.unbatch()
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))

        self.assertDatasetProduces(data, [(i, compat.as_bytes(str(i)), i)
                                          for i in range(10)])

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchDatasetWithSparseTensor(self):
        st = sparse_tensor.SparseTensorValue(indices=[[i, i]
                                                      for i in range(10)],
                                             values=list(range(10)),
                                             dense_shape=[10, 10])
        data = dataset_ops.Dataset.from_tensors(st)
        data = data.unbatch()
        data = data.batch(5)
        data = data.unbatch()
        expected_output = [
            sparse_tensor.SparseTensorValue([[i]], [i], [10])
            for i in range(10)
        ]
        self.assertDatasetProduces(data, expected_output=expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchDatasetWithDenseSparseAndRaggedTensor(self):
        st = sparse_tensor.SparseTensorValue(indices=[[i, i]
                                                      for i in range(10)],
                                             values=list(range(10)),
                                             dense_shape=[10, 10])
        rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]],
                                                [[4]], [[5]], [[6]], [[7]],
                                                [[8]], [[9]]])
        data = dataset_ops.Dataset.from_tensors((list(range(10)), st, rt))
        data = data.unbatch()
        data = data.batch(5)
        data = data.unbatch()
        expected_output = [(i, sparse_tensor.SparseTensorValue([[i]], [i],
                                                               [10]),
                            ragged_factory_ops.constant_value([[i]]))
                           for i in range(10)]
        self.assertDatasetProduces(data, expected_output=expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchDatasetWithRaggedTensor(self):
        rt = ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]],
                                                [[4]], [[5]], [[6]], [[7]],
                                                [[8]], [[9]]])
        data = dataset_ops.Dataset.from_tensors(rt)
        data = data.unbatch()
        data = data.batch(5)
        data = data.batch(2)
        data = data.unbatch()
        expected_output = [
            ragged_factory_ops.constant_value([[[0]], [[1]], [[2]], [[3]],
                                               [[4]]]),
            ragged_factory_ops.constant_value([[[5]], [[6]], [[7]], [[8]],
                                               [[9]]]),
        ]
        self.assertDatasetProduces(data, expected_output=expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchSingleElementTupleDataset(self):
        data = tuple([(math_ops.range(10), ) for _ in range(3)])
        data = dataset_ops.Dataset.from_tensor_slices(data)
        expected_types = ((dtypes.int32, ), ) * 3
        data = data.batch(2)
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))
        data = data.unbatch()
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))

        self.assertDatasetProduces(data, [((i, ), ) * 3 for i in range(10)])

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchMultiElementTupleDataset(self):
        data = tuple([(math_ops.range(10 * i,
                                      10 * i + 10), array_ops.fill([10], "hi"))
                      for i in range(3)])
        data = dataset_ops.Dataset.from_tensor_slices(data)
        expected_types = ((dtypes.int32, dtypes.string), ) * 3
        data = data.batch(2)
        self.assertAllEqual(expected_types,
                            dataset_ops.get_legacy_output_types(data))
        data = data.unbatch()
        self.assertAllEqual(expected_types,
                            dataset_ops.get_legacy_output_types(data))

        self.assertDatasetProduces(data, [((i, b"hi"), (10 + i, b"hi"),
                                           (20 + i, b"hi"))
                                          for i in range(10)])

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchEmpty(self):
        data = dataset_ops.Dataset.from_tensors(
            (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
             constant_op.constant([], shape=[0, 4, 0])))
        data = data.unbatch()
        self.assertDatasetProduces(data, [])

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchStaticShapeMismatch(self):
        data = dataset_ops.Dataset.from_tensors(
            (np.arange(7), np.arange(8), np.arange(9)))
        with self.assertRaises(ValueError):
            data.unbatch()

    @combinations.generate(test_base.graph_only_combinations())
    def testUnbatchDynamicShapeMismatch(self):
        ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
        ph2 = array_ops.placeholder(dtypes.int32, shape=None)
        data = dataset_ops.Dataset.from_tensors((ph1, ph2))
        data = data.unbatch()
        iterator = dataset_ops.make_initializable_iterator(data)
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            # Mismatch in the 0th dimension.
            sess.run(iterator.initializer,
                     feed_dict={
                         ph1: np.arange(7).astype(np.int32),
                         ph2: np.arange(8).astype(np.int32)
                     })
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(next_element)

            # No 0th dimension (i.e. scalar value) for one component.
            sess.run(iterator.initializer,
                     feed_dict={
                         ph1: np.arange(7).astype(np.int32),
                         ph2: 7
                     })
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testUnbatchDatasetWithUintDtypes(self):
        components = (
            np.tile(np.array([[0], [1], [2], [3]], dtype=np.uint8), 2),
            np.tile(np.array([[1], [2], [3], [256]], dtype=np.uint16), 2),
            np.tile(np.array([[2], [3], [4], [65536]], dtype=np.uint32), 2),
            np.tile(np.array([[3], [4], [5], [4294967296]], dtype=np.uint64),
                    2),
        )
        expected_types = (dtypes.uint8, dtypes.uint16, dtypes.uint32,
                          dtypes.uint64)
        expected_output = [tuple([c[i] for c in components]) for i in range(4)]

        data = dataset_ops.Dataset.from_tensor_slices(components)
        data = data.batch(2)
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))

        data = data.unbatch()
        self.assertEqual(expected_types,
                         dataset_ops.get_legacy_output_types(data))
        self.assertDatasetProduces(data, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testNoneComponent(self):
        dataset = dataset_ops.Dataset.from_tensors(
            (list(range(10)), None)).unbatch().map(lambda x, y: x)
        self.assertDatasetProduces(dataset, expected_output=range(10))

    @combinations.generate(test_base.default_test_combinations())
    def testName(self):
        dataset = dataset_ops.Dataset.from_tensors([42
                                                    ]).unbatch(name="unbatch")
        self.assertDatasetProduces(dataset, [42])
示例#19
0
class DataServiceMetadataTest(data_service_test_base.TestBase,
                              parameterized.TestCase):
    """Tests propagating data service metadata through tf.data service."""
    @combinations.generate(_cardinality_test_combinations())
    def testCardinality(self, dataset_fn, sharding_policy, expected_result):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        dataset = dataset_fn()
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)
        self.assertEqual(self.evaluate(dataset.cardinality()), expected_result)

    @combinations.generate(_cardinality_test_combinations())
    def testFromDatasetIdCardinality(self, dataset_fn, sharding_policy,
                                     expected_result):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        dataset = dataset_fn()
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset=dataset)
        dataset = data_service_ops.from_dataset_id(
            processing_mode=sharding_policy,
            service=cluster.dispatcher.target,
            dataset_id=dataset_id,
            element_spec=dataset.element_spec)
        self.assertEqual(self.evaluate(dataset.cardinality()), expected_result)

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdDoesntRequireElementSpec(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=data_service_test_base.NO_WORK_DIR,
            fault_tolerant_mode=False,
            data_transfer_protocol="grpc")
        num_elements = 10
        dataset = dataset_ops.Dataset.range(num_elements)

        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), dataset)
        dataset = data_service_ops.from_dataset_id(
            processing_mode=data_service_ops.ShardingPolicy.OFF,
            service=cluster.dispatcher_address(),
            dataset_id=dataset_id)
        self.assertDatasetProduces(dataset, list(range(num_elements)))

    @combinations.generate(test_base.graph_only_combinations())
    def testElementSpecGraphMode(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=data_service_test_base.NO_WORK_DIR,
            fault_tolerant_mode=False)
        num_elements = 10
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), dataset)
        with self.assertRaisesRegex(
                ValueError,
                "In graph mode `element_spec` must be provided manually."):
            _ = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id)

    @combinations.generate(test_base.eager_only_combinations())
    def testElementSpecMixedMode(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=data_service_test_base.NO_WORK_DIR,
            fault_tolerant_mode=False)
        num_elements = 10
        dataset = dataset_ops.Dataset.range(num_elements)

        @def_function.function
        def get_dataset_id():
            return data_service_ops.register_dataset(
                cluster.dispatcher_address(), dataset)

        dataset_id = get_dataset_id()
        dataset_id_val = tensor_util.constant_value(dataset_id)

        with self.assertRaisesRegex(
                ValueError,
                f"Failed to fetch element spec for dataset id {dataset_id_val} from "
                "tf.data service. If the dataset was registered in graph mode or "
                "inside a tf.function, the `element_spec` must be specified as an "
                "argument to `from_dataset_id`."):
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher_address(),
                dataset_id=dataset_id)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testFromDatasetIdOmitsCompression(self, compression):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("abcdefghijklmnopqrstuvwxyz"))

        def to_upper(x):
            return script_ops.numpy_function(
                func=lambda x: x.decode("utf-8").upper(),
                inp=[x],
                Tout=dtypes.string)

        dataset = dataset.map(to_upper,
                              num_parallel_calls=dataset_ops.AUTOTUNE)
        with mock.patch.object(compat, "forward_compatible",
                               return_value=True):
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher.target,
                dataset=dataset,
                compression=compression)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            self.assertDatasetProduces(dataset,
                                       list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))

    # Eager-only as querying `element_spec` is only supported in the eager mode.
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
        with mock.patch.object(compat, "forward_compatible",
                               return_value=True):
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher.target,
                dataset=dataset,
                compression=compression)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id)
            self.assertDatasetProduces(dataset,
                                       list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))

    def _testCompressionMismatch(self, dataset):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        with mock.patch.object(compat,
                               "forward_compatible",
                               return_value=False):
            dataset_id = data_service_ops._register_dataset(
                cluster.dispatcher.target, dataset=dataset, compression=None)
            # `compression` is "AUTO" by default.
            dataset = data_service_ops._from_dataset_id(
                processing_mode=data_service_ops.ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            with self.assertRaises(errors.InvalidArgumentError):
                self.getDatasetOutput(dataset)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testCompressionDtypeMismatch(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
        self._testCompressionMismatch(dataset)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testCompressionShapeMismatch(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]])
        self._testCompressionMismatch(dataset)

    # Only test eager mode since nested datasets are not allowed in graph mode.
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations()))
    def testCompressionVariantMismatch(self):
        # Use a nested dataset as an example of a variant.
        dataset = dataset_ops.Dataset.from_tensors(
            dataset_ops.Dataset.range(10))
        self._testCompressionMismatch(dataset)
示例#20
0
class CopyToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDevice(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceHostOptimizations(self):
        host_dataset = dataset_ops.Dataset.range(10)
        host_dataset = host_dataset.apply(testing.assert_next(["MapAndBatch"]))
        host_dataset = host_dataset.map(lambda x: x * x).batch(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.assertAllEqual([x * x for x in range(10)],
                                self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceInt32(self):
        host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int32, next_element.dtype)
        self.assertEqual((4, ), next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToSameDevice(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:0"))

        with ops.device("/cpu:0"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceWithPrefetch(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyDictToDevice(self):
        host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element["a"].dtype)
        self.assertEqual([], next_element["a"].shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual({"a": i}, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyDictToDeviceWithPrefetch(self):
        host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element["a"].dtype)
        self.assertEqual([], next_element["a"].shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                self.assertEqual({"a": i}, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopySparseTensorsToDevice(self):
        def make_tensor(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0]],
                                                   values=(i * [1]),
                                                   dense_shape=[2, 2])

        host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)

        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                actual = self.evaluate(next_element)
                self.assertAllEqual([i], actual.values)
                self.assertAllEqual([[0, 0]], actual.indices)
                self.assertAllEqual([2, 2], actual.dense_shape)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopySparseTensorsToDeviceWithPrefetch(self):
        def make_tensor(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0]],
                                                   values=(i * [1]),
                                                   dense_shape=[2, 2])

        host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)

        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                actual = self.evaluate(next_element)
                self.assertAllEqual([i], actual.values)
                self.assertAllEqual([[0, 0]], actual.indices)
                self.assertAllEqual([2, 2], actual.dense_shape)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testCopyToDeviceGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))

        with ops.device("/gpu:0"):
            self.assertDatasetProduces(device_dataset, list(range(10)))

    @combinations.generate(test_base.default_test_combinations())
    def testCopyToDeviceGpuWithPrefetch(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)

        with ops.device("/gpu:0"):
            self.assertDatasetProduces(device_dataset, list(range(10)))

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuWithMap(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        def generator():
            for i in range(10):
                yield i, float(i), str(i)

        host_dataset = dataset_ops.Dataset.from_generator(
            generator,
            output_types=(dtypes.int32, dtypes.float32, dtypes.string))
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))

        def gpu_map_func(x, y, z):
            return math_ops.square(x), math_ops.square(y), z

        device_dataset = device_dataset.apply(
            prefetching_ops.map_on_gpu(gpu_map_func))
        options = options_lib.Options()
        options.autotune.enabled = False
        device_dataset = device_dataset.with_options(options)

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            for i in range(10):
                x, y, z = self.evaluate(next_element)
                self.assertEqual(i**2, x)
                self.assertEqual(float(i**2), y)
                self.assertEqual(util_compat.as_bytes(str(i)), z)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuInt32(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuInt32AndPrefetch(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuStrings(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            self.assertAllEqual([b"a", b"b", b"c"],
                                self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuStringsAndPrefetch(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            self.assertAllEqual([b"a", b"b", b"c"],
                                self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDevicePingPongCPUGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
        back_to_cpu_dataset = device_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))

        with ops.device("/cpu:0"):
            iterator = dataset_ops.make_initializable_iterator(
                back_to_cpu_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceWithReInit(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceWithReInitAndPrefetch(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuWithReInit(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testCopyToDeviceGpuWithReInitAndPrefetch(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)

        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorGetNextAsOptionalOnGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        host_dataset = dataset_ops.Dataset.range(3)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/gpu:0"))
        with ops.device("/gpu:0"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_elem = iterator_ops.get_next_as_optional(iterator)
            elem_has_value_t = next_elem.has_value()
            elem_value_t = next_elem.get_value()

        with self.cached_session(config=config_pb2.ConfigProto(
                allow_soft_placement=False)):
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError.
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_value_t)

            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            self.evaluate(iterator.initializer)
            for i in range(3):
                elem_has_value, elem_value = self.evaluate(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self.assertEqual(i, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(self.evaluate(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(elem_value_t)
示例#21
0
class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationStatefulFunction(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda _: random_ops.random_uniform([])).batch(10)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        get_next = self.getNext(dataset)
        self.evaluate(get_next())

    # TODO(b/123354468)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensor(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
        dataset = dataset_ops.Dataset.from_tensors(input_t)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    # TODO(b/123354468)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensorSlices(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
        dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op,
                     {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDataset(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
            dataset = dataset.skip(0)  # Should be removed by noop elimination
            dataset = dataset.cache()
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[0])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDatasetWithModifiedRetval(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
            # Should be fused by map and batch fusion
            dataset = dataset.map(lambda x: x)
            dataset = dataset.batch(1)
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)

        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_batch_fusion = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[[0]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune=[True, False, None]),
            combinations.combine(map_parallelization=[True, False, None])))
    def testOptimizationMapParallelization(self, autotune,
                                           map_parallelization):
        dataset = dataset_ops.Dataset.range(5)
        if autotune is not False and map_parallelization is not False:  # pylint: disable=g-bool-id-comparison
            dataset = dataset.apply(testing.assert_next(["ParallelMap"]))
        else:
            dataset = dataset.apply(testing.assert_next(["Map"]))
        dataset = dataset.map(lambda x: x + 1)

        options = options_lib.Options()
        if autotune is not None:
            options.autotune.enabled = autotune
        if map_parallelization is not None:
            options.experimental_optimization.map_parallelization = (
                map_parallelization)
        dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=list(range(1, 6)))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(existing_prefetch=[True, False]),
            combinations.combine(autotune=[True, False]),
            combinations.combine(set_env=[True, False])))
    def testOptimizationInjectPrefetch(self, existing_prefetch, autotune,
                                       set_env):
        if set_env:
            os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "inject_prefetch"
            os.environ["TF_JOB_NAME"] = "test_job"

        dataset = dataset_ops.Dataset.range(5)
        dataset = dataset.map(lambda x: x + 1,
                              num_parallel_calls=dataset_ops.AUTOTUNE)
        if existing_prefetch:
            dataset = dataset.prefetch(1)
        if autotune and set_env and not existing_prefetch:
            dataset = dataset.apply(testing.assert_next(["Prefetch", "Root"]))
        else:
            dataset = dataset.apply(testing.assert_next(["Root"]))

        options = options_lib.Options()
        options.autotune.enabled = autotune
        dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=list(range(1, 6)))

        if set_env:
            del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
            del os.environ["TF_JOB_NAME"]

    # Reference variables are not supported in eager mode.
    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           _captured_refvar_test_combinations()))
    def testOptimizationWithCapturedRefVar(self, dataset_fn):
        """Tests that default optimizations are disabled with ref variables."""
        variable = variable_scope.get_variable("v",
                                               initializer=0,
                                               use_resource=False)
        assign_op = variable.assign_add(1)
        unoptimized_dataset = dataset_fn(variable)

        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        options.experimental_optimization.map_and_batch_fusion = True
        optimized_dataset = unoptimized_dataset.with_options(options)
        optimized_it = dataset_ops.make_initializable_iterator(
            optimized_dataset)

        # Check that outputs are the same in the optimized and unoptimized cases,
        # when the variable value is changing.
        unoptimized_it = dataset_ops.make_initializable_iterator(
            unoptimized_dataset)
        with ops.control_dependencies([assign_op]):
            unoptimized_output = unoptimized_it.get_next()
            optimized_output = optimized_it.get_next()

        self.evaluate(variable.initializer)
        self.evaluate((unoptimized_it.initializer, optimized_it.initializer))
        while True:
            try:
                unoptimized, optimized = self.evaluate(
                    (unoptimized_output, optimized_output))
                self.assertEqual(unoptimized, optimized)
            except errors.OutOfRangeError:
                break
class ParseExampleDatasetTest(test_base.DatasetTestBase,
                              parameterized.TestCase):

  def _compare_output_to_expected(self, dict_tensors, expected_tensors):
    self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))

    for k, v in sorted(dict_tensors.items()):
      expected_v = expected_tensors[k]
      self.assertValuesEqual(expected_v, v)

  def _test(self,
            input_tensor,
            feature_val,
            expected_values=None,
            expected_err=None,
            create_iterator_twice=False):

    if expected_err:
      with self.assertRaisesWithPredicateMatch(expected_err[0],
                                               expected_err[1]):
        dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
            contrib_parsing_ops.parse_example_dataset(feature_val))
        get_next = self.getNext(dataset)
        self.evaluate(get_next())
      return
    else:
      # Returns dict w/ Tensors and SparseTensors.
      # Check values.
      dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
          contrib_parsing_ops.parse_example_dataset(feature_val))
      get_next = self.getNext(dataset)
      result = self.evaluate(get_next())
      self._compare_output_to_expected(result, expected_values)
      with self.assertRaises(errors_impl.OutOfRangeError):
        self.evaluate(get_next())
      with self.assertRaises(errors_impl.OutOfRangeError):
        self.evaluate(get_next())
      if create_iterator_twice:
        get_next = self.getNext(dataset)
        result = self.evaluate(get_next())
        self._compare_output_to_expected(result, expected_values)
        with self.assertRaises(errors_impl.OutOfRangeError):
          self.evaluate(get_next())
    # Check shapes; if serialized is a Tensor we need its size to
    # properly check.
    batch_size = (
        self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor)
        else np.asarray(input_tensor).size)
    for k, f in feature_val.items():
      if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
        self.assertEqual(
            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[0],
            batch_size)
      elif isinstance(f, parsing_ops.VarLenFeature):
        self.assertEqual(
            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)

  @combinations.generate(test_base.default_test_combinations())
  def testEmptySerializedWithAllDefaults(self):
    sparse_name = "st_a"
    a_name = "a"
    b_name = "b"
    c_name = "c:has_a_tricky_name"
    a_default = [0, 42, 0]
    b_default = np.random.rand(3, 3).astype(bytes)
    c_default = np.random.rand(2).astype(np.float32)

    expected_st_a = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.empty((0, 2), dtype=np.int64),  # indices
        np.empty((0,), dtype=np.int64),  # sp_a is DT_INT64
        np.array([2, 0], dtype=np.int64))  # batch == 2, max_elems = 0

    expected_output = {
        sparse_name: expected_st_a,
        a_name: np.array(2 * [[a_default]]),
        b_name: np.array(2 * [b_default]),
        c_name: np.array(2 * [c_default]),
    }

    self._test(
        ops.convert_to_tensor(["", ""]), {
            sparse_name:
                parsing_ops.VarLenFeature(dtypes.int64),
            a_name:
                parsing_ops.FixedLenFeature(
                    (1, 3), dtypes.int64, default_value=a_default),
            b_name:
                parsing_ops.FixedLenFeature(
                    (3, 3), dtypes.string, default_value=b_default),
            c_name:
                parsing_ops.FixedLenFeature(
                    (2,), dtypes.float32, default_value=c_default),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.graph_only_combinations())
  def testEmptySerializedWithoutDefaultsShouldFail(self):
    input_features = {
        "st_a":
            parsing_ops.VarLenFeature(dtypes.int64),
        "a":
            parsing_ops.FixedLenFeature(
                (1, 3), dtypes.int64, default_value=[0, 42, 0]),
        "b":
            parsing_ops.FixedLenFeature(
                (3, 3),
                dtypes.string,
                default_value=np.random.rand(3, 3).astype(bytes)),
        # Feature "c" is missing a default, this gap will cause failure.
        "c":
            parsing_ops.FixedLenFeature(
                (2,), dtype=dtypes.float32),
    }

    # Edge case where the key is there but the feature value is empty
    original = example(features=features({"c": feature()}))
    self._test(
        [original.SerializeToString()],
        input_features,
        expected_err=(errors_impl.InvalidArgumentError,
                      "Feature: c \\(data type: float\\) is required"))

    # Standard case of missing key and value.
    self._test(
        ["", ""],
        input_features,
        expected_err=(errors_impl.InvalidArgumentError,
                      "Feature: c \\(data type: float\\) is required"))

  @combinations.generate(test_base.graph_only_combinations())
  def testDenseNotMatchingShapeShouldFail(self):
    original = [
        example(features=features({
            "a": float_feature([1, 1, 3]),
        })), example(features=features({
            "a": float_feature([-1, -1]),
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    self._test(
        ops.convert_to_tensor(serialized),
        {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
        expected_err=(errors_impl.InvalidArgumentError,
                      "Key: a, Index: 1.  Number of float values"))

  @combinations.generate(test_base.default_test_combinations())
  def testDenseDefaultNoShapeShouldFail(self):
    original = [example(features=features({"a": float_feature([1, 1, 3]),})),]

    serialized = [m.SerializeToString() for m in original]

    self._test(
        ops.convert_to_tensor(serialized),
        {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
        expected_err=(ValueError, "Missing shape for feature a"))

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingSparse(self):
    original = [
        example(features=features({
            "st_c": float_feature([3, 4])
        })),
        example(features=features({
            "st_c": float_feature([]),  # empty float list
        })),
        example(features=features({
            "st_d": feature(),  # feature with nothing in it
        })),
        example(features=features({
            "st_c": float_feature([1, 2, -1]),
            "st_d": bytes_feature([b"hi"])
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_st_c = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64),
        np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32),
        np.array([4, 3], dtype=np.int64))  # batch == 2, max_elems = 3

    expected_st_d = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[3, 0]], dtype=np.int64), np.array(["hi"], dtype=bytes),
        np.array([4, 1], dtype=np.int64))  # batch == 2, max_elems = 1

    expected_output = {
        "st_c": expected_st_c,
        "st_d": expected_st_d,
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "st_c": parsing_ops.VarLenFeature(dtypes.float32),
            "st_d": parsing_ops.VarLenFeature(dtypes.string)
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingSparseFeature(self):
    original = [
        example(features=features({
            "val": float_feature([3, 4]),
            "idx": int64_feature([5, 10])
        })),
        example(features=features({
            "val": float_feature([]),  # empty float list
            "idx": int64_feature([])
        })),
        example(features=features({
            "val": feature(),  # feature with nothing in it
            # missing idx feature
        })),
        example(features=features({
            "val": float_feature([1, 2, -1]),
            "idx":
                int64_feature([0, 9, 3])  # unsorted
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_sp = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
        np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
        np.array([4, 13], dtype=np.int64))  # batch == 4, max_elems = 13

    expected_output = {"sp": expected_sp,}

    self._test(
        ops.convert_to_tensor(serialized),
        {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingSparseFeatureReuse(self):
    original = [
        example(features=features({
            "val1": float_feature([3, 4]),
            "val2": float_feature([5, 6]),
            "idx": int64_feature([5, 10])
        })),
        example(features=features({
            "val1": float_feature([]),  # empty float list
            "idx": int64_feature([])
        })),
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_sp1 = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 5], [0, 10]], dtype=np.int64),
        np.array([3.0, 4.0], dtype=np.float32),
        np.array([2, 13], dtype=np.int64))  # batch == 2, max_elems = 13

    expected_sp2 = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 5], [0, 10]], dtype=np.int64),
        np.array([5.0, 6.0], dtype=np.float32),
        np.array([2, 7], dtype=np.int64))  # batch == 2, max_elems = 13

    expected_output = {
        "sp1": expected_sp1,
        "sp2": expected_sp2,
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "sp1":
                parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
            "sp2":
                parsing_ops.SparseFeature(
                    "idx", "val2", dtypes.float32, size=7, already_sorted=True)
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContaining3DSparseFeature(self):
    original = [
        example(features=features({
            "val": float_feature([3, 4]),
            "idx0": int64_feature([5, 10]),
            "idx1": int64_feature([0, 2]),
        })),
        example(features=features({
            "val": float_feature([]),  # empty float list
            "idx0": int64_feature([]),
            "idx1": int64_feature([]),
        })),
        example(features=features({
            "val": feature(),  # feature with nothing in it
            # missing idx feature
        })),
        example(features=features({
            "val": float_feature([1, 2, -1]),
            "idx0": int64_feature([0, 9, 3]),  # unsorted
            "idx1": int64_feature([1, 0, 2]),
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_sp = sparse_tensor.SparseTensorValue(
        # indices
        np.array([[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
                 dtype=np.int64),
        # values
        np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
        # shape batch == 4, max_elems = 13
        np.array([4, 13, 3], dtype=np.int64))

    expected_output = {"sp": expected_sp,}

    self._test(
        ops.convert_to_tensor(serialized), {
            "sp":
                parsing_ops.SparseFeature(["idx0", "idx1"], "val",
                                          dtypes.float32, [13, 3])
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingDense(self):
    aname = "a"
    bname = "b*has+a:tricky_name"
    original = [
        example(features=features({
            aname: float_feature([1, 1]),
            bname: bytes_feature([b"b0_str"]),
        })), example(features=features({
            aname: float_feature([-1, -1]),
            bname: bytes_feature([b""]),
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_output = {
        aname:
            np.array(
                [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
        bname:
            np.array(
                ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
    }

    # No defaults, values required
    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
            bname:
                parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  # This test is identical as the previous one except
  # for the creation of 'serialized'.
  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingDenseWithConcat(self):
    aname = "a"
    bname = "b*has+a:tricky_name"
    # TODO(lew): Feature appearing twice should be an error in future.
    original = [
        (example(features=features({
            aname: float_feature([10, 10]),
        })), example(features=features({
            aname: float_feature([1, 1]),
            bname: bytes_feature([b"b0_str"]),
        }))),
        (
            example(features=features({
                bname: bytes_feature([b"b100"]),
            })),
            example(features=features({
                aname: float_feature([-1, -1]),
                bname: bytes_feature([b"b1"]),
            })),),
    ]

    serialized = [
        m.SerializeToString() + n.SerializeToString() for (m, n) in original
    ]

    expected_output = {
        aname:
            np.array(
                [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
        bname:
            np.array(
                ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
    }

    # No defaults, values required
    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
            bname:
                parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingDenseScalar(self):
    original = [
        example(features=features({
            "a": float_feature([1]),
        })), example(features=features({}))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_output = {
        "a":
            np.array(
                [[1], [-1]], dtype=np.float32)  # 2x1 (column vector)
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "a":
                parsing_ops.FixedLenFeature(
                    (1,), dtype=dtypes.float32, default_value=-1),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingDenseWithDefaults(self):
    original = [
        example(features=features({
            "a": float_feature([1, 1]),
        })),
        example(features=features({
            "b": bytes_feature([b"b1"]),
        })),
        example(features=features({
            "b": feature()
        })),
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_output = {
        "a":
            np.array(
                [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
                                                                      1),
        "b":
            np.array(
                ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
                                                                   1),
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "a":
                parsing_ops.FixedLenFeature(
                    (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
            "b":
                parsing_ops.FixedLenFeature(
                    (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
    expected_st_a = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.empty((0, 2), dtype=np.int64),  # indices
        np.empty((0,), dtype=np.int64),  # sp_a is DT_INT64
        np.array([2, 0], dtype=np.int64))  # batch == 2, max_elems = 0
    expected_sp = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64),
        np.array(["a", "b", "c"], dtype="|S"),
        np.array([2, 13], dtype=np.int64))  # batch == 4, max_elems = 13

    original = [
        example(features=features({
            "c": float_feature([3, 4]),
            "val": bytes_feature([b"a", b"b"]),
            "idx": int64_feature([0, 3])
        })), example(features=features({
            "c": float_feature([1, 2]),
            "val": bytes_feature([b"c"]),
            "idx": int64_feature([7])
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    a_default = [1, 2, 3]
    b_default = np.random.rand(3, 3).astype(bytes)
    expected_output = {
        "st_a": expected_st_a,
        "sp": expected_sp,
        "a": np.array(2 * [[a_default]]),
        "b": np.array(2 * [b_default]),
        "c": np.array(
            [[3, 4], [1, 2]], dtype=np.float32),
    }

    self._test(
        ops.convert_to_tensor(serialized),
        {
            "st_a":
                parsing_ops.VarLenFeature(dtypes.int64),
            "sp":
                parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
            "a":
                parsing_ops.FixedLenFeature(
                    (1, 3), dtypes.int64, default_value=a_default),
            "b":
                parsing_ops.FixedLenFeature(
                    (3, 3), dtypes.string, default_value=b_default),
            # Feature "c" must be provided, since it has no default_value.
            "c":
                parsing_ops.FixedLenFeature((2,), dtypes.float32),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
    expected_idx = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
        np.array([0, 3, 7, 1]),
        np.array([2, 2], dtype=np.int64))  # batch == 4, max_elems = 2

    expected_sp = sparse_tensor.SparseTensorValue(  # indices, values, shape
        np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64),
        np.array(["a", "b", "d", "c"], dtype="|S"),
        np.array([2, 13], dtype=np.int64))  # batch == 4, max_elems = 13

    original = [
        example(features=features({
            "val": bytes_feature([b"a", b"b"]),
            "idx": int64_feature([0, 3])
        })), example(features=features({
            "val": bytes_feature([b"c", b"d"]),
            "idx": int64_feature([7, 1])
        }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_output = {
        "idx": expected_idx,
        "sp": expected_sp,
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "idx":
                parsing_ops.VarLenFeature(dtypes.int64),
            "sp":
                parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(batch_size=[1, 10, 20, 100, 256]))
  )
  def testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
    np.random.seed(3456)
    # During parsing, data read from the serialized proto is stored in buffers.
    # For small batch sizes, a buffer will contain one minibatch entry.
    # For larger batch sizes, a buffer may contain several minibatch
    # entries.  This test identified a bug where the code that copied
    # data out of the buffers and into the output tensors assumed each
    # buffer only contained one minibatch entry.  The bug has since been fixed.
    truth_int = [i for i in range(batch_size)]
    truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
                 for i in range(batch_size)]

    expected_str = copy.deepcopy(truth_str)

    # Delete some intermediate entries
    for i in range(batch_size):
      col = 1
      if np.random.rand() < 0.25:
        # w.p. 25%, drop out the second entry
        expected_str[i][col] = b"default"
        col -= 1
        truth_str[i].pop()
      if np.random.rand() < 0.25:
        # w.p. 25%, drop out the second entry (possibly again)
        expected_str[i][col] = b"default"
        truth_str[i].pop()

    expected_output = {
        # Batch size batch_size, 1 time step.
        "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
        # Batch size batch_size, 2 time steps.
        "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
    }

    original = [
        example(features=features(
            {"a": int64_feature([truth_int[i]]),
             "b": bytes_feature(truth_str[i])}))
        for i in range(batch_size)
    ]

    serialized = [m.SerializeToString() for m in original]

    self._test(
        ops.convert_to_tensor(serialized, dtype=dtypes.string), {
            "a":
                parsing_ops.FixedLenSequenceFeature(
                    shape=(),
                    dtype=dtypes.int64,
                    allow_missing=True,
                    default_value=-1),
            "b":
                parsing_ops.FixedLenSequenceFeature(
                    shape=[],
                    dtype=dtypes.string,
                    allow_missing=True,
                    default_value="default"),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedShapeMismatch(self):
    aname = "a"
    bname = "b"
    cname = "c"
    original = [
        example(features=features({
            cname: int64_feature([2]),
        })),
        example(features=features({
            aname: float_feature([1, 1]),
            bname: bytes_feature([b"b0_str", b"b1_str"]),
        })),
        example(features=features({
            aname: float_feature([-1, -1, 2, 2]),
            bname: bytes_feature([b"b1"]),
        })),
        example(features=features({
            aname: float_feature([]),
            cname: int64_feature([3]),
        })),
    ]

    serialized = [m.SerializeToString() for m in original]
    if context.executing_eagerly():
      self._test(
          ops.convert_to_tensor(serialized), {
              aname:
                  parsing_ops.FixedLenSequenceFeature((2, 1),
                                                      dtype=dtypes.float32,
                                                      allow_missing=True,
                                                      default_value=[]),
              bname:
                  parsing_ops.FixedLenSequenceFeature(
                      (2, 1, 1), dtype=dtypes.string, allow_missing=True),
          },
          expected_err=(errors_impl.InvalidArgumentError,
                        "Input to reshape is a tensor with 0 values"))
    else:
      self._test(
          ops.convert_to_tensor(serialized), {
              aname:
                  parsing_ops.FixedLenSequenceFeature((2, 1),
                                                      dtype=dtypes.float32,
                                                      allow_missing=True,
                                                      default_value=[]),
              bname:
                  parsing_ops.FixedLenSequenceFeature(
                      (2, 1, 1), dtype=dtypes.string, allow_missing=True),
          },
          expected_err=(ValueError,
                        "Cannot reshape a tensor with 0 elements to shape"))

  @combinations.generate(test_base.graph_only_combinations())
  def testSerializedContainingVarLenDense(self):
    aname = "a"
    bname = "b"
    cname = "c"
    dname = "d"
    original = [
        example(features=features({
            cname: int64_feature([2]),
        })),
        example(
            features=features({
                aname: float_feature([1, 1]),
                bname: bytes_feature([b"b0_str", b"b1_str"]),
            })),
        example(
            features=features({
                aname: float_feature([-1, -1, 2, 2]),
                bname: bytes_feature([b"b1"]),
            })),
        example(
            features=features({
                aname: float_feature([]),
                cname: int64_feature([3]),
            })),
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_output = {
        aname:
            np.array(
                [
                    [0, 0, 0, 0],
                    [1, 1, 0, 0],
                    [-1, -1, 2, 2],
                    [0, 0, 0, 0],
                ],
                dtype=np.float32).reshape(4, 2, 2, 1),
        bname:
            np.array(
                [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
                dtype=bytes).reshape(4, 2, 1, 1, 1),
        cname:
            np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
        dname:
            np.empty(shape=(4, 0), dtype=bytes),
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenSequenceFeature(
                    (2, 1), dtype=dtypes.float32, allow_missing=True),
            bname:
                parsing_ops.FixedLenSequenceFeature(
                    (1, 1, 1), dtype=dtypes.string, allow_missing=True),
            cname:
                parsing_ops.FixedLenSequenceFeature(
                    shape=[], dtype=dtypes.int64, allow_missing=True),
            dname:
                parsing_ops.FixedLenSequenceFeature(
                    shape=[], dtype=dtypes.string, allow_missing=True),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

    # Test with padding values.
    expected_output_custom_padding = dict(expected_output)
    expected_output_custom_padding[aname] = np.array(
        [
            [-2, -2, -2, -2],
            [1, 1, -2, -2],
            [-1, -1, 2, 2],
            [-2, -2, -2, -2],
        ],
        dtype=np.float32).reshape(4, 2, 2, 1)

    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenSequenceFeature(
                    (2, 1),
                    dtype=dtypes.float32,
                    allow_missing=True,
                    default_value=-2.0),
            bname:
                parsing_ops.FixedLenSequenceFeature(
                    (1, 1, 1), dtype=dtypes.string, allow_missing=True),
            cname:
                parsing_ops.FixedLenSequenceFeature(
                    shape=[], dtype=dtypes.int64, allow_missing=True),
            dname:
                parsing_ops.FixedLenSequenceFeature(
                    shape=[], dtype=dtypes.string, allow_missing=True),
        }, expected_output_custom_padding)

    # Change number of required values so the inputs are not a
    # multiple of this size.
    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenSequenceFeature(
                    (2, 1), dtype=dtypes.float32, allow_missing=True),
            bname:
                parsing_ops.FixedLenSequenceFeature(
                    (2, 1, 1), dtype=dtypes.string, allow_missing=True),
        },
        expected_err=(
            errors_impl.OpError, "Key: b, Index: 2.  "
            "Number of bytes values is not a multiple of stride length."))

    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
            bname:
                parsing_ops.FixedLenSequenceFeature(
                    (2, 1, 1), dtype=dtypes.string, allow_missing=True),
        },
        expected_err=(ValueError,
                      "First dimension of shape for feature a unknown. "
                      "Consider using FixedLenSequenceFeature."))

    self._test(
        ops.convert_to_tensor(serialized), {
            cname:
                parsing_ops.FixedLenFeature(
                    (1, None), dtype=dtypes.int64, default_value=[[1]]),
        },
        expected_err=(ValueError,
                      "All dimensions of shape for feature c need to be known "
                      r"but received \(1, None\)."))

    self._test(
        ops.convert_to_tensor(serialized), {
            aname:
                parsing_ops.FixedLenSequenceFeature(
                    (2, 1), dtype=dtypes.float32, allow_missing=True),
            bname:
                parsing_ops.FixedLenSequenceFeature(
                    (1, 1, 1), dtype=dtypes.string, allow_missing=True),
            cname:
                parsing_ops.FixedLenSequenceFeature(
                    shape=[], dtype=dtypes.int64, allow_missing=False),
            dname:
                parsing_ops.FixedLenSequenceFeature(
                    shape=[], dtype=dtypes.string, allow_missing=True),
        },
        expected_err=(ValueError,
                      "Unsupported: FixedLenSequenceFeature requires "
                      "allow_missing to be True."))

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingRaggedFeatureWithNoPartitions(self):
    original = [
        example(
            features=features({
                "rt_c": float_feature([3, 4, 5, 6, 7, 8]),
            })),
        example(
            features=features({
                "rt_c": float_feature([]),  # empty float list
            })),
        example(
            features=features({
                "rt_d": feature(),  # feature with nothing in it
            })),
        example(
            features=features({
                "rt_c": float_feature([1, 2, -1]),
                "rt_d": bytes_feature([b"hi"]),
            }))
    ]

    serialized = [m.SerializeToString() for m in original]

    expected_rt_c = ragged_factory_ops.constant_value(
        [[3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [], [], [1.0, 2.0, -1.0]],
        row_splits_dtype=dtypes.int32)
    expected_rt_d = ragged_factory_ops.constant_value(
        [[], [], [], [b"hi"]], row_splits_dtype=dtypes.int64)

    expected_output = {
        "rt_c": expected_rt_c,
        "rt_d": expected_rt_d,
    }

    self._test(
        ops.convert_to_tensor(serialized), {
            "rt_c":
                parsing_ops.RaggedFeature(dtypes.float32),
            "rt_d":
                parsing_ops.RaggedFeature(
                    dtypes.string, row_splits_dtype=dtypes.int64),
        },
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingRaggedFeatureWithOnePartition(self):
    original = [
        example(
            features=features({
                # rt = [[3], [4, 5, 6]]
                "rt_values": float_feature([3, 4, 5, 6]),
                "rt_splits": int64_feature([0, 1, 4]),
                "rt_lengths": int64_feature([1, 3]),
                "rt_starts": int64_feature([0, 1]),
                "rt_limits": int64_feature([1, 4]),
                "rt_rowids": int64_feature([0, 1, 1, 1]),
            })),
        example(
            features=features({
                # rt = []
                "rt_values": float_feature([]),
                "rt_splits": int64_feature([0]),
                "rt_lengths": int64_feature([]),
                "rt_starts": int64_feature([]),
                "rt_limits": int64_feature([]),
                "rt_rowids": int64_feature([]),
            })),
        example(
            features=features({
                # rt = []
                "rt_values": feature(),  # feature with nothing in it
                "rt_splits": int64_feature([0]),
                "rt_lengths": feature(),
                "rt_starts": feature(),
                "rt_limits": feature(),
                "rt_rowids": feature(),
            })),
        example(
            features=features({
                # rt = [[1.0, 2.0, -1.0], [], [8.0, 9.0], [5.0]]
                "rt_values": float_feature([1, 2, -1, 8, 9, 5]),
                "rt_splits": int64_feature([0, 3, 3, 5, 6]),
                "rt_lengths": int64_feature([3, 0, 2, 1]),
                "rt_starts": int64_feature([0, 3, 3, 5]),
                "rt_limits": int64_feature([3, 3, 5, 6]),
                "rt_rowids": int64_feature([0, 0, 0, 2, 2, 3]),
            }))
    ]
    serialized = [m.SerializeToString() for m in original]

    test_features = {
        "rt1":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowSplits("rt_splits")],
                dtype=dtypes.float32),
        "rt2":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowLengths("rt_lengths")],
                dtype=dtypes.float32),
        "rt3":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowStarts("rt_starts")],
                dtype=dtypes.float32),
        "rt4":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.RowLimits("rt_limits")],
                dtype=dtypes.float32),
        "rt5":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.ValueRowIds("rt_rowids")],
                dtype=dtypes.float32),
        "uniform1":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[parsing_ops.RaggedFeature.UniformRowLength(2)],
                dtype=dtypes.float32),
        "uniform2":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[
                    parsing_ops.RaggedFeature.UniformRowLength(2),
                    parsing_ops.RaggedFeature.RowSplits("rt_splits")
                ],
                dtype=dtypes.float32),
    }

    expected_rt = ragged_factory_ops.constant(
        [[[3], [4, 5, 6]], [], [], [[1, 2, -1], [], [8, 9], [5]]],
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int32)

    expected_uniform1 = ragged_factory_ops.constant(
        [[[3, 4], [5, 6]], [], [], [[1, 2], [-1, 8], [9, 5]]],
        ragged_rank=1,
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int32)

    expected_uniform2 = ragged_factory_ops.constant(
        [[[[3], [4, 5, 6]]], [], [], [[[1, 2, -1], []], [[8, 9], [5]]]],
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int32)

    expected_output = {
        "rt1": expected_rt,
        "rt2": expected_rt,
        "rt3": expected_rt,
        "rt4": expected_rt,
        "rt5": expected_rt,
        "uniform1": expected_uniform1,
        "uniform2": expected_uniform2,
    }

    self._test(
        ops.convert_to_tensor(serialized),
        test_features,
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(test_base.default_test_combinations())
  def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
    original = [
        # rt shape: [(batch), 2, None, None]
        example(
            features=features({
                # rt = [[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]]
                "rt_values": float_feature([1, 2, 3, 4, 5, 6, 7]),
                "lengths_axis2": int64_feature([1, 2, 0, 1]),
                "lengths_axis3": int64_feature([1, 2, 1, 3]),
                "splits_axis3": int64_feature([0, 1, 3, 4, 7]),
            })),
        example(
            features=features({
                # rt = [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]
                "rt_values": float_feature([1, 2, 3, 4, 5, 6, 7, 8]),
                "lengths_axis2": int64_feature([2, 3]),
                "lengths_axis3": int64_feature([3, 1, 1, 1, 2]),
                "splits_axis3": int64_feature([0, 3, 4, 5, 6, 8]),
            }))
    ]
    serialized = [m.SerializeToString() for m in original]

    test_features = {
        "rt1":
            parsing_ops.RaggedFeature(
                value_key="rt_values",
                partitions=[
                    parsing_ops.RaggedFeature.UniformRowLength(2),
                    parsing_ops.RaggedFeature.RowLengths("lengths_axis2"),
                    parsing_ops.RaggedFeature.RowSplits("splits_axis3"),
                ],
                dtype=dtypes.float32,
                row_splits_dtype=dtypes.int64,
            ),
    }

    expected_rt = ragged_factory_ops.constant(
        [[[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]],
         [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]],
        dtype=dtypes.float32,
        row_splits_dtype=dtypes.int64)

    expected_output = {
        "rt1": expected_rt,
    }

    self._test(
        ops.convert_to_tensor(serialized),
        test_features,
        expected_values=expected_output,
        create_iterator_twice=True)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(
              local_determinism=[None, True, False],
              global_determinism=[True, False])))
  def testDeterminism(self, local_determinism, global_determinism):
    num_elements = 1000
    batches = []
    for i in range(num_elements):
      example_i = example(features=features({
          "a": int64_feature([i]),
      }))
      batches.append([example_i.SerializeToString()])

    test_features = {"a": parsing_ops.FixedLenFeature((), dtype=dtypes.int64)}
    dataset = dataset_ops.Dataset.from_tensor_slices(batches)
    dataset = dataset.apply(
        contrib_parsing_ops.parse_example_dataset(
            test_features,
            num_parallel_calls=10,
            deterministic=local_determinism))

    opts = dataset_ops.Options()
    opts.experimental_deterministic = global_determinism
    dataset = dataset.with_options(opts)

    expected = list(range(num_elements))
    actual = [elem["a"][0] for elem in self.getDatasetOutput(dataset)]

    require_order = local_determinism or (local_determinism is None and
                                          global_determinism)
    if require_order:
      self.assertAllEqual(expected, actual)
    else:
      self.assertCountEqual(expected, actual)
示例#23
0
class TraverseTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testOnlySource(self):
        ds = dataset_ops.Dataset.range(10)
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
        self.assertAllEqual(["RangeDataset"],
                            [x.name for x in variant_tensor_ops])

    @combinations.generate(test_base.graph_only_combinations())
    def testSimplePipeline(self):
        ds = dataset_ops.Dataset.range(10).map(math_ops.square)
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
        self.assertSetEqual(set(["MapDataset", "RangeDataset"]),
                            set(x.name for x in variant_tensor_ops))

    @combinations.generate(test_base.graph_only_combinations())
    def testConcat(self):
        ds1 = dataset_ops.Dataset.range(10)
        ds2 = dataset_ops.Dataset.range(10)
        ds = ds1.concatenate(ds2)
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
        self.assertSetEqual(
            set(["ConcatenateDataset", "RangeDataset", "RangeDataset_1"]),
            set(x.name for x in variant_tensor_ops))

    @combinations.generate(test_base.graph_only_combinations())
    def testZip(self):
        ds1 = dataset_ops.Dataset.range(10)
        ds2 = dataset_ops.Dataset.range(10)
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
        self.assertSetEqual(
            set(["ZipDataset", "RangeDataset", "RangeDataset_1"]),
            set(x.name for x in variant_tensor_ops))

    @combinations.generate(test_base.graph_only_combinations())
    def testMultipleVariantTensors(self):
        ds = dataset_ops.Dataset.range(10)
        ds = _TestDataset(ds)
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds)
        self.assertSetEqual(
            set(["RangeDataset", "ModelDataset", "PrefetchDataset"]),
            set(x.name for x in variant_tensor_ops))

    @combinations.generate(test_base.graph_only_combinations())
    def testFlatMap(self):
        ds1 = dataset_ops.Dataset.range(10).repeat(10)

        def map_fn(ds):
            def _map(x):
                return ds.batch(x)

            return _map

        ds2 = dataset_ops.Dataset.range(20).prefetch(1)
        ds2 = ds2.flat_map(map_fn(ds1))
        variant_tensor_ops = traverse.obtain_all_variant_tensor_ops(ds2)
        self.assertSetEqual(
            set([
                "FlatMapDataset", "PrefetchDataset", "RepeatDataset",
                "RangeDataset", "RangeDataset_1"
            ]), set(x.name for x in variant_tensor_ops))

    @combinations.generate(test_base.graph_only_combinations())
    def testTfDataService(self):
        ds = dataset_ops.Dataset.range(10)
        ds = ds.apply(
            data_service_ops.distribute("parallel_epochs", "grpc://foo:0"))
        ops = traverse.obtain_capture_by_value_ops(ds)
        self.assertContainsSubset(
            ["RangeDataset", "DataServiceDatasetV2", "DummyIterationCounter"],
            set(x.name for x in ops))
示例#24
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any(node.op == "RangeDataset" for node in graph.node))

    def testAsSerializedGraphStateful(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda _: random_ops.random_uniform(()))
        with self.assertRaises(errors.FailedPreconditionError):
            self.evaluate(
                dataset._as_serialized_graph(
                    external_state_policy=distribute_options.
                    ExternalStatePolicy.FAIL))

    @combinations.generate(test_base.default_test_combinations())
    def testAsFunctionWithMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).map(
                lambda x: x * 2)
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

    @combinations.generate(test_base.default_test_combinations())
    def testAsFunctionWithMapInFlatMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).flat_map(
                lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, list(original_dataset))

    def _testNumInputs(self, dataset, num_inputs):
        self.assertLen(dataset._inputs(), num_inputs)

    @combinations.generate(test_base.default_test_combinations())
    def testFixedLengthRecordInputs(self):
        dataset = readers.FixedLengthRecordDataset("", 42)
        self._testNumInputs(dataset, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorInputs(self):
        def gen():
            yield 42

        dataset = dataset_ops.Dataset.from_generator(gen, dtypes.int32)
        self._testNumInputs(dataset, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testFromTensorsInputs(self):
        dataset = dataset_ops.Dataset.from_tensors([42])
        self._testNumInputs(dataset, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testRangeInputs(self):
        dataset = dataset_ops.Dataset.range(10)
        self._testNumInputs(dataset, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testTextLineInputs(self):
        dataset = readers.TextLineDataset("")
        self._testNumInputs(dataset, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordInputs(self):
        dataset = readers.TFRecordDataset("")
        self._testNumInputs(dataset, 1)

    @combinations.generate(
        combinations.combine(tf_api_version=1, mode=["eager", "graph"]))
    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEmpty(dataset_fn._inputs())

    def _testUnaryInputs(self, dataset_fn):
        input_dataset = dataset_ops.Dataset.range(0)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @combinations.generate(test_base.default_test_combinations())
    def testBatchInputs(self):
        self._testUnaryInputs(lambda x: x.batch(10))

    @combinations.generate(test_base.default_test_combinations())
    def testCacheInputs(self):
        self._testUnaryInputs(lambda x: x.cache())

    @combinations.generate(test_base.default_test_combinations())
    def testFilterInputs(self):
        self._testUnaryInputs(lambda x: x.filter(lambda x: True))

    @combinations.generate(test_base.default_test_combinations())
    def testFlatMapInputs(self):
        self._testUnaryInputs(
            lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)))

    @combinations.generate(test_base.default_test_combinations())
    def testMapInputs(self):
        self._testUnaryInputs(lambda x: x.map(lambda x: x))

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchInputs(self):
        self._testUnaryInputs(lambda x: x.padded_batch(10, []))

    @combinations.generate(test_base.default_test_combinations())
    def testParallelMapInputs(self):
        self._testUnaryInputs(
            lambda x: x.map(lambda x: x, num_parallel_calls=2))

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatInputs(self):
        self._testUnaryInputs(lambda x: x.repeat())

    @combinations.generate(test_base.default_test_combinations())
    def testShuffleInputs(self):
        self._testUnaryInputs(lambda x: x.shuffle(10))

    @combinations.generate(test_base.default_test_combinations())
    def testSkipInputs(self):
        self._testUnaryInputs(lambda x: x.skip(1))

    @combinations.generate(test_base.default_test_combinations())
    def testTakeInputs(self):
        self._testUnaryInputs(lambda x: x.take(1))

    @combinations.generate(test_base.default_test_combinations())
    def testWindowInputs(self):
        self._testUnaryInputs(lambda x: x.window(10))

    @combinations.generate(test_base.default_test_combinations())
    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset = input_dataset.apply(lambda dataset: dataset.cache())

        self.assertEqual([input_dataset], dataset._inputs())

    def _testInputsWithInterleaveFn(self, dataset_fn, interleave_parallelism):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset = input_dataset.interleave(
            lambda x: dataset_ops.Dataset.range(0),
            cycle_length=2,
            num_parallel_calls=interleave_parallelism)
        self.assertEqual([input_dataset], dataset._inputs())

    @combinations.generate(test_base.default_test_combinations())
    def testParallelInterleaveInputs(self):
        self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), 2)

    @combinations.generate(test_base.default_test_combinations())
    def testInterleaveInputs(self):
        self._testInputsWithInterleaveFn(lambda: dataset_ops.range(0), None)

    @combinations.generate(test_base.default_test_combinations())
    def testNoWarnings(self):
        with test.mock.patch.object(warnings, "warn") as mock_log:
            dataset_ops.Dataset.range(0).interleave(
                lambda x: dataset_ops.Dataset.range(0), cycle_length=2)
            self.assertEmpty(mock_log.call_args_list)

    def _testBinaryInputs(self, dataset_fn):
        input1 = dataset_ops.Dataset.range(0)
        input2 = dataset_ops.Dataset.range(1)
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @combinations.generate(test_base.default_test_combinations())
    def testConcatenateInputs(self):
        self._testBinaryInputs(lambda x, y: x.concatenate(y))

    def _testVariadicInputs(self, dataset_fn, input_datasets):
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    @combinations.generate(test_base.default_test_combinations())
    def testZipOneInputs(self):
        input_datasets = dataset_ops.Dataset.range(0)
        self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)

    @combinations.generate(test_base.default_test_combinations())
    def testZipNestInputs(self):
        input_datasets = (dataset_ops.Dataset.range(0),
                          (dataset_ops.Dataset.range(1),
                           dataset_ops.Dataset.range(2)))
        self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)

    @combinations.generate(test_base.default_test_combinations())
    def testZipTupleInputs(self):
        input_datasets = (dataset_ops.Dataset.range(0),
                          dataset_ops.Dataset.range(1))
        self._testVariadicInputs(dataset_ops.Dataset.zip, input_datasets)

    @combinations.generate(test_base.default_test_combinations())
    def testFunctions(self):
        dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        self.assertLen(dataset._functions(), 1)

    @combinations.generate(test_base.default_test_combinations())
    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    def _testDatasetSpec(self, tf_value, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
        dataset_structure = structure.type_spec_from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset),
                                     expected_element_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(dataset_structure))
        self.assertEqual([tensor_shape.TensorShape([])],
                         structure.get_flat_tensor_shapes(dataset_structure))

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value)],
                                       requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testTensorDatasetSpec(self):
        self._testDatasetSpec(constant_op.constant(37.0),
                              tensor_spec.TensorSpec([], dtypes.float32))

    @combinations.generate(test_base.default_test_combinations())
    def testSparseTensorDatasetSpec(self):
        self._testDatasetSpec(
            sparse_tensor.SparseTensor(indices=[[0]],
                                       values=constant_op.constant(
                                           [0], dtype=dtypes.int32),
                                       dense_shape=[1]),
            sparse_tensor.SparseTensorSpec([1], dtypes.int32))

    @combinations.generate(test_base.default_test_combinations())
    def testNestDatasetSpec(self):
        self._testDatasetSpec(
            {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }, {
                "a":
                tensor_spec.TensorSpec([], dtypes.float32),
                "b": (
                    tensor_spec.TensorSpec([1], dtypes.string),
                    tensor_spec.TensorSpec([], dtypes.string),
                )
            })

    @combinations.generate(test_base.default_test_combinations())
    def testDatasetDatasetSpec(self):
        self._testDatasetSpec(
            dataset_ops.Dataset.from_tensor_slices(
                constant_op.constant([1, 2, 3])),
            dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32)))

    @combinations.generate(test_base.default_test_combinations())
    def testOptionalDatasetSpec(self):
        self._testDatasetSpec(
            optional_ops.Optional.from_value(37.0),
            optional_ops.OptionalSpec(
                tensor_spec.TensorSpec([], dtypes.float32)))

    @combinations.generate(test_base.graph_only_combinations())
    def testSameGraphError(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegex(ValueError,
                                        "must be from the same graph"):
                dataset = dataset.batch(2)

    @combinations.generate(
        combinations.combine(tf_api_version=[1], mode=["graph"]))
    def testSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegex(
                    ValueError,
                    "Please ensure that all datasets in the pipeline are "
                    "created in the same graph as the iterator."):
                _ = dataset_ops.make_one_shot_iterator(dataset)

    @combinations.generate(
        combinations.combine(tf_api_version=[1], mode=["graph"]))
    def testSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegex(
                    ValueError,
                    "Please ensure that all datasets in the pipeline are "
                    "created in the same graph as the iterator."):
                _ = dataset_ops.make_initializable_iterator(dataset)

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(execution_mode=[context.ASYNC, context.SYNC]))
    )
    def testEagerIteration(self, execution_mode):
        with context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            for foo in dataset:
                self.assertEqual(val, foo.numpy())
                val += 1

    @combinations.generate(test_base.default_test_combinations())
    def testDatasetAsFunctionArgument(self):
        @def_function.function
        def _uses_dataset(d):
            accumulator = array_ops.zeros([], dtype=dtypes.int64)
            for value in d:
                accumulator += value
            return accumulator

        with ops.device("CPU"):
            first_dataset = dataset_ops.Dataset.range(10)
            self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
            second_dataset = dataset_ops.Dataset.range(11)
            self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
            first_concrete = _uses_dataset.get_concrete_function(first_dataset)
            # The dataset should not be a captured input
            self.assertEmpty(first_concrete.graph.captures)
            # The two datasets have the same structure and so should re-use a trace.
            self.assertIs(first_concrete,
                          _uses_dataset.get_concrete_function(second_dataset))
            # With a different structure we should use a different trace.
            self.assertIsNot(
                first_concrete,
                _uses_dataset.get_concrete_function(
                    dataset_ops.Dataset.zip((first_dataset, second_dataset))))

    @combinations.generate(test_base.default_test_combinations())
    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(ds):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in ds:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(dataset)), 10)
            self.assertEqual(self.evaluate(f(dataset2)), 45)
            self.assertEqual(trace_count[0], 1)

    # pylint: disable=g-long-lambda,unnecessary-lambda
    @combinations.generate(test_base.default_test_combinations())
    def testLegacyStructureAPI(self):
        components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
                                                            np.array([6.,
                                                                      7.])),
                      np.array([8, 9, 10], dtype=np.int64))

        dataset = dataset_ops.Dataset.from_tensors(components)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.shuffle(10, 10)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.repeat(-1)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.filter(lambda x, y, z: True)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.take(5)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([3], ([2], [2]), [3]),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
        self.assertEqual(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual((([3], [3]), ([2], [2])),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.flat_map(lambda x, y: dataset_ops.Dataset.
                                   from_tensors(((x[0], x[1]), (y[0], y[1]))))
        self.assertEqual(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual((([3], [3]), ([2], [2])),
                         dataset_ops.get_legacy_output_shapes(dataset))

        dataset = dataset.batch(32)
        self.assertEqual(
            ((dtypes.int64, dtypes.int64), (dtypes.float64, dtypes.float64)),
            dataset_ops.get_legacy_output_types(dataset))
        dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
        self.assertEqual(
            (([None, 3], [None, 3]), ([None, 2], [None, 2])),
            nest.pack_sequence_as(
                dataset_output_shapes,
                [s.as_list() for s in nest.flatten(dataset_output_shapes)]))

        # Define a separate set of components with matching leading
        # dimension for the from-slices constructor.
        components_for_slices = (np.array([1, 2, 3], dtype=np.int64),
                                 (np.array([4., 5.,
                                            6.]), np.array([7., 8., 9.])),
                                 np.array([10, 11, 12], dtype=np.int64))

        dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
        self.assertEqual(
            (dtypes.int64, (dtypes.float64, dtypes.float64), dtypes.int64),
            dataset_ops.get_legacy_output_types(dataset))
        self.assertEqual(([], ([], []), []),
                         dataset_ops.get_legacy_output_shapes(dataset))

    @combinations.generate(test_base.default_test_combinations())
    def testNoneComponent(self):
        dataset = dataset_ops.Dataset.from_tensors((42, None))
        if context.executing_eagerly():
            self.assertDatasetProduces(dataset, expected_output=[(42, None)])
        else:
            iterator = dataset_ops.make_one_shot_iterator(dataset)
            next_first, next_second = iterator.get_next()
            self.assertEqual(next_second, None)
            with self.cached_session() as sess:
                self.assertEqual(sess.run(next_first), 42)

    @combinations.generate(test_base.default_test_combinations())
    def testNoneComponentInFunction(self):
        @def_function.function
        def fn(ds):
            total = 0
            it = iter(ds)
            for elem in it:
                x, _ = elem
                total += x
            return total

        dataset = dataset_ops.Dataset.range(
            10, output_type=dtypes.int32).map(lambda x: (x, None))
        self.assertEqual(self.evaluate(fn(dataset)), 45)

    @combinations.generate(test_base.default_test_combinations())
    def testIncorrectPythonStructure(self):
        # Tests that an exception is raised (as opposed to a segfault) when the
        # Python structure assigned to a dataset is incorrect.
        dataset = dataset_ops.Dataset.range(10)
        spec = tensor_spec.TensorSpec([], dtypes.int64)
        new_structure = (spec, spec)
        dataset = dataset_ops._RestructuredDataset(dataset, new_structure)
        dataset = dataset.map(lambda x, y: y)

        with self.assertRaisesOpError(""):
            self.getDatasetOutput(dataset)
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationStatefulFunction(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda _: random_ops.random_uniform([])).batch(10)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        get_next = self.getNext(dataset)
        self.evaluate(get_next())

    # TODO(b/123902160)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensor(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
        dataset = dataset_ops.Dataset.from_tensors(input_t)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    # TODO(b/123902160)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensorSlices(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
        dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op,
                     {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDataset(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
            dataset = dataset.skip(0)  # Should be removed by noop elimination
            dataset = dataset.cache()
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[0])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDatasetWithModifiedRetval(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
            # Should be fused by map and batch fusion
            dataset = dataset.map(lambda x: x)
            dataset = dataset.batch(1)
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)

        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_batch_fusion = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[[0]])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDisableIntraOpParallelism(self):
        os.environ[
            "TF_DATA_EXPERIMENT_OPT_IN"] = "disable_intra_op_parallelism"
        os.environ["TF_JOB_NAME"] = "test_job"

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"]))

        options = dataset_ops.Options()
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=list(range(10)))

        del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
        del os.environ["TF_JOB_NAME"]

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationThreadPoolDataset(self):
        dataset = dataset_ops.Dataset.range(10).batch(10)

        dataset = threadpool.override_threadpool(
            dataset,
            threadpool.PrivateThreadPool(
                2, display_name="private_thread_pool_%d" % 2))

        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset,
                                   expected_output=[list(range(10))],
                                   requires_initialization=True)

    # Reference variables are not supported in eager mode.
    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           _captured_refvar_test_combinations()))
    def testOptimizationWithCapturedRefVar(self, dataset_fn):
        """Tests that default optimizations are disabled with ref variables."""
        variable = variable_scope.get_variable("v",
                                               initializer=0,
                                               use_resource=False)
        assign_op = variable.assign_add(1)

        # Check that warning is logged.
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            unoptimized_dataset = dataset_fn(variable)

            options = dataset_ops.Options()
            options.experimental_optimization.apply_default_optimizations = False
            options.experimental_optimization.noop_elimination = True
            options.experimental_optimization.map_and_batch_fusion = True
            optimized_dataset = unoptimized_dataset.with_options(options)
            optimized_it = dataset_ops.make_initializable_iterator(
                optimized_dataset)

        self.assertGreaterEqual(len(w), 1)
        graph_rewrites = options._graph_rewrites()
        expected = (
            "tf.data graph rewrites are not compatible with "
            "tf.Variable. The following rewrites will be disabled: %s."
            " To enable rewrites, use resource variables instead by "
            "calling `tf.enable_resource_variables()` at the start of the "
            "program." %
            (", ".join(graph_rewrites.enabled + graph_rewrites.default)))
        self.assertTrue(any(expected in str(warning) for warning in w))

        # Check that outputs are the same in the optimized and unoptimized cases,
        # when the variable value is changing.
        unoptimized_it = dataset_ops.make_initializable_iterator(
            unoptimized_dataset)
        with ops.control_dependencies([assign_op]):
            unoptimized_output = unoptimized_it.get_next()
            optimized_output = optimized_it.get_next()

        self.evaluate(variable.initializer)
        self.evaluate((unoptimized_it.initializer, optimized_it.initializer))
        while True:
            try:
                unoptimized, optimized = self.evaluate(
                    (unoptimized_output, optimized_output))
                self.assertEqual(unoptimized, optimized)
            except errors.OutOfRangeError:
                break

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDefault(self):
        """Tests the optimization settings by default."""
        options = dataset_ops.Options()
        expected_optimizations_enabled = []
        expected_optimizations_disabled = []
        expected_optimizations_default = [
            "map_and_batch_fusion",
            "noop_elimination",
            "shuffle_and_repeat_fusion",
        ]
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

        options.experimental_optimization.apply_default_optimizations = True
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

        options.experimental_optimization.apply_default_optimizations = False
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationEnabled(self):
        """Tests the optimization settings by enabling all."""
        options = dataset_ops.Options()
        options.experimental_optimization.filter_fusion = True
        options.experimental_optimization.filter_with_random_uniform_fusion = True
        options.experimental_optimization.hoist_random_uniform = True
        options.experimental_optimization.map_and_batch_fusion = True
        options.experimental_optimization.map_and_filter_fusion = True
        options.experimental_optimization.map_parallelization = True
        options.experimental_optimization.map_fusion = True
        options.experimental_optimization.noop_elimination = True
        options.experimental_optimization.parallel_batch = True
        options.experimental_optimization.shuffle_and_repeat_fusion = True
        options.experimental_optimization.map_vectorization.enabled = True
        options.experimental_optimization.autotune_buffers = True
        options.experimental_deterministic = False
        options.experimental_stats.latency_all_edges = True
        options.experimental_slack = True

        expected_optimizations_enabled = [
            "filter_fusion",
            "filter_with_random_uniform_fusion",
            "hoist_random_uniform",
            "map_and_batch_fusion",
            "map_and_filter_fusion",
            "map_parallelization",
            "map_fusion",
            "noop_elimination",
            "parallel_batch",
            "shuffle_and_repeat_fusion",
            "map_vectorization",
            "inject_prefetch",
            "make_sloppy",
            "latency_all_edges",
            "slack",
        ]
        expected_optimizations_disabled = []
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDisabled(self):
        """Tests the optimization settings by disabling all."""
        options = dataset_ops.Options()
        options.experimental_optimization.filter_fusion = False
        options.experimental_optimization.filter_with_random_uniform_fusion = False
        options.experimental_optimization.hoist_random_uniform = False
        options.experimental_optimization.map_and_batch_fusion = False
        options.experimental_optimization.map_and_filter_fusion = False
        options.experimental_optimization.map_parallelization = False
        options.experimental_optimization.map_fusion = False
        options.experimental_optimization.noop_elimination = False
        options.experimental_optimization.parallel_batch = False
        options.experimental_optimization.shuffle_and_repeat_fusion = False
        options.experimental_optimization.map_vectorization.enabled = False
        options.experimental_optimization.autotune = False
        options.experimental_deterministic = True
        options.experimental_stats.latency_all_edges = False
        options.experimental_slack = False

        expected_optimizations_enabled = []
        expected_optimizations_disabled = [
            "filter_fusion",
            "filter_with_random_uniform_fusion",
            "hoist_random_uniform",
            "map_and_batch_fusion",
            "map_and_filter_fusion",
            "map_parallelization",
            "map_fusion",
            "noop_elimination",
            "parallel_batch",
            "shuffle_and_repeat_fusion",
            "map_vectorization",
            "inject_prefetch",
            "make_sloppy",
            "latency_all_edges",
            "slack",
        ]
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testAutotuningDefaults(self):
        options = dataset_ops.Options()

        # Check defaults
        autotune, algorithm, cpu_budget = options._autotune_settings()
        self.assertTrue(autotune)
        self.assertEqual(algorithm,
                         optimization_options._AutotuneAlgorithm.HILL_CLIMB)
        self.assertEqual(cpu_budget, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testAutotuningBufferSizes(self):
        options = dataset_ops.Options()
        options.experimental_optimization.autotune_buffers = True
        self.assertIn("inject_prefetch", options._graph_rewrites().enabled)
        autotune, algorithm, cpu_budget = options._autotune_settings()
        self.assertTrue(autotune)
        self.assertEqual(
            algorithm,
            optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
        self.assertEqual(cpu_budget, 0)
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
                              parameterized.TestCase):
    def setUp(self):
        super(MultiDeviceIteratorTest, self).setUp()
        self._devices = self.configureDevicesForMultiDeviceTest(3)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(num_inits=[0, 1, 42])))
    def testInitOnly(self, num_inits):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        for _ in range(num_inits):
            self.evaluate(multi_device_iterator.initializer)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(max_buffer_size=[0, 1, 10],
                                 prefetch_buffer_size=[0, 1, 10])))
    def testBasic(self, prefetch_buffer_size, max_buffer_size):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]],
            max_buffer_size=max_buffer_size,
            prefetch_buffer_size=prefetch_buffer_size)

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testOneOnSameDevice(self):
        dataset = dataset_ops.Dataset.range(12)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[0], self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 12, 3):
            elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_0))
            self.assertEqual(i + 1, self.evaluate(elem_on_1))
            self.assertEqual(i + 2, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_0)
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatDevices(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[1]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elements = multi_device_iterator.get_next()
            elem_on_1, elem_on_2 = elements
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elements = multi_device_iterator.get_next()
            elem_on_1, elem_on_2 = elements
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testNotFullyDivisible(self):
        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 8, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        elem_on_1 = multi_device_iterator.get_next(self._devices[1])
        self.assertEqual(8, self.evaluate(elem_on_1))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testGetNextAsOptional(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
            has_elem_1, get_elem_1 = self.evaluate(
                [elem_on_1.has_value(),
                 elem_on_1.get_value()])
            has_elem_2, get_elem_2 = self.evaluate(
                [elem_on_2.has_value(),
                 elem_on_2.get_value()])
            self.assertTrue(has_elem_1)
            self.assertEqual(i, get_elem_1)
            self.assertTrue(has_elem_2)
            self.assertEqual(i + 1, get_elem_2)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        has_elem_1 = elem_on_1.has_value()
        has_elem_2 = elem_on_2.has_value()
        self.assertFalse(self.evaluate(has_elem_1))
        self.assertFalse(self.evaluate(has_elem_2))
        with self.assertRaises(errors.InvalidArgumentError):
            elem_1 = elem_on_1.get_value()
            self.evaluate(elem_1)
        with self.assertRaises(errors.InvalidArgumentError):
            elem_2 = elem_on_2.get_value()
            self.evaluate(elem_2)

    @combinations.generate(test_base.default_test_combinations())
    def testUneven(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]], max_buffer_size=4)

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1 = multi_device_iterator.get_next(self._devices[1])
            self.assertEqual(i, self.evaluate(elem_on_1))
        for i in range(0, 10, 2):
            elem_on_2 = multi_device_iterator.get_next(self._devices[2])
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.graph_only_combinations())
    def testMultipleInitializationsGraph(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]],
            prefetch_buffer_size=4)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()

        for _ in range(5):
            self.evaluate(multi_device_iterator.initializer)
            self.assertEqual([(0, 0), (1, 1)],
                             self.evaluate([elem_on_1, elem_on_2]))

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleInitializationsEager(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))

        for _ in range(5):
            multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
                dataset, [self._devices[1], self._devices[2]],
                prefetch_buffer_size=4)
            self.evaluate(multi_device_iterator.initializer)
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual([(0, 0), (1, 1)],
                             self.evaluate([elem_on_1, elem_on_2]))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimization(self):
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
        dataset = dataset.skip(0)  # this should be optimized away
        dataset = dataset.cache()

        options = options_lib.Options()
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)

        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)
示例#27
0
class IteratorClusterTest(test.TestCase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorWithoutRemoteCallFail(self):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 2
        worker, _ = test_util.create_local_cluster(1,
                                                   1,
                                                   worker_config=worker_config)

        with ops.device("/job:worker/replica:0/task:0/cpu:1"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        with ops.device("/job:worker/replica:0/task:0/cpu:0"):
            remote_it = iterator_ops.Iterator.from_string_handle(
                iterator_3_handle,
                dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            get_next_op = remote_it.get_next()

        with session.Session(worker[0].target) as sess:
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(get_next_op)

    def _testRemoteIteratorHelper(self, device0, device1, target):
        with ops.device(device1):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        @function.Defun(dtypes.string)
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device(device0):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            remote_op = functional_ops.remote_call(args=[iterator_3_handle],
                                                   Tout=[dtypes.int32],
                                                   f=_remote_fn,
                                                   target=target_placeholder)

        with session.Session(target) as sess:
            elem = sess.run(remote_op, feed_dict={target_placeholder: device1})
            self.assertEqual(elem, [1])
            # Fails when target is cpu:0 where the resource is not located.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(remote_op, feed_dict={target_placeholder: device0})
            elem = sess.run(iterator_3.get_next())
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op, feed_dict={target_placeholder: device1})
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op, feed_dict={target_placeholder: device1})

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOp(self):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 2
        worker, _ = test_util.create_local_cluster(1,
                                                   1,
                                                   worker_config=worker_config)

        self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0",
                                       "/job:worker/replica:0/task:0/cpu:1",
                                       worker[0].target)

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpCrossProcess(self):
        workers, _ = test_util.create_local_cluster(2, 1)

        self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0",
                                       "/job:worker/replica:0/task:1/cpu:0",
                                       workers[0].target)

    @combinations.generate(test_base.graph_only_combinations())
    def testCaptureHashTableInSharedIterator(self):
        worker, _ = test_util.create_local_cluster(1, 1)

        # NOTE(mrry): We must use the V2 variants of `HashTable`
        # etc. because these produce a `tf.resource`-typed output that is
        # compatible with the in-graph function implementation.
        default_val = -1
        keys = constant_op.constant(["brain", "salad", "surgery"])
        values = constant_op.constant([0, 1, 2], dtypes.int64)
        table = lookup_ops.StaticHashTableV1(
            lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

        input_sentences = dataset_ops.Dataset.from_tensor_slices(
            ["brain brain tank salad surgery", "surgery brain"])

        dataset = input_sentences.map(
            lambda x: string_ops.string_split([x]).values).map(table.lookup)
        iterator = dataset_ops.make_initializable_iterator(
            dataset, shared_name="shared_iterator")
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with session.Session(worker[0].target) as sess:
            sess.run(table.initializer)
            sess.run(init_op)
            self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))

        with session.Session(worker[0].target) as sess:
            self.assertAllEqual([2, 0], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testImplicitDisposeParallelMapDataset(self):
        # Tests whether a parallel map dataset will be cleaned up correctly when
        # the pipeline does not run it until exhaustion.
        # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
        # RepeatDataset(None) -> PrefetchDataset(100).
        worker, _ = test_util.create_local_cluster(1, 1)

        components = (np.arange(1000),
                      np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                      np.array(37.0) * np.arange(1000))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        dataset = (dataset_ops.Dataset.from_tensor_slices(components).map(
            _map_fn).repeat(None).prefetch(10000))

        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with session.Session(worker[0].target) as sess:
            sess.run(init_op)
            for _ in range(3):
                sess.run(get_next)
class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
                                   parameterized.TestCase):
    def setUp(self):
        super(OwnedMultiDeviceIteratorTest, self).setUp()
        self._devices = self.configureDevicesForMultiDeviceTest(3)

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(max_buffer_size=[0, 1, 10],
                                 prefetch_buffer_size=[0, 1, 10])))
    def testBasic(self, max_buffer_size, prefetch_buffer_size):
        dataset = dataset_ops.Dataset.range(1000)

        mdi = multi_device_iterator_ops.OwnedMultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]],
            max_buffer_size=max_buffer_size,
            prefetch_buffer_size=prefetch_buffer_size)

        for i, el in enumerate(mdi):
            self.assertEqual([i * 2, i * 2 + 1],
                             [el[0].numpy(), el[1].numpy()])

    @combinations.generate(test_base.eager_only_combinations())
    def testBasicFunction(self):
        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)

        @def_function.function
        def fn():
            with ops.device(self._devices[0]):
                dataset = dataset_ops.Dataset.range(10)
            iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
                dataset, [self._devices[1], self._devices[2]])
            for _ in range(5):
                el0, el1 = next(iterator)
                queue.enqueue(el0)
                queue.enqueue(el1)

        fn()

        for i in range(10):
            self.assertEqual(queue.dequeue().numpy(), i)

    @combinations.generate(test_base.eager_only_combinations())
    def testFunctionError(self):
        # In this test we verify that a function that raises an error ends up
        # properly deallocating the iterator resource.

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
        queue.enqueue(0)

        def init_fn(n):
            return n

        def next_fn(_):
            ds = dataset_ops.Dataset.range(0)
            return next(iter(ds))

        def finalize_fn(n):
            queue.enqueue(0)
            return n

        @def_function.function
        def fn():
            dataset = dataset_ops._GeneratorDataset(
                1,
                init_fn,
                next_fn,
                finalize_fn,
                output_signature=tensor_spec.TensorSpec([], dtypes.int64))
            iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
                dataset, [self._devices[1], self._devices[2]])
            next(iterator)

        with self.assertRaises(errors.OutOfRangeError):
            fn()

        self.assertEqual(queue.size().numpy(), 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleInitializations(self):
        dataset = dataset_ops.Dataset.range(1000)

        for _ in range(5):
            multi_device_iterator = (
                multi_device_iterator_ops.OwnedMultiDeviceIterator(
                    dataset, [self._devices[1], self._devices[2]]))
            for i, el in enumerate(multi_device_iterator):
                self.assertEqual([i * 2, i * 2 + 1],
                                 [el[0].numpy(), el[1].numpy()])

    @combinations.generate(test_base.eager_only_combinations())
    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(iterator):
            trace_count[0] += 1
            counter = np.int64(0)
            for _ in range(5):
                elem = next(iterator)
                counter += elem[0]
                counter += elem[1]
            return counter

        dataset = dataset_ops.Dataset.range(10)
        dataset2 = dataset_ops.Dataset.range(20)

        for _ in range(10):
            multi_device_iterator = (
                multi_device_iterator_ops.OwnedMultiDeviceIterator(
                    dataset, [self._devices[1], self._devices[2]]))
            self.assertEqual(self.evaluate(f(multi_device_iterator)), 45)
            multi_device_iterator2 = (
                multi_device_iterator_ops.OwnedMultiDeviceIterator(
                    dataset2, [self._devices[1], self._devices[2]]))
            self.assertEqual(self.evaluate(f(multi_device_iterator2)), 45)
            self.assertEqual(trace_count[0], 1)

    @combinations.generate(test_base.eager_only_combinations())
    def testMissingDevices(self):
        dataset = dataset_ops.Dataset.range(1000)
        with self.assertRaisesRegex(ValueError, "`devices` must be provided."):
            multi_device_iterator_ops.OwnedMultiDeviceIterator(dataset)

    @combinations.generate(test_base.eager_only_combinations())
    def testMissingInput(self):
        with self.assertRaisesRegex(
                ValueError,
                "When `dataset` is not provided, both `components` and `element_spec` "
                "must be specified."):
            multi_device_iterator_ops.OwnedMultiDeviceIterator(
                dataset=None, devices=[self._devices[1], self._devices[2]])

    @combinations.generate(test_base.eager_only_combinations())
    def testExtraElementSpecInput(self):
        dataset = dataset_ops.Dataset.range(1000)
        with self.assertRaisesRegex(
                ValueError,
                "When `dataset` is provided, `element_spec` and `components` must "
                "not be specified."):
            multi_device_iterator_ops.OwnedMultiDeviceIterator(
                dataset,
                devices=[self._devices[1], self._devices[2]],
                element_spec=dataset.element_spec)

    @combinations.generate(test_base.graph_only_combinations())
    def testGraphMode(self):
        dataset = dataset_ops.Dataset.range(1000)
        with self.assertRaisesRegex(
                RuntimeError,
                "OwnedMultiDeviceIterator is only supported inside of tf.function or "
                "when eager execution is enabled."):
            multi_device_iterator_ops.OwnedMultiDeviceIterator(
                dataset, devices=[self._devices[1], self._devices[2]])
示例#29
0
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(test_base.default_test_combinations())
  def testBasic(self):
    components = (
        np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
        np.array([9.0, 10.0, 11.0, 12.0])
    )

    def dataset_fn(count=5, buffer_size=None, seed=0):
      repeat_dataset = (
          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
      if buffer_size:
        shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)

        self.assertEqual(
            tuple([c.shape[1:] for c in components]),
            dataset_ops.get_legacy_output_shapes(shuffle_dataset))
        return shuffle_dataset
      else:
        return repeat_dataset

    # First run without shuffling to collect the "ground truth".
    get_next = self.getNext(dataset_fn())
    unshuffled_elements = []
    for _ in range(20):
      unshuffled_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

    # Assert that the shuffled dataset has the same elements as the
    # "ground truth".
    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
    shuffled_elements = []
    for _ in range(20):
      shuffled_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements))

    # Assert that shuffling twice with the same seeds gives the same sequence.
    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
    reshuffled_elements_same_seed = []
    for _ in range(20):
      reshuffled_elements_same_seed.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)

    # Assert that shuffling twice with a different seed gives a different
    # permutation of the same elements.
    get_next = self.getNext(dataset_fn(buffer_size=100, seed=137))
    reshuffled_elements_different_seed = []
    for _ in range(20):
      reshuffled_elements_different_seed.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
    self.assertAllEqual(
        sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))

    # Assert that the shuffled dataset has the same elements as the
    # "ground truth" when the buffer size is smaller than the input
    # dataset.
    get_next = self.getNext(dataset_fn(buffer_size=2, seed=37))
    reshuffled_elements_small_buffer = []
    for _ in range(20):
      reshuffled_elements_small_buffer.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertAllEqual(
        sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))

    # Test the case of shuffling an empty dataset.
    get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
  def testSeedZero(self):
    """Test for same behavior when the seed is a Python or Tensor zero."""
    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.range(10).shuffle(10, seed=0))
    get_next = iterator.get_next()

    elems = []
    with self.cached_session() as sess:
      for _ in range(10):
        elems.append(sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

    seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
    iterator = dataset_ops.make_initializable_iterator(
        dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder))
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
      for elem in elems:
        self.assertEqual(elem, sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  @combinations.generate(test_base.default_test_combinations())
  def testDefaultArguments(self):
    components = [0, 1, 2, 3, 4]
    dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
        5).repeat()
    get_next = self.getNext(dataset)
    counts = collections.defaultdict(lambda: 0)
    for _ in range(10):
      for _ in range(5):
        counts[self.evaluate(get_next())] += 1

    for i in range(5):
      self.assertEqual(10, counts[i])

  @combinations.generate(
      combinations.times(
          test_base.graph_only_combinations(),
          combinations.combine(reshuffle=[True, False]),
          combinations.combine(graph_seed=38, op_seed=None) +
          combinations.combine(graph_seed=None, op_seed=42) +
          combinations.combine(graph_seed=38, op_seed=42)))
  def testShuffleSeed(self, reshuffle, graph_seed, op_seed):
    results = []
    for _ in range(2):
      with ops.Graph().as_default() as g:
        random_seed.set_random_seed(graph_seed)
        dataset = dataset_ops.Dataset.range(10).shuffle(
            10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3)
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        run_results = []
        with self.session(graph=g) as sess:
          for _ in range(30):
            run_results.append(sess.run(next_element))
          with self.assertRaises(errors.OutOfRangeError):
            sess.run(next_element)
        results.append(run_results)

    self.assertAllEqual(results[0], results[1])

  # TODO(b/117581999): enable this test for eager-mode.
  @combinations.generate(
      combinations.times(
          test_base.graph_only_combinations(),
          combinations.combine(
              reshuffle=[True, False], initializable=[True, False])))
  def testMultipleIterators(self, reshuffle, initializable):
    with ops.Graph().as_default() as g:
      dataset = dataset_ops.Dataset.range(100).shuffle(
          10, reshuffle_each_iteration=reshuffle).repeat(3)

      if initializable:
        iterators = [dataset_ops.make_initializable_iterator(dataset)
                     for _ in range(2)]
      else:
        iterators = [dataset_ops.make_one_shot_iterator(dataset)
                     for _ in range(2)]

      results = []
      with self.session(graph=g) as sess:
        for iterator in iterators:
          if initializable:
            sess.run(iterator.initializer)
          next_element = iterator.get_next()
          run_results = []
          for _ in range(300):
            run_results.append(sess.run(next_element))
          with self.assertRaises(errors.OutOfRangeError):
            sess.run(next_element)

          results.append(run_results)

        self.assertNotEqual(results[0], results[1])

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
  def testReshuffleRepeatEpochs(self, reshuffle, seed):
    dataset = dataset_ops.Dataset.range(10).shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2)
    next_element = self.getNext(dataset)

    first_epoch = []
    for _ in range(10):
      first_epoch.append(self.evaluate(next_element()))

    second_epoch = []
    for _ in range(10):
      second_epoch.append(self.evaluate(next_element()))

    self.assertEqual(first_epoch == second_epoch, not reshuffle)

  @combinations.generate(
      combinations.times(
          combinations.combine(tf_api_version=2, mode="eager"),
          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
  def testReshuffleIterationEpochs(self, reshuffle, seed):
    # TensorFlow unit tests set the global graph seed. We unset it here so that
    # we can control determinism via the `seed` parameter.
    random_seed.set_random_seed(None)
    dataset = dataset_ops.Dataset.range(10).shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle)

    first_epoch = self.getDatasetOutput(dataset)
    second_epoch = self.getDatasetOutput(dataset)

    self.assertEqual(first_epoch == second_epoch, not reshuffle)

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testShuffleV2ResourceCapture(self):

    def make_dataset():
      ids = dataset_ops.Dataset.range(10)
      ids = ids.shuffle(1)

      def interleave_fn(dataset, _):
        return dataset

      dataset = dataset_ops.Dataset.range(1)
      dataset = dataset.interleave(functools.partial(interleave_fn, ids))
      return dataset

    results = []
    for elem in make_dataset():
      results.append(elem.numpy())

    self.assertAllEqual(results, range(10))

  @combinations.generate(
      combinations.times(
          test_base.eager_only_combinations(),
          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
  def testReshuffleSeparateTransformations(self, reshuffle, seed):
    dataset = dataset_ops.Dataset.range(10)

    first_epoch = []
    for elem in dataset.shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle):
      first_epoch.append(elem.numpy())

    second_epoch = []
    for elem in dataset.shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle):
      second_epoch.append(elem.numpy())

    self.assertEqual(first_epoch != second_epoch, seed is None)

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testShuffleV2InFunction(self):
    counter_var = variables.Variable(0)

    @function.defun
    def consume():
      ds = dataset_ops.Dataset.range(10)
      ds = ds.shuffle(1)
      for _ in ds:
        counter_var.assign(counter_var + 1)

    consume()
    self.assertAllEqual(self.evaluate(counter_var), 10)

  @combinations.generate(test_base.default_test_combinations())
  def testEmptyDataset(self):
    dataset = dataset_ops.Dataset.from_tensors(1)

    def map_fn(x):
      with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
        return x

    dataset = dataset.map(map_fn)
    dataset = dataset.cache()
    dataset = dataset.shuffle(buffer_size=10).repeat()

    get_next = self.getNext(dataset)

    # First time around, we get an error for the failed assertion.
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(get_next())

    # Second time around, we get an EOF because the cached dataset is empty.
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(reshuffle=[True, False])))
  def testRerandomizeOnReplicate(self, reshuffle):
    random_seed.set_random_seed(None)
    # When no seeds are fixed, each instantiation of the shuffle dataset should
    # produce elements in a different order.
    num_elements = 100
    dataset = dataset_ops.Dataset.range(num_elements)
    dataset = dataset.shuffle(num_elements, reshuffle_each_iteration=reshuffle)

    shuffle_1 = self.getDatasetOutput(dataset)
    dataset = self.graphRoundTrip(dataset, allow_stateful=True)
    shuffle_2 = self.getDatasetOutput(dataset)

    self.assertCountEqual(shuffle_1, shuffle_2)
    self.assertNotEqual(shuffle_1, shuffle_2)
示例#30
0
class LocalTaskGarbageCollectTest(data_service_test_base.TestBase,
                                  parameterized.TestCase):
    """Tests garbage collecting unused local worker tasks.

  The user typically creates an iterator in each epoch. This should delete the
  previous iterator and releases the resources of it.
  """
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochs(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_epochs, num_steps = 5, 5
        dataset = self._make_distributed_infinite_range_dataset(cluster)
        for _ in range(num_epochs):
            # For each iteration, the previous iterator is garbage collected.
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(self.evaluate(get_next()), i)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochsSharedJob(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_epochs, num_steps = 5, 5
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        for _ in range(num_epochs):
            # For each iteration, the previous iterator is garbage collected.
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(self.evaluate(get_next()), i)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_remote_workers=[0, 3],
                                 job_name=[None, "shared_job_name"])))
    def testRepeatDistributedDataset(self, num_remote_workers, job_name):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        dataset = self.make_distributed_range_dataset(10,
                                                      cluster,
                                                      job_name=job_name,
                                                      target_workers="LOCAL")
        dataset = dataset.repeat(3)
        self.assertDatasetProduces(dataset, list(range(10)) * 3)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testReadFromDeletedTask(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Re-creating the dataset resets the iterator index, so the second iterator
        # reads from the same task as the first, which has been deleted.
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "which has been deleted."):
            _ = self.evaluate(get_next())

    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testReadFromDeletedTask_GraphMode(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        with self.session() as sess:
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(sess.run(get_next()), i)

        # Re-creating the dataset resets the iterator index, so the second iterator
        # reads from the same task as the first, which has been deleted.
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "which has been deleted."):
            with self.session() as sess:
                get_next = self.getNext(dataset)
                sess.run(get_next())

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochs_WorkerRestart(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Verifies the worker re-creates the task after the iterator is deleted and
        # the worker restarts.
        del get_next
        cluster.restart_local_workers()

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochs_DispatcherRestart(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Verifies the worker re-creates the task after the iterator is deleted and
        # the dispatcher restarts.
        del get_next
        cluster.restart_dispatcher()

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

    def _make_distributed_infinite_range_dataset(self, cluster, job_name=None):
        dataset = dataset_ops.Dataset.range(1000000).repeat()
        return self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            job_name=job_name,
            processing_mode=ShardingPolicy.OFF,
            target_workers="LOCAL")