Ejemplo n.º 1
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherStop(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        results.append(next(iterator).numpy())
        cluster.stop_dispatcher()
        # After the dispatcher dies, the worker should continue providing the rest
        # of the dataset's elements.
        for _ in range(num_elements - 1):
            results.append(next(iterator).numpy())
        self.assertEqual(results, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBeforeReading(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        cluster.restart_dispatcher()

        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringReading(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringDistributedEpoch(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(
            num_elements, cluster, processing_mode="distributed_epoch")
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringDistributedEpochRepeat(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        repetitions = 5
        breakpoints = [50, 250, 450, 500]
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.repeat(repetitions)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")

        iterator = iter(ds)
        results = []
        for breakpoint in breakpoints:
            for _ in range(len(results), breakpoint):
                results.append(next(iterator).numpy())
            cluster.restart_dispatcher()

        self.assertCountEqual(repetitions * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBetweenIterations(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(100, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherManyRestarts(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements_start = 10
        num_elements_end = 15
        datasets = []
        for num_elements in range(num_elements_start, num_elements_end):
            datasets.append(
                self.make_distributed_range_dataset(num_elements, cluster))
            cluster.restart_dispatcher()
        for ds, num_elements in zip(
                datasets, range(num_elements_start, num_elements_end)):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndWorkerRestart(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)

        cluster.restart_dispatcher()
        cluster.restart_worker()
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        cluster.restart_worker()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(workers_to_add=[1, 3])))
    def testRoundRobinAddWorkers(self, workers_to_add):
        starting_workers = 3
        cluster = self.create_cluster(num_workers=starting_workers)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_consumers = 7
        ds = dataset_ops.Dataset.range(100000000)
        ds = ds.repeat()
        consumers = []
        for consumer_index in range(num_consumers):
            consumers.append(
                self.make_distributed_dataset(ds,
                                              cluster,
                                              job_name="test",
                                              consumer_index=consumer_index,
                                              num_consumers=num_consumers))
        # Use parallel interleave to read from consumers in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(consumers)
        ds = ds.interleave(lambda x: x,
                           cycle_length=num_consumers,
                           num_parallel_calls=num_consumers)

        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        zeros_seen = 0
        for _ in range(50):
            results.append(self.evaluate(get_next()))
            if results[-1] == 0:
                zeros_seen += 1

        for _ in range(workers_to_add):
            cluster.add_worker()

        # Read until all new workers have joined.
        while zeros_seen < starting_workers + workers_to_add:
            results.append(self.evaluate(get_next()))
            if results[-1] == 0:
                zeros_seen += 1
        # Read some more.
        for _ in range(100):
            results.append(self.evaluate(get_next()))

        for i in range(0, len(results), num_consumers):
            self.assertEqual(0, results[i] % num_consumers)
            # Check that each group of `num_consumers` results are consecutive.
            for offset in range(1, num_consumers):
                if i + offset < len(results):
                    self.assertEqual(results[i] + offset, results[i + offset])

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndMultiWorkerRestart(self):
        num_workers = 2
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []

        cluster.restart_dispatcher()
        for worker_index in range(num_workers):
            cluster.restart_worker(worker_index=worker_index)
        for elem in iterator:
            results.append(elem.numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), results)
        cluster.restart_dispatcher()
        for worker_index in range(num_workers):
            cluster.restart_worker(worker_index=worker_index)
        for elem in iterator:
            results.append(elem.numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        cluster = self.create_cluster(num_workers=1,
                                      dispatcher_port=dispatcher_port,
                                      start=False)

        def start_servers():
            time.sleep(0.5)
            cluster.start_dispatcher()
            cluster.start_workers()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()

    @combinations.generate(test_base.eager_only_combinations())
    def testAddWorkerMidJob(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        cluster.add_worker()
        # Wait for the new worker to register with the dispatcher.
        while cluster.num_registered_workers() < 2:
            time.sleep(10 / 1000)  # 10ms

        for elem in iterator:
            results.append(elem.numpy())

        self.assertCountEqual(2 * list(range(num_elements)), results)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(use_same_port=[True, False]),
                           data_service_test_base.all_cluster_configurations())
    )
    def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        cluster.restart_worker(use_same_port=use_same_port)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)

    @combinations.generate(test_base.eager_only_combinations())
    def testChangeProcessingModeAfterRestart(self):
        self.skipTest("b/170910141")
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        range_dataset = dataset_ops.Dataset.range(num_elements)
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service=cluster.target,
                                        job_name="test"))
        iterator = iter(ds)
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())
        cluster.restart_dispatcher()
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="distributed_epoch",
                                        service=cluster.target,
                                        job_name="test"))
        with self.assertRaisesOpError(
                "already an existing job with that name "
                "using processing mode <parallel_epochs>"):
            next(iter(ds)).numpy()

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR])))
    def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
        cluster = self.create_cluster(num_workers=0,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        it = iter(ds)
        cluster.add_worker()
        self.assertAllEqual(next(it), tensor)
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations(),
                         data_service_test_base.all_cluster_configurations()))
  def testDistributeBasic(self, work_dir, fault_tolerant_mode):
    cluster = self.create_cluster(
        num_workers=1,
        work_dir=work_dir,
        fault_tolerant_mode=fault_tolerant_mode)
    num_elements = 10
    ds = self.make_distributed_range_dataset(10, cluster)
    results = [elem.numpy() for elem in ds]
    self.assertEqual(list(range(num_elements)), results)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeSparse(self):
    cluster = self.create_cluster(num_workers=1)
    element = sparse_tensor.SparseTensor(
        indices=[[0]],
        values=constant_op.constant([0], dtype=dtypes.int32),
        dense_shape=[1])
    ds = dataset_ops.Dataset.from_tensors(element)
    ds = self.make_distributed_dataset(ds, cluster)
    results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
    self.assertAllEqual(results, [[0]])

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeRagged(self):
    cluster = self.create_cluster(num_workers=1)
    ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
    ds = ds.map(math_ops.range)
    ds = ds.apply(batching.dense_to_ragged_batch(2))
    ds = self.make_distributed_dataset(ds, cluster)
    results = [elem.to_tensor() for elem in ds]
    self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
    self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
    self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

  @combinations.generate(test_base.eager_only_combinations())
  def testDifferentShuffleOrders(self):
    random_seed.set_random_seed(None)
    num_elements = 100
    cluster = self.create_cluster(num_workers=2)
    ds = dataset_ops.Dataset.range(num_elements)
    ds = ds.shuffle(num_elements)
    ds = self.make_distributed_dataset(ds, cluster)
    output = [elem.numpy() for elem in ds]

    # The output will be two sequences of range(num_elements)
    # non-deterministically interleaved together. If the orders of the elements
    # were the same, first_order and second_order computed below will be equal.
    first_order = {}
    second_order = {}
    for element in output:
      if element in first_order:
        second_order[element] = len(second_order)
      else:
        first_order[element] = len(first_order)
    self.assertNotEqual(first_order, second_order)

  @combinations.generate(test_base.eager_only_combinations())
  def testMultipleEpochs(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 3
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    for _ in range(10):
      self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])

  @combinations.generate(test_base.eager_only_combinations())
  def testRepeatedDataset(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 10
    num_repetitions = 5
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    ds = ds.repeat(num_repetitions)
    self.assertDatasetProduces(
        ds, expected_output=num_repetitions * list(range(num_elements)))

  @combinations.generate(test_base.eager_only_combinations())
  def testConcurrentEpoch(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 10
    num_datasets = 3
    iterators = []
    results = []
    for _ in range(num_datasets):
      ds = self.make_distributed_range_dataset(num_elements, cluster)
      iterators.append(iter(ds))
      results.append([])

    for _ in range(num_elements):
      for dataset_ind in range(num_datasets):
        result = next(iterators[dataset_ind]).numpy()
        results[dataset_ind].append(result)
    for result in results:
      self.assertEqual(list(range(num_elements)), result)

  @combinations.generate(test_base.eager_only_combinations())
  def testSharedEpoch(self):
    self.skipTest("Not yet implemented")
    cluster = self.create_cluster(num_workers=1)
    num_elements = 10
    num_iterators = 3
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    result = []
    iterators = []
    for _ in range(num_iterators):
      iterators.append(iter(ds))

    # Alternate reading between the iterators.
    for _ in range(2):
      for it in iterators:
        result.append(next(it).numpy())

    # Drain the rest of the elements.
    for it in iterators:
      for elem in it:
        result.append(elem.numpy())

    self.assertCountEqual(list(range(num_elements)), result)

  @combinations.generate(test_base.eager_only_combinations())
  def testMultiWorker(self):
    num_workers = 3
    cluster = self.create_cluster(num_workers=num_workers)
    num_elements = 10
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    results = [elem.numpy() for elem in ds]
    self.assertCountEqual(num_workers * list(range(num_elements)), results)

  @combinations.generate(test_base.eager_only_combinations())
  def testMaxOutstandingRequests(self):
    num_workers = 3
    cluster = self.create_cluster(num_workers=num_workers)
    num_elements = 10
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, max_outstanding_requests=1)
    self.assertCountEqual(num_workers * list(range(num_elements)),
                          self.getDatasetOutput(ds))

  @combinations.generate(test_base.eager_only_combinations())
  def testInsideFunction(self):
    num_workers = 3
    cluster = self.create_cluster(num_workers=num_workers)
    num_elements = 10

    @def_function.function
    def f():
      ds = self.make_distributed_range_dataset(num_elements, cluster)
      result = tensor_array_ops.TensorArray(
          dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
      i = 0
      for elem in ds:
        result = result.write(i, elem)
        i += 1
      return result.stack()

    result = list(f().numpy())
    self.assertCountEqual(num_workers * list(range(num_elements)), result)

  @combinations.generate(test_base.eager_only_combinations())
  def testSharedJobName(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 100

    def make_ds():
      return dataset_ops.Dataset.range(num_elements).shuffle(num_elements)

    ds1 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name")
    ds2 = self.make_distributed_dataset(make_ds(), cluster, job_name="job_name")
    iter1 = iter(ds1)
    iter2 = iter(ds2)
    results = []
    for _ in range(num_elements // 5):
      results.append(next(iter1).numpy())
      results.append(next(iter2).numpy())
    for elem in iter1:
      results.append(elem.numpy())
    for elem in iter2:
      results.append(elem.numpy())
    self.assertCountEqual(list(range(num_elements)), results)

  @combinations.generate(test_base.eager_only_combinations())
  def testDifferentJobNames(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 10
    ds1 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name1")
    ds2 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name2")
    self.assertDatasetProduces(ds1, list(range(num_elements)))
    self.assertDatasetProduces(ds2, list(range(num_elements)))

  @combinations.generate(test_base.eager_only_combinations())
  def testSharedJobNameMultiIteration(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 10
    ds1 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    ds2 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    # iteration 1
    self.assertDatasetProduces(ds1, list(range(num_elements)))
    self.assertDatasetProduces(ds2, [])
    # iteration 2
    self.assertDatasetProduces(ds2, list(range(num_elements)))
    self.assertDatasetProduces(ds1, [])

  @combinations.generate(test_base.eager_only_combinations())
  def testSharedJobNameRepeat(self):
    cluster = self.create_cluster(num_workers=1)
    num_elements = 100
    num_repetitions = 3
    ds1 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    ds1 = ds1.repeat(num_repetitions)
    ds2 = self.make_distributed_range_dataset(
        num_elements, cluster, job_name="job_name")
    ds2 = ds2.repeat(num_repetitions)
    results = []
    iter1 = iter(ds1)
    iter2 = iter(ds2)
    for _ in range((num_elements * num_repetitions) // 5):
      results.append(next(iter1).numpy())
    for _ in range((num_elements * num_repetitions) // 5):
      results.append(next(iter2).numpy())
    for elem in iter1:
      results.append(elem.numpy())
    for elem in iter2:
      results.append(elem.numpy())
    self.assertCountEqual(num_repetitions * list(range(num_elements)), results)

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations(),
                         combinations.combine(job_name=[None, "test"])))
  def testGcUnusedJob(self, job_name):
    cluster = self.create_cluster(
        num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
    num_elements = 100
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, job_name=job_name)
    it = iter(ds)
    self.assertEqual(next(it).numpy(), 0)
    self.assertEqual(cluster.num_tasks_on_worker(), 1)
    del it
    while cluster.num_tasks_on_worker() > 0:
      time.sleep(0.1)

  @combinations.generate(test_base.eager_only_combinations())
  def testDontGcUsedJob(self):
    cluster = self.create_cluster(
        num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
    num_elements = 10
    it1 = iter(
        self.make_distributed_range_dataset(
            num_elements, cluster, job_name="test1"))
    it2 = iter(
        self.make_distributed_range_dataset(
            num_elements, cluster, job_name="test2"))
    it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
        self.make_distributed_range_dataset(
            num_elements, cluster, job_name="test2"))
    self.assertEqual(2, cluster.num_tasks_on_worker())
    del it1
    del it2
    # Check that only the first job is gced. The second job will not be gced
    # because there is still an outstanding iterator for it.
    while cluster.num_tasks_on_worker() > 1:
      time.sleep(0.1)
    self.assertEqual(1, cluster.num_tasks_on_worker())

  @combinations.generate(test_base.eager_only_combinations())
  def testApplyDeterminismOption(self):
    elements = list(range(10))
    cluster = self.create_cluster(num_workers=1)

    def dataset_fn(delay_ms):

      def interleave_fn(x):
        ds = dataset_ops.Dataset.from_tensors(x)
        if math_ops.equal(x, 0):
          ds = ds.apply(testing.sleep(delay_ms * 1000))
        else:
          ds = ds.apply(testing.sleep(0))
        return ds

      ds = dataset_ops.Dataset.from_tensor_slices(elements)
      ds = ds.interleave(interleave_fn, cycle_length=10, num_parallel_calls=10)
      opts = dataset_ops.Options()
      opts.experimental_deterministic = False
      ds = ds.with_options(opts)
      ds = self.make_distributed_dataset(ds, cluster)
      return ds

    self.checkDeterminism(
        dataset_fn=dataset_fn,
        expect_determinism=False,
        expected_elements=elements)

  def run_stateful(self, external_state_policy):
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements).map(
        lambda _: random_ops.random_uniform(()))

    options = dataset_ops.Options()
    options.experimental_external_state_policy = external_state_policy
    ds = ds.with_options(options)

    cluster = self.create_cluster(num_workers=3)
    ds = self.make_distributed_dataset(ds, cluster)
    next(iter(ds))

  @combinations.generate(
      combinations.times(
          test_base.eager_only_combinations(),
          combinations.combine(external_state_policy=[
              distribute_options.ExternalStatePolicy.IGNORE,
              distribute_options.ExternalStatePolicy.WARN
          ])))
  def testStatefulNoError(self, external_state_policy):
    self.run_stateful(external_state_policy)

  @combinations.generate(test_base.eager_only_combinations())
  def testStatefulError(self):
    with self.assertRaises(errors.FailedPreconditionError):
      self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpochTensorSlices(self):
    cluster = self.create_cluster(num_workers=2)
    vals = [5, 1, 2, 4]
    ds = dataset_ops.Dataset.from_tensor_slices(vals)
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(ds, vals, assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpochInterleave(self):
    cluster = self.create_cluster(num_workers=2)
    elements = [1, 5, 0]
    ds = dataset_ops.Dataset.from_tensor_slices(elements)
    ds = ds.interleave(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(ds, elements, assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpochParallelInterleave(self):
    cluster = self.create_cluster(num_workers=2)
    elements = [1, 5, 0]
    ds = dataset_ops.Dataset.from_tensor_slices(elements)
    ds = ds.interleave(
        lambda x: dataset_ops.Dataset.from_tensor_slices([x]),
        num_parallel_calls=dataset_ops.AUTOTUNE)
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(ds, elements, assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpochFlatMap(self):
    cluster = self.create_cluster(num_workers=2)
    elements = [1, 5, 0]
    ds = dataset_ops.Dataset.from_tensor_slices(elements)
    ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(ds, elements, assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpochRepeat(self):
    cluster = self.create_cluster(num_workers=2)
    num_repeats = 5
    num_elements = 20
    ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(
        ds, num_repeats * list(range(num_elements)), assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpochShuffleAndRepeat(self):
    cluster = self.create_cluster(num_workers=2)
    num_repeats = 5
    num_elements = 20
    ds = dataset_ops.Dataset.range(num_elements).shuffle(num_elements).repeat(
        num_repeats)
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(
        ds, num_repeats * list(range(num_elements)), assert_items_equal=True)

  def testDistributeFromInterleave(self):
    cluster = self.create_cluster(num_workers=1)
    ds = dataset_ops.Dataset.range(2)

    def interleave_fn(_):
      dataset = dataset_ops.Dataset.range(2)
      self.make_distributed_dataset(dataset, cluster)
      return dataset

    ds = ds.interleave(interleave_fn, cycle_length=2)
    self.assertDatasetProduces(ds, [0, 0, 1, 1])

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeDistributedEpoch(self):
    cluster = self.create_cluster(num_workers=2)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = self.make_distributed_dataset(
        ds, cluster, processing_mode="distributed_epoch")
    self.assertDatasetProduces(
        ds, list(range(num_elements)), assert_items_equal=True)

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeNonStringAddresses(self):
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(ValueError, "service must be a string"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="parallel_epochs", service=1))

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeEmptyAddress(self):
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesWithLiteralMatch(ValueError,
                                           "service must not be empty"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="parallel_epochs", service=""))

  @combinations.generate(test_base.eager_only_combinations())
  def testDistributeInvalidProcessingMode(self):
    ds = dataset_ops.Dataset.range(10)
    with self.assertRaisesRegex(ValueError,
                                "invalid is not a valid processing mode"):
      ds = ds.apply(
          data_service_ops.distribute(
              processing_mode="invalid", service="grpc://localhost:5000"))

  @combinations.generate(test_base.eager_only_combinations())
  def testFromDatasetId(self):
    cluster = self.create_cluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(cluster.target, ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
    self.assertDatasetProduces(from_dataset_id_ds, list(range(num_elements)))

  @combinations.generate(test_base.eager_only_combinations())
  def testFromDatasetIdMultipleComponents(self):
    cluster = self.create_cluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
    dataset_id = data_service_ops.register_dataset(cluster.target, ds)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
    output = self.getDatasetOutput(from_dataset_id_ds)
    for i in range(num_elements):
      self.assertEqual(i, output[i]["a"][0])
      self.assertEqual(i, output[i]["a"][1])
      self.assertEqual(i, output[i]["b"])

  @combinations.generate(test_base.eager_only_combinations())
  def testFromDatasetIdWrongElementSpec(self):
    cluster = self.create_cluster(num_workers=1)

    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    dataset_id = data_service_ops.register_dataset(cluster.target, ds)
    wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.target, dataset_id, wrong_spec)
    with self.assertRaisesRegex(errors.FailedPreconditionError,
                                "Expected a tensor of type variant"):
      self.evaluate(self.getNext(from_dataset_id_ds)())

  @combinations.generate(test_base.eager_only_combinations())
  def testFromDatasetIdNotRegistered(self):
    cluster = self.create_cluster(num_workers=1)

    dataset_id = 0
    element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
    from_dataset_id_ds = data_service_ops.from_dataset_id(
        "parallel_epochs", cluster.target, dataset_id, element_spec)
    with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
      self.evaluate(self.getNext(from_dataset_id_ds)())

  @combinations.generate(test_base.default_test_combinations())
  def testCancellation(self):
    self.skipTest("b/162521601")
    sleep_microseconds = int(1e6) * 1000

    cluster = self.create_cluster(num_workers=1)
    # Create a dataset which produces the first element quickly, and the second
    # element slowly. Fetching the first element triggers prefetching of the
    # second element, which we should be able to cancel.
    slow = dataset_ops.Dataset.range(1)
    slow = slow.apply(testing.sleep(sleep_microseconds))
    ds = dataset_ops.Dataset.range(1).concatenate(slow)
    ds = self.make_distributed_dataset(ds, cluster)
    ds = ds.prefetch(1)
    get_next = self.getNext(ds, requires_initialization=True)
    self.assertEqual(0, self.evaluate(get_next()))
    # Without properly implemented cancellation, we will hang here while trying
    # to garbage collect the dataset iterator.

  @combinations.generate(test_base.eager_only_combinations())
  def testRegisterEquivalentDatasets(self):
    ds_1 = dataset_ops.Dataset.range(10)
    ds_2 = dataset_ops.Dataset.range(10)
    cluster = self.create_cluster(num_workers=1)
    id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
    id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
    self.assertEqual(id_1.numpy(), id_2.numpy())

  @combinations.generate(test_base.eager_only_combinations())
  def testRegisterDifferentDatasets(self):
    ds_1 = dataset_ops.Dataset.range(10)
    ds_2 = dataset_ops.Dataset.range(20)
    cluster = self.create_cluster(num_workers=1)
    id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
    id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
    self.assertNotEqual(id_1.numpy(), id_2.numpy())

  @combinations.generate(test_base.eager_only_combinations())
  def testTwoLevelDistribute(self):
    cluster_1_size = 3
    cluster_1 = self.create_cluster(num_workers=cluster_1_size)
    cluster_2 = self.create_cluster(num_workers=1)
    num_sizes = 10
    size_repeats = 5
    strings = ["a" * i for i in range(num_sizes)] * size_repeats
    ds = dataset_ops.Dataset.from_tensor_slices(strings)
    ds = ds.shuffle(len(strings))
    ds = self.make_distributed_dataset(ds, cluster_1)
    # Large enough so that all strings of the same size are windowed together.
    window_size = cluster_1_size * size_repeats
    batch_size = size_repeats

    def key_func(x):
      return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

    ds = ds.apply(
        grouping.group_by_window(
            key_func=key_func,
            reduce_func=lambda _, x: x.batch(batch_size),
            window_size=window_size))
    ds = self.make_distributed_dataset(ds, cluster_2)

    it = iter(ds)
    for _ in range(num_sizes):
      element = next(it).numpy()
      for _ in range(1, cluster_1_size):
        self.assertAllEqual(next(it).numpy(), element)
    self.assertEmpty(list(it))

  @combinations.generate(
      combinations.times(test_base.eager_only_combinations()))
  def testDistributeLargeGraph(self):
    cluster = self.create_cluster(
        num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False)
    # Larger than default OSS grpc message size limit of 4MB.
    tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
    ds = dataset_ops.Dataset.from_tensors(tensor)
    ds = self.make_distributed_dataset(ds, cluster)
    self.assertDatasetProduces(ds, [tensor])
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(10, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = self.create_cluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = self.create_cluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDifferentShuffleOrders(self):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = self.create_cluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements)
        ds = self.make_distributed_dataset(ds, cluster)
        output = [elem.numpy() for elem in ds]

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        self.assertNotEqual(first_order, second_order)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleEpochs(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertEqual(list(range(num_elements)),
                             [elem.numpy() for elem in ds])

    @combinations.generate(test_base.eager_only_combinations())
    def testRepeatedDataset(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testConcurrentEpoch(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        iterators = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            iterators.append(iter(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = next(iterators[dataset_ind]).numpy()
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedEpoch(self):
        self.skipTest("Not yet implemented")
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_iterators = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        result = []
        iterators = []
        for _ in range(num_iterators):
            iterators.append(iter(ds))

        # Alternate reading between the iterators.
        for _ in range(2):
            for it in iterators:
                result.append(next(it).numpy())

        # Drain the rest of the elements.
        for it in iterators:
            for elem in it:
                result.append(elem.numpy())

        self.assertCountEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertCountEqual(num_workers * list(range(num_elements)),
                              self.getDatasetOutput(ds))

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobName(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(next(iter1).numpy())
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDifferentJobNames(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameRepeat(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(next(iter1).numpy())
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
    def testRoundRobin(self, num_workers, num_consumers):
        cluster = self.create_cluster(num_workers=num_workers)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        ds = dataset_ops.Dataset.range(10000000)
        ds = ds.repeat()
        consumers = []
        for consumer_index in range(num_consumers):
            consumers.append(
                self.make_distributed_dataset(ds,
                                              cluster,
                                              job_name="test",
                                              consumer_index=consumer_index,
                                              num_consumers=num_consumers))
        # Use parallel interleave to read from consumers in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(consumers)
        ds = ds.interleave(lambda x: x,
                           cycle_length=num_consumers,
                           num_parallel_calls=num_consumers)
        ds = ds.take(1000)
        results = self.getDatasetOutput(ds, requires_initialization=True)

        for i in range(0, len(results), num_consumers):
            self.assertEqual(0, results[i] % num_consumers)
            # Check that each group of `num_consumers` results are consecutive.
            for offset in range(1, num_consumers):
                if i + offset < len(results):
                    self.assertEqual(results[i] + offset, results[i + offset])

    @combinations.generate(test_base.default_test_combinations())
    def testRoundRobinBucketizing(self):
        # Tests a common use case for round robin reads. At each step, all
        # consumers should get batches with the same bucket size.
        cluster = self.create_cluster(num_workers=4)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_elements = 100
        low_bucket_max = 30
        mid_bucket_max = 60
        bucket_boundaries = [low_bucket_max, mid_bucket_max]
        batch_size = 10
        num_consumer_hosts = 3
        replicas_per_consumer_host = 5
        num_consumers = num_consumer_hosts * replicas_per_consumer_host
        bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
        # Set up the dataset that will run on the tf.data workers.
        ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
        ds = ds.shuffle(num_elements)
        ds = ds.repeat()
        ds = ds.apply(
            grouping.bucket_by_sequence_length(lambda x: x,
                                               bucket_boundaries,
                                               bucket_batch_sizes,
                                               drop_remainder=True))
        ds = ds.apply(
            grouping.group_by_window(
                lambda x: math_ops.cast(x[1], dtypes.int64),
                lambda _, x: dataset_ops.Dataset.from_tensors(x),
                window_size=num_consumers))
        ds = ds.flat_map(lambda x: x)

        # Set up the per-consumer-host datasets. During each global step, we pull
        # `replicas_per_consumer_host` batches from each of these datasets.
        host_datasets = []
        for host_index in range(num_consumer_hosts):
            per_replica_datasets = []
            for i in range(replicas_per_consumer_host):
                consumer_index = host_index * replicas_per_consumer_host + i
                per_replica_datasets.append(
                    self.make_distributed_dataset(
                        ds,
                        cluster,
                        job_name="test",
                        consumer_index=consumer_index,
                        num_consumers=num_consumers))
            host_dataset = dataset_ops.Dataset.from_tensor_slices(
                per_replica_datasets)
            host_dataset = host_dataset.interleave(
                lambda x: x,
                cycle_length=len(per_replica_datasets),
                num_parallel_calls=len(per_replica_datasets),
                deterministic=True)
            host_datasets.append(host_dataset)

        # Use parallel interleave to read from host datasets in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
        ds = ds.interleave(lambda x: x,
                           block_length=replicas_per_consumer_host,
                           cycle_length=len(host_datasets),
                           num_parallel_calls=len(host_datasets),
                           deterministic=True)

        num_rounds = 10
        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        for _ in range(num_rounds * num_consumers):
            results.append(self.evaluate(get_next()))

        def get_bucket(elem):
            bucket_ind = 0
            while bucket_ind < len(bucket_boundaries
                                   ) and elem >= bucket_boundaries[bucket_ind]:
                bucket_ind += 1
            return bucket_ind

        # Check that the batches for each step contain elements from the same
        # bucket.
        for i in range(0, len(results), num_consumers):
            batches = results[num_consumers * i:num_consumers * (i + 1)]
            bucket_inds = [get_bucket(batch[0]) for batch in batches]
            for bucket_ind in bucket_inds[1:]:
                self.assertEqual(bucket_inds[0], bucket_ind)

    @combinations.generate(test_base.v1_only_combinations())
    def testRoundRobinFiniteV1(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError,
                "Encountered end of sequence on a "
                "round-robin read iterator"):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(test_base.v2_only_combinations())
    def testRoundRobinFiniteV2(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError, "Round robin reads "
                "require that the input dataset has infinite "
                "cardinality, but the dataset has cardinality " +
                str(num_elements)):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = self.create_cluster(num_workers=1,
                                      job_gc_check_interval_ms=50,
                                      job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.num_tasks_on_worker(), 1)
        del it
        while cluster.num_tasks_on_worker() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = self.create_cluster(num_workers=1,
                                      job_gc_check_interval_ms=50,
                                      job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(2, cluster.num_tasks_on_worker())
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.num_tasks_on_worker() > 1:
            time.sleep(0.1)
        self.assertEqual(1, cluster.num_tasks_on_worker())

    @combinations.generate(test_base.eager_only_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = self.create_cluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = self.create_cluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        next(iter(ds))

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(external_state_policy=[
                distribute_options.ExternalStatePolicy.IGNORE,
                distribute_options.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.eager_only_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochTensorSlices(self):
        cluster = self.create_cluster(num_workers=2)
        vals = [5, 1, 2, 4]
        ds = dataset_ops.Dataset.from_tensor_slices(vals)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, vals, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochInterleave(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochParallelInterleave(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]),
            num_parallel_calls=dataset_ops.AUTOTUNE)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochFlatMap(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochForeverRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_elements = 20
        elements_to_read = 1000
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        it = iter(ds)
        results = {}
        for _ in range(elements_to_read):
            val = next(it).numpy()
            if val not in results:
                results[val] = 0
            results[val] += 1
        for i in range(num_elements):
            self.assertGreater(results[i], elements_to_read / num_elements / 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochForeverRepeatFewElements(self):
        num_workers = 5
        cluster = self.create_cluster(num_workers=num_workers)
        # Less than the number of workers, so that some workers get zero elements on
        # the first repetition.
        num_elements = 1
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        it = iter(ds)
        for _ in range(100):
            self.assertEqual(next(it).numpy(), 0)

        # Stop all but one worker and check that we can still read.
        for i in range(num_workers - 1):
            cluster.workers[i]._stop()
        for _ in range(100):
            self.assertEqual(next(it).numpy(), 0)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochShuffleAndRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).shuffle(
            num_elements).repeat(num_repeats)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    def testDistributeFromInterleave(self):
        cluster = self.create_cluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(_):
            dataset = dataset_ops.Dataset.range(2)
            self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 0, 1, 1])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpoch(self):
        cluster = self.create_cluster(num_workers=2)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError,
                                    "invalid is not a valid processing mode"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.eager_only_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetId(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = self.create_cluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = self.create_cluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds, requires_initialization=True)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.eager_only_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = self.create_cluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
        id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
        self.assertEqual(id_1.numpy(), id_2.numpy())

    @combinations.generate(test_base.eager_only_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = self.create_cluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
        id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
        self.assertNotEqual(id_1.numpy(), id_2.numpy())

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpochOnZippedDataset(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = self.create_cluster(num_workers=1)

        ds_3 = dataset_ops.Dataset.zip((ds_1, ds_2))
        ds_3 = self.make_distributed_dataset(
            ds_3, cluster, processing_mode="distributed_epoch")

        error_regex = "Cannot create a split provider for dataset " + \
            "of type ZipDataset"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds_3, requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpochOnDistributedDataset(self):
        cluster_1 = self.create_cluster(num_workers=1)
        cluster_2 = self.create_cluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        numbers = [1 * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(numbers)
        ds = self.make_distributed_dataset(ds,
                                           cluster_1,
                                           processing_mode="parallel_epochs")
        ds = ds.map(lambda x: x + 1)
        ds = self.make_distributed_dataset(ds,
                                           cluster_2,
                                           processing_mode="distributed_epoch")

        error_regex = "Cannot create a split provider for dataset " + \
            "of type DataServiceDataset"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = self.create_cluster(num_workers=cluster_1_size)
        cluster_2 = self.create_cluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        it = iter(ds)
        for _ in range(num_sizes):
            element = next(it).numpy()
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(next(it).numpy(), element)
        self.assertEmpty(list(it))

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations()))
    def testDistributeLargeGraph(self):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=NO_WORK_DIR,
                                      fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherStop(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        results.append(next(iterator).numpy())
        cluster.stop_dispatcher()
        # After the dispatcher dies, the worker should continue providing the rest
        # of the dataset's elements.
        for _ in range(num_elements - 1):
            results.append(next(iterator).numpy())
        self.assertEqual(results, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBeforeReading(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        cluster.restart_dispatcher()

        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringReading(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBetweenIterations(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(100, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherManyRestarts(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements_start = 10
        num_elements_end = 15
        datasets = []
        for num_elements in range(num_elements_start, num_elements_end):
            datasets.append(
                self.make_distributed_range_dataset(num_elements, cluster))
            cluster.restart_dispatcher()
        for ds, num_elements in zip(
                datasets, range(num_elements_start, num_elements_end)):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndWorkerRestart(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds, cluster)

        cluster.restart_dispatcher()
        cluster.restart_worker()
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        cluster.restart_worker()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        cluster = self.create_cluster(num_workers=1,
                                      dispatcher_port=dispatcher_port,
                                      start=False)

        def start_servers():
            time.sleep(0.5)
            cluster.start_dispatcher()
            cluster.start_workers()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()

    @combinations.generate(test_base.eager_only_combinations())
    def testAddWorkerMidJob(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        cluster.add_worker()
        # Wait for the new worker to register with the dispatcher.
        while cluster.num_registered_workers() < 2:
            time.sleep(10 / 1000)  # 10ms

        for elem in iterator:
            results.append(elem.numpy())

        self.assertCountEqual(2 * list(range(num_elements)), results)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(use_same_port=[True, False]),
                           data_service_test_base.all_cluster_configurations())
    )
    def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        cluster.restart_worker(use_same_port=use_same_port)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)

    @combinations.generate(test_base.eager_only_combinations())
    def testChangeProcessingModeAfterRestart(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        range_dataset = dataset_ops.Dataset.range(num_elements)
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service=cluster.target,
                                        job_name="test"))
        iterator = iter(ds)
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())
        cluster.restart_dispatcher()
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="distributed_epoch",
                                        service=cluster.target,
                                        job_name="test"))
        with self.assertRaisesOpError(
                "already an existing job with that name "
                "using processing mode <parallel_epochs>"):
            next(iter(ds)).numpy()

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR])))
    def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
        cluster = self.create_cluster(num_workers=0,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        it = iter(ds)
        cluster.add_worker()
        self.assertAllEqual(next(it), tensor)