Ejemplo n.º 1
0
class FaultToleranceTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherStop(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        results.append(next(iterator).numpy())
        cluster.stop_dispatcher()
        # After the dispatcher dies, the worker should continue providing the rest
        # of the dataset's elements.
        for _ in range(num_elements - 1):
            results.append(next(iterator).numpy())
        self.assertEqual(results, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBeforeReading(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        cluster.restart_dispatcher()

        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringReading(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringDistributedEpoch(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(
            num_elements, cluster, processing_mode="distributed_epoch")
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringDistributedEpochRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        repetitions = 5
        breakpoints = [50, 250, 450, 500]
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.repeat(repetitions)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")

        iterator = iter(ds)
        results = []
        for breakpoint_ in breakpoints:
            for _ in range(len(results), breakpoint_):
                results.append(next(iterator).numpy())
            cluster.restart_dispatcher()

        self.assertCountEqual(repetitions * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBetweenIterations(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(100, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherManyRestarts(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements_start = 10
        num_elements_end = 15
        datasets = []
        for num_elements in range(num_elements_start, num_elements_end):
            datasets.append(
                self.make_distributed_range_dataset(num_elements, cluster))
            cluster.restart_dispatcher()
        for ds, num_elements in zip(
                datasets, range(num_elements_start, num_elements_end)):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndWorkerRestart(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)

        cluster.restart_dispatcher()
        cluster.workers[0].restart()
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        cluster.workers[0].restart()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndMultiWorkerRestart(self):
        num_workers = 2
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []

        cluster.restart_dispatcher()
        for worker_index in range(num_workers):
            cluster.workers[worker_index].restart()
        for elem in iterator:
            results.append(elem.numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), results)
        cluster.restart_dispatcher()
        for worker_index in range(num_workers):
            cluster.workers[worker_index].restart()
        for elem in iterator:
            results.append(elem.numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        cluster = data_service_test_base.TestCluster(
            num_workers=1, dispatcher_port=dispatcher_port, start=False)

        def start_servers():
            time.sleep(0.5)
            cluster.start_dispatcher()
            cluster.start_workers()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()

    @combinations.generate(test_base.eager_only_combinations())
    def testAddWorkerMidJob(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 2 * multiprocessing.cpu_count() + 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        cluster.add_worker()
        # Wait for the new worker to register with the dispatcher.
        while cluster.num_registered_workers() < 2:
            time.sleep(10 / 1000)  # 10ms

        for elem in iterator:
            results.append(elem.numpy())

        self.assertCountEqual(2 * list(range(num_elements)), results)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(use_same_port=[True, False]),
                           data_service_test_base.all_cluster_configurations())
    )
    def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=work_dir,
            fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 2 * multiprocessing.cpu_count() + 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        cluster.workers[0].restart(use_same_port=use_same_port)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)

    @combinations.generate(test_base.eager_only_combinations())
    def testChangeProcessingModeAfterRestart(self):
        self.skipTest("b/170910141")
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        range_dataset = dataset_ops.Dataset.range(num_elements)
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service=cluster.dispatcher_address(),
                                        job_name="test"))
        iterator = iter(ds)
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())
        cluster.restart_dispatcher()
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="distributed_epoch",
                                        service=cluster.dispatcher_address(),
                                        job_name="test"))
        with self.assertRaisesOpError(
                "already an existing job with that name "
                "using processing mode <parallel_epochs>"):
            next(iter(ds)).numpy()

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR])))
    def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
        cluster = data_service_test_base.TestCluster(num_workers=0,
                                                     work_dir=work_dir,
                                                     fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        it = iter(ds)
        cluster.add_worker()
        self.assertAllEqual(next(it), tensor)
Ejemplo n.º 2
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=work_dir,
            fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testDistributeCompression(self, compression):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 compression=compression)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidCompression(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError,
                                    "Invalid compression argument"):
            self.make_distributed_range_dataset(10, cluster, compression="foo")

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                init_source=["textfile", "keyvaluetensor", "dataset"])))
    def testDistributeLookupTable(self, init_source):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        initializer = self.lookupTableInitializer(init_source, [10, 11])
        table = lookup_ops.StaticHashTable(initializer, -1)
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.evaluate(lookup_ops.tables_initializer())
        self.assertDatasetProduces(ds, [10, 11, -1],
                                   requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(value_rank=[0, 1])))
    def testDistributeMutableHashTable(self, value_rank):
        def value(v):
            for _ in range(value_rank):
                v = [v, v]
            return v

        v1 = value(10)
        v2 = value(11)
        default_value = value(-1)

        cluster = data_service_test_base.TestCluster(num_workers=1)
        table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
                                            default_value)
        self.evaluate(table.insert([0, 1], [v1, v2]))
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [v1, v2, default_value],
                                   requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testDifferentShuffleOrders(self):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = data_service_test_base.TestCluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements)
        ds = self.make_distributed_dataset(ds, cluster)
        output = self.getDatasetOutput(ds)

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        self.assertNotEqual(first_order, second_order)

    @combinations.generate(test_base.default_test_combinations())
    def testMultipleEpochs(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatedDataset(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testConcurrentEpoch(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        get_nexts = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            get_nexts.append(self.getNext(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = self.evaluate(get_nexts[dataset_ind]())
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(self.evaluate(get_next_1()))
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.default_test_combinations())
    def testDifferentJobNames(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobNameRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_1()))
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)
        del it
        while cluster.workers[0].num_tasks() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(cluster.workers[0].num_tasks(), 2)
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.workers[0].num_tasks() > 1:
            time.sleep(0.1)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)

    @combinations.generate(test_base.default_test_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = data_service_test_base.TestCluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = data_service_test_base.TestCluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        self.getDatasetOutput(ds)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(external_state_policy=[
                distribute_options.ExternalStatePolicy.IGNORE,
                distribute_options.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.default_test_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeFromInterleave(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(_):
            dataset = dataset_ops.Dataset.range(2)
            self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 0, 1, 1])

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeExplicitProtocol(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(10)
        ds = ds.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service="grpc://" +
                                        cluster.dispatcher_address()))
        self.assertDatasetProduces(ds, list(range(10)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidProtocol(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(
                errors.NotFoundError,
                "No credentials factory has been registered for protocol grp"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service="grp://" +
                                            cluster.dispatcher_address()))
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError,
                                    "invalid is not a valid processing mode"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = data_service_test_base.TestCluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_1)
        id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_2)
        self.assertEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_1)
        id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_2)
        self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = data_service_test_base.TestCluster(
            num_workers=cluster_1_size)
        cluster_2 = data_service_test_base.TestCluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        get_next = self.getNext(ds)
        for _ in range(num_sizes):
            element = self.evaluate(get_next())
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(self.evaluate(get_next()), element)
        self.assertEmpty(self.getIteratorOutput(get_next))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testDistributeLargeGraph(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])
Ejemplo n.º 3
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=work_dir,
            fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testDistributeCompression(self, compression):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 compression=compression)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testFromDatasetIdOmitsCompression(self, compression):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("abcdefghijklmnopqrstuvwxyz"))

        def to_upper(x):
            return script_ops.numpy_function(
                func=lambda x: x.decode("utf-8").upper(),
                inp=[x],
                Tout=dtypes.string)

        dataset = dataset.map(to_upper,
                              num_parallel_calls=dataset_ops.AUTOTUNE)
        with mock.patch.object(compat, "forward_compatible",
                               return_value=True):
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher.target,
                dataset=dataset,
                compression=compression)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            self.assertDatasetProduces(dataset,
                                       list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))

    # Eager-only as querying `element_spec` is only supported in the eager mode.
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testFromDatasetIdOmitsElementSpecAndCompression(self, compression):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
        with mock.patch.object(compat, "forward_compatible",
                               return_value=True):
            dataset_id = data_service_ops.register_dataset(
                cluster.dispatcher.target,
                dataset=dataset,
                compression=compression)
            dataset = data_service_ops.from_dataset_id(
                processing_mode=ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id)
            self.assertDatasetProduces(dataset,
                                       list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))

    def _testCompressionMismatch(self, dataset):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        with mock.patch.object(compat,
                               "forward_compatible",
                               return_value=False):
            dataset_id = data_service_ops._register_dataset(
                cluster.dispatcher.target, dataset=dataset, compression=None)
            # `compression` is "AUTO" by default.
            dataset = data_service_ops._from_dataset_id(
                processing_mode=ShardingPolicy.OFF,
                service=cluster.dispatcher.target,
                dataset_id=dataset_id,
                element_spec=dataset.element_spec)
            with self.assertRaises(errors.InvalidArgumentError):
                self.getDatasetOutput(dataset)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testCompressionDtypeMismatch(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"))
        self._testCompressionMismatch(dataset)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testCompressionShapeMismatch(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]])
        self._testCompressionMismatch(dataset)

    # Only test eager mode since nested datasets are not allowed in graph mode.
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations()))
    def testCompressionVariantMismatch(self):
        # Use a nested dataset as an example of a variant.
        dataset = dataset_ops.Dataset.from_tensors(
            dataset_ops.Dataset.range(10))
        self._testCompressionMismatch(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidCompression(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError,
                                    "Invalid `compression` argument"):
            self.make_distributed_range_dataset(10, cluster, compression="foo")

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                init_source=["textfile", "keyvaluetensor", "dataset"])))
    def testDistributeLookupTable(self, init_source):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        initializer = self.lookupTableInitializer(init_source, [10, 11])
        table = lookup_ops.StaticHashTable(initializer, -1)
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.evaluate(lookup_ops.tables_initializer())
        self.assertDatasetProduces(ds, [10, 11, -1],
                                   requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(value_rank=[0, 1])))
    def testDistributeMutableHashTable(self, value_rank):
        def value(v):
            for _ in range(value_rank):
                v = [v, v]
            return v

        v1 = value(10)
        v2 = value(11)
        default_value = value(-1)

        cluster = data_service_test_base.TestCluster(num_workers=1)
        table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
                                            default_value)
        self.evaluate(table.insert([0, 1], [v1, v2]))
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [v1, v2, default_value],
                                   requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(shuffle_seed=[None, 10])))
    def testShuffleOrder(self, shuffle_seed):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = data_service_test_base.TestCluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements, seed=shuffle_seed)
        ds = self.make_distributed_dataset(ds, cluster)
        output = self.getDatasetOutput(ds)

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        if shuffle_seed is None:
            self.assertNotEqual(first_order, second_order)
        else:
            self.assertEqual(first_order, second_order)

    @combinations.generate(test_base.default_test_combinations())
    def testMultipleEpochs(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatedDataset(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testConcurrentEpoch(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        get_nexts = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            get_nexts.append(self.getNext(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = self.evaluate(get_nexts[dataset_ind]())
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyJobNameDistribute(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError,
                                    "`job_name` must not be empty"):
            dataset_ops.Dataset.range(10).apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=cluster.dispatcher.target,
                                            job_name=""))

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyJobNameFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset_ops.Dataset.range(10))
        with self.assertRaisesRegex(ValueError,
                                    "`job_name` must not be empty"):
            data_service_ops.from_dataset_id(dataset_id=dataset_id,
                                             processing_mode="parallel_epochs",
                                             service=cluster.dispatcher.target,
                                             job_name="")

    @combinations.generate(test_base.default_test_combinations())
    def testExplicitProtocolFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        range_ds = dataset_ops.Dataset.range(10)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target, range_ds)
        ds = data_service_ops.from_dataset_id(
            dataset_id=dataset_id,
            processing_mode="parallel_epochs",
            element_spec=range_ds.element_spec,
            service=cluster.dispatcher.target,
            data_transfer_protocol="grpc")
        self.assertDatasetProduces(ds, list(range(10)))

    @combinations.generate(test_base.default_test_combinations())
    def testNonStringJobNameDistribute(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError, "`job_name` must be a string"):
            dataset_ops.Dataset.range(10).apply(
                data_service_ops.distribute(
                    processing_mode="parallel_epochs",
                    service=cluster.dispatcher.target,
                    job_name=constant_op.constant("foo")))

    @combinations.generate(test_base.default_test_combinations())
    def testNonStringJobNameFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset_ops.Dataset.range(10))
        with self.assertRaisesRegex(ValueError, "`job_name` must be a string"):
            data_service_ops.from_dataset_id(
                dataset_id=dataset_id,
                processing_mode="parallel_epochs",
                service=cluster.dispatcher.target,
                job_name=constant_op.constant("foo"))

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(self.evaluate(get_next_1()))
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.default_test_combinations())
    def testDifferentJobNames(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobNameRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_1()))
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultipleEpochs(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset = self.make_distributed_range_dataset(10,
                                                      cluster,
                                                      job_name="job_name")

        num_epochs = 5
        for _ in range(num_epochs):
            get_next = self.getNext(dataset)
            self.assertEqual(self.getIteratorOutput(get_next), list(range(10)))

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)
        del it
        while cluster.workers[0].num_tasks() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(cluster.workers[0].num_tasks(), 2)
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.workers[0].num_tasks() > 1:
            time.sleep(0.1)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)

    @combinations.generate(test_base.eager_only_combinations())
    def testGcAndRecreate(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=3, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 1000
        # Repeatedly create and garbage-collect the same job.
        for _ in range(3):
            ds = self.make_distributed_range_dataset(num_elements,
                                                     cluster,
                                                     job_name="test")
            it = iter(ds)
            for _ in range(50):
                next(it)
            del it
            # Wait for the task to be garbage-collected on all workers.
            while cluster.num_tasks_on_workers() > 0:
                time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testGcClient(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=50))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=10000))
        get_next = self.getNext(dataset)

        # The client does not heartbeat in 10 seconds. It will be garbage-collected.
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Unknown job client id"):
            self.evaluate(get_next())
            time.sleep(3)
            self.getIteratorOutput(get_next)

    @combinations.generate(test_base.eager_only_combinations())
    def testKeepClientAliveBeforeReading(self):
        dispatcher = server_lib.DispatchServer(
            service_config_pb2.DispatcherConfig(protocol="grpc",
                                                job_gc_check_interval_ms=50,
                                                job_gc_timeout_ms=20,
                                                client_timeout_ms=1000))
        dispatcher_address = dispatcher.target.split("://")[1]
        _ = server_lib.WorkerServer(
            server_lib.WorkerConfig(dispatcher_address=dispatcher_address,
                                    heartbeat_interval_ms=100))

        num_elements = 1000
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.apply(
            data_service_ops._distribute(processing_mode=ShardingPolicy.OFF,
                                         service=dispatcher.target,
                                         task_refresh_interval_hint_ms=100))
        get_next = self.getNext(dataset)

        # The client regularly heartbeats in 100 milliseconds. It should not be
        # garbage-collected even if it does not start reading in 3 seconds.
        time.sleep(3)
        self.assertEqual(self.getIteratorOutput(get_next),
                         list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = data_service_test_base.TestCluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = options_lib.Options()
            opts.deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = options_lib.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = data_service_test_base.TestCluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        self.getDatasetOutput(ds)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(external_state_policy=[
                options_lib.ExternalStatePolicy.IGNORE,
                options_lib.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.default_test_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(options_lib.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeFromInterleave(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(x):
            dataset = dataset_ops.Dataset.range(10 * x, 10 * x + 2)
            dataset = self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 10, 1, 11])

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "`service` must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "`service` must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeExplicitProtocol(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        ds = dataset_ops.Dataset.range(10)
        ds = ds.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service="grpc://" +
                                        cluster.dispatcher_address()))
        self.assertDatasetProduces(ds, list(range(10)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidProtocol(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(
                errors.NotFoundError,
                "No credentials factory has been registered for protocol grp"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service="grp://" +
                                            cluster.dispatcher_address()))
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(
                ValueError,
                "should be a `tf.data.experimental.service.ShardingPolicy`, "
                "`\"parallel_epochs\"`, or "
                "`\"distributed_epoch\"`. Got 'invalid'."):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdSharedJobs(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)

        datasets = [
            dataset_ops.Dataset.range(20, output_type=dtypes.int32),
            dataset_ops.Dataset.from_tensor_slices(list(range(20, 40)))
        ]
        dataset_ids = []

        for ds in datasets:
            dataset_id = self.register_dataset(cluster.dispatcher_address(),
                                               ds)
            dataset_ids.append(dataset_id)

        # Read from both jobs in parallel, with 2 consumers for each job.
        data_service_datasets = []
        for _ in range(2):
            for dataset, dataset_id in zip(datasets, dataset_ids):
                ds = self.from_dataset_id("distributed_epoch",
                                          cluster,
                                          dataset_id,
                                          dataset.element_spec,
                                          job_name="shared_job")
                data_service_datasets.append(ds)
        ds = dataset_ops.Dataset.from_tensor_slices(data_service_datasets)
        ds = ds.interleave(lambda x: x,
                           cycle_length=len(data_service_datasets))

        self.assertDatasetProduces(ds,
                                   list(range(40)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testRegisteringDatasetAsTfFunction(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        register_func = def_function.function(self.register_dataset)
        dataset_id = register_func(
            (constant_op.constant("grpc"),
             constant_op.constant(cluster.dispatcher_address())), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, wrong_spec)

        if data_service_test_base.TRANSFER_PROTOCOL.value:
            with self.assertRaisesRegex(errors.InvalidArgumentError,
                                        "Data type mismatch at component 0"):
                self.evaluate(self.getNext(from_dataset_id_ds)())
        else:
            with self.assertRaisesRegex(errors.FailedPreconditionError,
                                        "Expected a tensor of type variant"):
                self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Dataset id 0 not found"):
            from_dataset_id_ds = self.from_dataset_id("parallel_epochs",
                                                      cluster, dataset_id,
                                                      element_spec)
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = data_service_test_base.TestCluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1)
        id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2)
        self.assertEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1)
        id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2)
        self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = data_service_test_base.TestCluster(
            num_workers=cluster_1_size)
        cluster_2 = data_service_test_base.TestCluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        get_next = self.getNext(ds)
        for _ in range(num_sizes):
            element = self.evaluate(get_next())
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(self.evaluate(get_next()), element)
        self.assertEmpty(self.getIteratorOutput(get_next))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testDistributeLargeGraph(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])

    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           combinations.combine(use_resource=False)) +
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(use_resource=True)))
    def testVariables(self, use_resource):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        if not use_resource:
            with variable_scope.variable_scope("foo", use_resource=False):
                v = variables.VariableV1(10, dtype=dtypes.int64)
        else:
            v = variables.Variable(10, dtype=dtypes.int64)

        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(lambda x: x + v)
        ds = self.make_distributed_dataset(ds, cluster)
        self.evaluate(v.initializer)
        self.assertDatasetProduces(ds,
                                   list(range(10, 13)),
                                   requires_initialization=True)

    @combinations.generate(test_base.graph_only_combinations())
    def testElementSpecGraphMode(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        with self.assertRaisesRegex(
                ValueError,
                "In graph mode `element_spec` must be provided manually."):
            ds = data_service_ops.from_dataset_id("parallel_epochs",
                                                  cluster.dispatcher_address(),
                                                  dataset_id)

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdDoesntRequireElementSpec(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=NO_WORK_DIR,
            fault_tolerant_mode=False,
            data_transfer_protocol="grpc")
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)

        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        ds = data_service_ops.from_dataset_id("parallel_epochs",
                                              cluster.dispatcher_address(),
                                              dataset_id)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testElementSpecMixedMode(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)

        @def_function.function
        def get_dataset_id():
            return data_service_ops.register_dataset(
                cluster.dispatcher_address(), ds)

        dataset_id = get_dataset_id()
        dataset_id_val = tensor_util.constant_value(dataset_id)

        with self.assertRaisesRegex(
                ValueError, "Failed to fetch element spec for dataset id " +
                str(dataset_id_val) + " from tf.data service. If the "
                "dataset was registered in graph mode or inside a "
                "tf.function, the `element_spec` must be specified as "
                "an argument to `from_dataset_id`."):
            ds = data_service_ops.from_dataset_id("parallel_epochs",
                                                  cluster.dispatcher_address(),
                                                  dataset_id)

    @combinations.generate(test_base.default_test_combinations())
    def testNoShardingPolicy(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF)
        self.assertDatasetProduces(dataset, list(range(20)))

    @combinations.generate(test_base.default_test_combinations())
    def testCardinality(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset = self.make_distributed_range_dataset(10, cluster)
        self.assertEqual(self.evaluate(dataset.cardinality()),
                         dataset_ops.UNKNOWN)