Esempio n. 1
0
 def testOneLocalWorker(self):
   cluster = multi_process_cluster.MultiProcessCluster(
       num_local_workers=1, num_remote_workers=5)
   num_elements = 10
   ds = self.make_distributed_range_dataset(
       num_elements, cluster, target_workers="local")
   self.assertDatasetProduces(ds, list(range(num_elements)))
Esempio n. 2
0
    def testAddRemoteWorkersMidJob(self, num_local_workers,
                                   num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers,
            worker_tags=[_COLOCATED_WORKER_TAG])

        num_elements = 300
        dataset = self.make_distributed_range_dataset(num_elements, cluster)
        get_next = self.getNext(dataset)
        results = [self.evaluate(get_next()) for _ in range(100)]

        # Will only read from the two non-TPU workers.
        cluster.start_remote_worker(worker_tags=None)
        cluster.start_remote_worker(worker_tags=[_COLOCATED_WORKER_TAG])
        cluster.start_remote_worker(worker_tags=None)
        cluster.start_remote_worker(worker_tags=[_COLOCATED_WORKER_TAG])
        expect_num_workers_to_read = num_local_workers + 2

        # Wait for the new worker to register with the dispatcher.
        while cluster._dispatcher._num_workers() < (num_local_workers +
                                                    num_remote_workers + 4):
            time.sleep(10 / 1000)  # 10ms

        results += self.getIteratorOutput(get_next)
        self.assertCountEqual(
            results, expect_num_workers_to_read * list(range(num_elements)))
Esempio n. 3
0
 def testInvalidTag(self):
     with self.assertRaisesRegex(RuntimeError,
                                 "Worker tags cannot be empty."):
         _ = multi_process_cluster.MultiProcessCluster(
             num_local_workers=1,
             num_remote_workers=3,
             worker_tags=["", _COLOCATED_WORKER_TAG])
Esempio n. 4
0
    def testAnonymousJobWithDifferentTargetWorkers(self):
        num_local_workers, num_remote_workers = (3, 3)
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers, num_remote_workers)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        datasets = {
            target_workers:
            self.make_distributed_dataset(ds,
                                          cluster,
                                          target_workers=target_workers)
            for target_workers in ["AUTO", "ANY", "LOCAL"]
        }

        num_workers = num_local_workers + num_remote_workers
        self.assertDatasetProduces(datasets["AUTO"],
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)
        self.assertDatasetProduces(datasets["ANY"],
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)
        self.assertDatasetProduces(datasets["LOCAL"],
                                   num_local_workers *
                                   list(range(num_elements)),
                                   assert_items_equal=True)
Esempio n. 5
0
 def testEmptyDataset(self, num_local_workers, num_remote_workers):
   cluster = multi_process_cluster.MultiProcessCluster(
       num_local_workers=num_local_workers,
       num_remote_workers=num_remote_workers)
   num_elements = 0
   ds = self.make_distributed_range_dataset(
       num_elements, cluster, target_workers="LOCAL")
   self.assertDatasetProduces(ds, [])
Esempio n. 6
0
 def testRepeatDistributedDataset(self, num_remote_workers, job_name):
   num_local_workers = 1
   cluster = multi_process_cluster.MultiProcessCluster(
       num_local_workers=num_local_workers,
       num_remote_workers=num_remote_workers)
   dataset = self.make_distributed_range_dataset(
       10, cluster, job_name=job_name, target_workers="LOCAL")
   dataset = dataset.repeat(3)
   self.assertDatasetProduces(dataset, list(range(10)) * 3)
Esempio n. 7
0
 def testReadFromLocalWorker(self, num_remote_workers):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=1,
         num_remote_workers=num_remote_workers,
         worker_tags=[_COLOCATED_WORKER_TAG])
     num_elements = 100
     dataset = self.make_distributed_range_dataset(num_elements, cluster)
     # Only reads from the local worker.
     self.assertDatasetProduces(dataset, list(range(num_elements)))
Esempio n. 8
0
 def testNoLocalWorker(self):
   cluster = multi_process_cluster.MultiProcessCluster(
       num_local_workers=0, num_remote_workers=3)
   num_elements = 10
   ds = self.make_distributed_range_dataset(
       num_elements, cluster, target_workers="LOCAL")
   with self.assertRaisesRegex(errors.InvalidArgumentError,
                               "no local worker is found"):
     get_next = self.getNext(ds)
     self.evaluate(get_next())
Esempio n. 9
0
 def testCoordinatedRead(self, num_local_workers, num_remote_workers):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=num_local_workers,
         num_remote_workers=num_remote_workers,
         worker_tags=[_COLOCATED_WORKER_TAG])
     num_consumers = 4
     dataset = self.make_coordinated_read_dataset(cluster, num_consumers)
     get_next = self.getNext(dataset)
     results = [self.evaluate(get_next()) for _ in range(200)]
     self.checkCoordinatedReadGroups(results, num_consumers)
Esempio n. 10
0
 def testLocalWorkers(self, num_local_workers, num_remote_workers):
   cluster = multi_process_cluster.MultiProcessCluster(
       num_local_workers=num_local_workers,
       num_remote_workers=num_remote_workers)
   num_elements = 10
   ds = self.make_distributed_range_dataset(
       num_elements, cluster, target_workers="LOCAL")
   self.assertDatasetProduces(
       ds,
       num_local_workers * list(range(num_elements)),
       assert_items_equal=True)
Esempio n. 11
0
 def testUnusedTags(self):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=1,
         num_remote_workers=3,
         worker_tags=["Unused tag 1", "Unused tag 2", "Unused tag 3"])
     num_elements = 100
     dataset = self.make_distributed_range_dataset(num_elements, cluster)
     # The tags don't have an effect. tf.data service will read from all workers.
     self.assertDatasetProduces(dataset,
                                4 * list(range(num_elements)),
                                assert_items_equal=True)
Esempio n. 12
0
  def testNonLocalRead(self, num_local_workers, num_remote_workers):
    """This test ensures the remote workers are running and producing data."""

    cluster = multi_process_cluster.MultiProcessCluster(
        num_local_workers=num_local_workers,
        num_remote_workers=num_remote_workers)
    num_elements = 10
    ds = self.make_distributed_range_dataset(num_elements, cluster)
    num_workers = num_local_workers + num_remote_workers
    self.assertDatasetProduces(
        ds, num_workers * list(range(num_elements)), assert_items_equal=True)
Esempio n. 13
0
 def testDynamicSharding(self, num_local_workers, num_remote_workers):
   cluster = multi_process_cluster.MultiProcessCluster(
       num_local_workers=num_local_workers,
       num_remote_workers=num_remote_workers)
   num_elements = 100
   ds = self.make_distributed_range_dataset(
       num_elements,
       cluster,
       processing_mode=data_service_ops.ShardingPolicy.DYNAMIC,
       target_workers="LOCAL")
   self.assertDatasetProduces(
       ds, list(range(num_elements)), assert_items_equal=True)
Esempio n. 14
0
 def testCluster(self, num_local_workers, num_remote_workers):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=num_local_workers,
         num_remote_workers=num_remote_workers)
     num_elements = 10
     num_workers = num_local_workers + num_remote_workers
     if num_workers == 0:
         return
     dataset = self.make_distributed_range_dataset(num_elements, cluster)
     self.assertDatasetProduces(dataset,
                                num_workers * list(range(num_elements)),
                                assert_items_equal=True)
Esempio n. 15
0
  def testNoLocalWorker(self):
    cluster = multi_process_cluster.MultiProcessCluster(
        num_local_workers=0, num_remote_workers=3)
    num_elements = 10
    ds = self.make_distributed_range_dataset(
        num_elements, cluster, target_workers="LOCAL")

    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Local reads require local tf.data workers, but no local worker is "
        "found."):
      self.getDatasetOutput(ds)
Esempio n. 16
0
    def testMultipleEpochs(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_epochs, num_steps = 5, 5
        dataset = self._make_distributed_infinite_range_dataset(cluster)
        for _ in range(num_epochs):
            # For each iteration, the previous iterator is garbage collected.
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(self.evaluate(get_next()), i)
Esempio n. 17
0
 def testRepeatedDataset(self, num_local_workers, num_remote_workers):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=num_local_workers,
         num_remote_workers=num_remote_workers)
     num_elements = 10
     num_repetitions = 5
     ds = self.make_distributed_range_dataset(num_elements,
                                              cluster,
                                              target_workers="LOCAL")
     ds = ds.repeat(num_repetitions)
     self.assertDatasetProduces(ds,
                                expected_output=num_local_workers *
                                num_repetitions * list(range(num_elements)),
                                assert_items_equal=True)
Esempio n. 18
0
 def testCoordinatedRead(self):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=3, num_remote_workers=3)
     ds = dataset_ops.Dataset.range(10).repeat()
     ds = self.make_distributed_dataset(ds,
                                        cluster,
                                        job_name="test_job",
                                        consumer_index=0,
                                        num_consumers=3,
                                        target_workers="LOCAL")
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Coordinated reads require non-local workers"):
         self.getDatasetOutput(ds)
Esempio n. 19
0
    def testReadFromLocalAndNonTpuWorkers(self, num_local_workers,
                                          num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers,
            worker_tags=[_COLOCATED_WORKER_TAG])
        cluster.start_remote_worker(worker_tags=None)

        num_elements = 100
        dataset = self.make_distributed_range_dataset(num_elements, cluster)
        # Reads from the local worker or non-colocated worker.
        self.assertDatasetProduces(dataset, (num_local_workers + 1) *
                                   list(range(num_elements)),
                                   assert_items_equal=True)
Esempio n. 20
0
 def testNoLocalWorkers(self):
     cluster = multi_process_cluster.MultiProcessCluster(
         num_local_workers=0, num_remote_workers=3)
     dataset = dataset_ops.Dataset.list_files(self._filenames,
                                              shuffle=False)
     dataset = dataset.flat_map(readers.TFRecordDataset)
     dataset = self.make_distributed_dataset(
         dataset,
         cluster=cluster,
         processing_mode=ShardingPolicy.FILE_OR_DATA)
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Local reads or static sharding require local tf.data workers"
     ):
         self.getDatasetOutput(dataset)
Esempio n. 21
0
    def testReadFromLocalWorker_StaticSharding(self):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=1,
            num_remote_workers=3,
            worker_addresses=["localhost:%port%"] * 5,
            worker_tags=[_COLOCATED_WORKER_TAG])
        cluster.start_remote_worker(worker_tags=None)

        num_elements = 100
        dataset = self.make_distributed_range_dataset(
            num_elements,
            cluster,
            processing_mode=data_service_ops.ShardingPolicy.FILE_OR_DATA)

        # Static sharding will only read from the local worker.
        self.assertDatasetProduces(dataset, list(range(0, num_elements, 5)))
Esempio n. 22
0
    def testReadFromLocalAndNonTpuWorkers_DynamicSharding(
            self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=3,
            worker_tags=[_COLOCATED_WORKER_TAG])
        cluster.start_remote_worker(worker_tags=None)

        num_elements = 100
        dataset = self.make_distributed_range_dataset(
            num_elements,
            cluster,
            processing_mode=data_service_ops.ShardingPolicy.DYNAMIC)
        self.assertDatasetProduces(dataset,
                                   list(range(num_elements)),
                                   assert_items_equal=True)
Esempio n. 23
0
  def testInconsistentTargetWorkers(self):
    cluster = multi_process_cluster.MultiProcessCluster(
        num_local_workers=3, num_remote_workers=3)
    ds = dataset_ops.Dataset.range(10)
    datasets = [
        self.make_distributed_dataset(
            ds, cluster, job_name="test_job", target_workers=target_workers)
        for target_workers in ["AUTO", "ANY", "LOCAL"]
    ]

    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "but found an existing job with different parameters: "
        "Existing target workers: <AUTO>"):
      for dataset in datasets:
        self.getDatasetOutput(dataset)
  def testPreferLocalRead(self):
    cluster = multi_process_cluster.MultiProcessCluster(
        num_local_workers=1, num_remote_workers=0)
    num_elements = 100
    dataset = self.make_distributed_range_dataset(num_elements, cluster)
    get_next = self.getNext(dataset)
    self.assertEqual(self.evaluate(get_next()), 0)

    for i in range(1, 4):
      cluster.start_remote_worker()
      # Waits for the new worker to register with the dispatcher.
      while cluster._dispatcher._num_workers() < i + 1:
        time.sleep(10 / 1000)  # 10ms
      # Prefers reading from the local worker.
      self.assertEqual(self.evaluate(get_next()), i)
    self.assertCountEqual(
        self.getIteratorOutput(get_next),
        list(range(4, num_elements)) + 3 * list(range(num_elements)))
Esempio n. 25
0
    def testMultipleEpochs_DispatcherRestart(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Verifies the worker re-creates the task after the iterator is deleted and
        # the dispatcher restarts.
        del get_next
        cluster.restart_dispatcher()

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)
Esempio n. 26
0
    def testReadFromDeletedTask(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Re-creating the dataset resets the iterator index, so the second iterator
        # reads from the same task as the first, which has been deleted.
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "which has been deleted."):
            _ = self.evaluate(get_next())
Esempio n. 27
0
  def testMultipleConsumers(self):
    num_local_workers, num_remote_workers = 1, 3
    cluster = multi_process_cluster.MultiProcessCluster(
        num_local_workers=num_local_workers,
        num_remote_workers=num_remote_workers)
    num_elements = 300
    num_consumers = 8
    iterators = []
    for _ in range(num_consumers):
      dataset = self.make_distributed_range_dataset(
          num_elements, cluster, job_name="shared_job")
      iterators.append(self.getNext(dataset))

    results = []
    for _ in range(10):
      for it in iterators:
        results.append(self.evaluate(it()))
    for it in iterators:
      results.extend(self.getIteratorOutput(it))

    self.assertCountEqual(results, (num_local_workers + num_remote_workers) *
                          list(range(num_elements)))