def testDistributeInvalidProcessingMode(self):
     ds = dataset_ops.Dataset.range(10)
     with self.assertRaisesRegex(ValueError,
                                 "invalid is not a valid processing mode"):
         ds = ds.apply(
             data_service_ops.distribute(processing_mode="invalid",
                                         service="grpc://localhost:5000"))
 def testDistributeEmptyAddress(self):
     ds = dataset_ops.Dataset.range(10)
     with self.assertRaisesWithLiteralMatch(ValueError,
                                            "service must not be empty"):
         ds = ds.apply(
             data_service_ops.distribute(processing_mode="parallel_epochs",
                                         service=""))
示例#3
0
 def testDistributeDistributedEpochTensorSlices(self):
     cluster = self.create_cluster(2)
     vals = [5, 1, 2, 4]
     ds = dataset_ops.Dataset.from_tensor_slices(vals)
     ds = ds.apply(
         data_service_ops.distribute(processing_mode="distributed_epoch",
                                     service=cluster.target))
     self.assertDatasetProduces(ds, vals, assert_items_equal=True)
 def testDistributeInvalidProcessingMode(self):
   ds = dataset_ops.Dataset.range(10)
   with self.assertRaisesRegex(
       ValueError, "should be a ShardingPolicy, `\"parallel_epochs\"`, or "
       "`\"distributed_epoch\"`. Got 'invalid'."):
     ds = ds.apply(
         data_service_ops.distribute(
             processing_mode="invalid", service="grpc://localhost:5000"))
 def testNonStringJobNameDistribute(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     with self.assertRaisesRegex(ValueError, "job_name must be a string"):
         dataset_ops.Dataset.range(10).apply(
             data_service_ops.distribute(
                 processing_mode="parallel_epochs",
                 service=cluster.dispatcher.target,
                 job_name=constant_op.constant("foo")))
示例#6
0
 def testEmptyJobNameDistribute(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     with self.assertRaisesRegex(ValueError,
                                 "`job_name` must not be empty"):
         dataset_ops.Dataset.range(10).apply(
             data_service_ops.distribute(processing_mode="parallel_epochs",
                                         service=cluster.dispatcher.target,
                                         job_name=""))
示例#7
0
 def testDistributeDistributedEpochTensorSlices(self):
     dispatcher, workers = self.start_cluster(2)  # to avoid gcing workers, pylint: disable=unused-variable
     vals = [5, 1, 2, 4]
     ds = dataset_ops.Dataset.from_tensor_slices(vals)
     ds = ds.apply(
         data_service_ops.distribute(processing_mode="distributed_epoch",
                                     service=dispatcher.target))
     self.assertDatasetProduces(ds, vals, assert_items_equal=True)
示例#8
0
 def testTfDataService(self):
     ds = dataset_ops.Dataset.range(10)
     ds = ds.apply(
         data_service_ops.distribute("parallel_epochs", "grpc://foo:0"))
     ops = traverse.obtain_capture_by_value_ops(ds)
     self.assertContainsSubset(
         ["RangeDataset", "DataServiceDatasetV2", "DummyIterationCounter"],
         set(x.name for x in ops))
示例#9
0
 def testDistributeExplicitProtocol(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     ds = dataset_ops.Dataset.range(10)
     ds = ds.apply(
         data_service_ops.distribute(processing_mode="parallel_epochs",
                                     service="grpc://" +
                                     cluster.dispatcher_address()))
     self.assertDatasetProduces(ds, list(range(10)))
示例#10
0
 def testDistributeDistributedEpoch(self):
   cluster = self.create_cluster(num_workers=2)
   num_elements = 100
   ds = dataset_ops.Dataset.range(num_elements)
   ds = ds.apply(
       data_service_ops.distribute(
           processing_mode="distributed_epoch", service=cluster.target))
   self.assertDatasetProduces(
       ds, list(range(num_elements)), assert_items_equal=True)
示例#11
0
 def testMultipleEpochs(self):
     service = self.create_cluster(1)
     ds = dataset_ops.Dataset.range(3)
     ds = ds.apply(data_service_ops.distribute(service))
     for _ in range(10):
         token = data_service_ops.create_job(
             ds, processing_mode="parallel_epochs")
         it = data_service_ops.create_iterator(ds, token)
         self.assertEqual(list(range(3)), [t.numpy() for t in it])
 def testDistributeDistributedEpoch(self):
   dispatcher, workers = self.start_cluster(2)  # to avoid gcing workers, pylint: disable=unused-variable
   num_elements = 100
   ds = dataset_ops.Dataset.range(num_elements)
   ds = ds.apply(
       data_service_ops.distribute(
           processing_mode="distributed_epoch", service=dispatcher.target))
   self.assertDatasetProduces(
       ds, list(range(num_elements)), assert_items_equal=True)
示例#13
0
 def testDistributeBasic(self):
     num_elements = 10
     service = self.create_cluster(1)
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(data_service_ops.distribute(service))
     token = data_service_ops.create_job(ds,
                                         processing_mode="parallel_epochs")
     it = data_service_ops.create_iterator(ds, token)
     results = [t.numpy() for t in it]
     self.assertEqual(list(range(num_elements)), results)
示例#14
0
 def testChangeProcessingModeAfterRestart(self):
     dispatcher, workers = self.start_cluster(1)  # to avoid gcing workers, pylint: disable=unused-variable
     num_elements = 100
     range_dataset = dataset_ops.Dataset.range(num_elements)
     ds = range_dataset.apply(
         data_service_ops.distribute(processing_mode="parallel_epochs",
                                     service=dispatcher.target,
                                     job_name="test"))
     iterator = iter(ds)
     for i in range(num_elements // 2):
         self.assertEqual(i, next(iterator).numpy())
     dispatcher = self.restart_dispatcher(dispatcher)
     ds = range_dataset.apply(
         data_service_ops.distribute(processing_mode="distributed_epoch",
                                     service=dispatcher.target,
                                     job_name="test"))
     with self.assertRaisesOpError(
             "already an existing job with that name "
             "using processing mode <parallel_epochs>"):
         next(iter(ds)).numpy()
示例#15
0
 def testDistributeInvalidProtocol(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     ds = dataset_ops.Dataset.range(10)
     with self.assertRaisesRegex(
             errors.NotFoundError,
             "No credentials factory has been registered for protocol grp"):
         ds = ds.apply(
             data_service_ops.distribute(processing_mode="parallel_epochs",
                                         service="grp://" +
                                         cluster.dispatcher_address()))
         self.getDatasetOutput(ds)
示例#16
0
 def testDistributeDistributedEpochShuffleAndRepeat(self):
   cluster = self.create_cluster(2)
   num_repeats = 5
   num_elements = 20
   ds = dataset_ops.Dataset.range(num_elements).shuffle(num_elements).repeat(
       num_repeats)
   ds = ds.apply(
       data_service_ops.distribute(
           processing_mode="distributed_epoch", service=cluster.target))
   self.assertDatasetProduces(
       ds, num_repeats * list(range(num_elements)), assert_items_equal=True)
示例#17
0
 def testTfDataService(self):
     ds = dataset_ops.Dataset.range(10)
     ds = ds.apply(
         data_service_ops.distribute("parallel_epochs", "grpc://foo:0"))
     ops = traverse.obtain_capture_by_value_ops(ds)
     data_service_dataset_op = ("DataServiceDatasetV4"
                                if compat.forward_compatible(2022, 8, 31)
                                else "DataServiceDatasetV3")
     self.assertContainsSubset(
         ["RangeDataset", data_service_dataset_op, "DummyIterationCounter"],
         set(x.name for x in ops))
 def testChangeProcessingModeAfterRestart(self):
     cluster = self.create_cluster(num_workers=1)
     num_elements = 100
     range_dataset = dataset_ops.Dataset.range(num_elements)
     ds = range_dataset.apply(
         data_service_ops.distribute(processing_mode="parallel_epochs",
                                     service=cluster.target,
                                     job_name="test"))
     iterator = iter(ds)
     for i in range(num_elements // 2):
         self.assertEqual(i, next(iterator).numpy())
     cluster.restart_dispatcher()
     ds = range_dataset.apply(
         data_service_ops.distribute(processing_mode="distributed_epoch",
                                     service=cluster.target,
                                     job_name="test"))
     with self.assertRaisesOpError(
             "already an existing job with that name "
             "using processing mode <parallel_epochs>"):
         next(iter(ds)).numpy()
示例#19
0
 def testMultiWorker(self):
     num_workers = 3
     num_elements = 10
     service = self.create_cluster(num_workers)
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(data_service_ops.distribute(service))
     token = data_service_ops.create_job(ds,
                                         processing_mode="parallel_epochs")
     iterator = data_service_ops.create_iterator(ds, token)
     results = [elem.numpy() for elem in iterator]
     self.assertCountEqual(num_workers * list(range(num_elements)), results)
示例#20
0
    def __iter__(self):
        datasets: List[Tuple[int, dataset_ops.DatasetV2]] = []

        # Start with the batched the dataset.
        local_dataset = self._batched_dataset

        # If a replica is split over multiple clients then each batch needs to be
        # repeated before distribution as many times as there are clients
        # corresponding to that replica.
        if self._batch_dim is not None:
            local_dataset = self._repeat_batch(local_dataset,
                                               self._num_clients_per_replica)

        # Apply distribution here (if specified) so all remaining transformations
        # are executed locally.
        if self._tf_data_service_config is not None:
            if self._batch_dim is None:
                sharding_policy = data_service_ops.ShardingPolicy.OFF
            else:
                sharding_policy = data_service_ops.ShardingPolicy.FILE_OR_DATA

            local_dataset = local_dataset.apply(
                data_service_ops.distribute(
                    processing_mode=sharding_policy,
                    service=self._tf_data_service_config.dispatcher_address,
                    job_name=
                    f'{self._tf_data_service_config.job_name}_{api.client_id()}',
                    target_workers='LOCAL'))

        for local_replica_idx, replica_id in enumerate(
                self._local_replica_ids):
            # Select the shard for the corresponding replica.
            dataset = local_dataset.shard(self._num_local_replicas,
                                          local_replica_idx)

            # Repeat each batch for each local device in the replica.
            dataset = self._repeat_batch(dataset,
                                         self._num_local_devices_per_replica)

            # Slice each shard further for all non-batch dim shards. If there is no
            # non-batch dim sharding, this slice is essentially a no-op.
            dataset = self._partition(dataset)

            # Apply prefetch as the last step. Since each batch is repeated, the
            # number of elements to prefetch has to be scaled by the same size.
            if self._prefetch is not None:
                dataset = dataset.prefetch(self._prefetch *
                                           self._num_local_devices_per_replica)

            datasets.append((replica_id, dataset))

        return _DTensorIterator(datasets, self._element_spec, self._layouts,
                                self._num_local_devices_per_replica)
示例#21
0
 def testRequiresInfiniteDataset(self, range_):
     cluster = self._create_cluster(num_workers=1)
     dataset = dataset_ops.Dataset.range(range_).map(lambda x: x + 1)
     with self.assertRaisesRegex(
             errors.InvalidArgumentError,
             "Cross-trainer caching requires the input dataset to be infinite."
     ):
         dataset = dataset.apply(
             data_service_ops.distribute(
                 processing_mode=data_service_ops.ShardingPolicy.OFF,
                 service=cluster.dispatcher.target,
                 job_name="job_name",
                 cross_trainer_cache=data_service_ops.CrossTrainerCache(
                     trainer_id="Trainer ID")))
         self.getDatasetOutput(dataset)
示例#22
0
 def f():
     ds = dataset_ops.Dataset.range(num_elements)
     ds = ds.apply(data_service_ops.distribute(service))
     token = data_service_ops.create_job(
         ds, processing_mode="parallel_epochs")
     it = data_service_ops.create_iterator(ds, token)
     result = tensor_array_ops.TensorArray(dtypes.int64,
                                           size=num_workers *
                                           num_elements,
                                           dynamic_size=True)
     i = 0
     for elem in it:
         result = result.write(i, elem)
         i += 1
     return result.stack()
示例#23
0
    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        service = self.create_cluster(3)
        ds = ds.apply(data_service_ops.distribute(service))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        next(iterator)
示例#24
0
    def testConcurrentEpoch(self):
        num_elements = 10
        num_datasets = 3
        service = self.create_cluster(1)
        iterators = []
        results = []
        for _ in range(num_datasets):
            ds = dataset_ops.Dataset.range(num_elements)
            ds = ds.apply(data_service_ops.distribute(service))
            token = data_service_ops.create_job(
                ds, processing_mode="parallel_epochs")
            it = data_service_ops.create_iterator(ds, token)
            iterators.append(it)
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = next(iterators[dataset_ind]).numpy()
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)
示例#25
0
    def testSharedEpoch(self):
        num_elements = 10
        num_iterators = 3
        service = self.create_cluster(1)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(data_service_ops.distribute(service))
        result = []
        iterators = []
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        for _ in range(num_iterators):
            iterators.append(data_service_ops.create_iterator(ds, token))

        # Alternate reading between the iterators.
        for _ in range(2):
            for it in iterators:
                result.append(next(it).numpy())

        # Drain the rest of the elements.
        for it in iterators:
            for elem in it:
                result.append(elem.numpy())

        self.assertCountEqual(list(range(num_elements)), result)
 def testDistributeNonStringAddresses(self):
     ds = dataset_ops.Dataset.range(10)
     with self.assertRaisesRegex(ValueError, "service must be a string"):
         ds = ds.apply(
             data_service_ops.distribute(processing_mode="parallel_epochs",
                                         service=1))
示例#27
0
 def interleave_fn(_):
     ds = dataset_ops.Dataset.range(2)
     ds = ds.apply(data_service_ops.distribute(service))
     return ds