示例#1
0
  def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
    if not tf2.enabled():
      self.skipTest("unsupported test combination")

    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
    input_workers = input_lib.InputWorkers(worker_device_pairs)
    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
    dataset_or_input_fn = self._create_dataset_or_input_fn(
        input_type, dataset_fn)

    input_workers = input_lib.InputWorkers(worker_device_pairs)
    if input_type == "dataset":
      dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
                                                       input_workers,
                                                       distribution)
    else:
      dist_dataset = input_lib.get_distributed_datasets_from_function(
          dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
          distribution)

    # Default Iterator types in TF2.
    iterator = iter(dist_dataset)
    self.assertIsInstance(iterator, input_lib.DistributedIterator)
    self.assertIsInstance(iterator._iterators[0],
                          input_lib._SingleWorkerOwnedDatasetIterator)

    # Disable creating owned iterators by setting a property on the strategy.
    distribution._enable_legacy_iterators = True
    iterator = iter(dist_dataset)
    self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
    self.assertIsInstance(iterator._iterators[0],
                          input_lib._SingleWorkerDatasetIterator)
示例#2
0
 def _experimental_distribute_dataset(self, dataset, options):
   # Note that split_batch_by argument is not passed because it is always 1 in
   # this strategy, and adding it adds unnecessary overhead to the dataset.
   return input_lib.get_distributed_dataset(
       dataset,
       self._input_workers_with_options(options),
       self._container_strategy())
 def _experimental_distribute_dataset(self, dataset, options):
     return input_lib.get_distributed_dataset(
         dataset,
         self._input_workers_with_options(options),
         self._container_strategy(),
         num_replicas_in_sync=self._num_replicas_in_sync,
         options=options)
 def _experimental_distribute_dataset(self, dataset):
   input_context = self._make_input_context()
   return input_lib.get_distributed_dataset(
       dataset,
       self._input_workers,
       self._container_strategy(),
       split_batch_by=self._num_replicas_in_sync,
       input_context=input_context)
示例#5
0
 def _experimental_distribute_dataset(self, dataset):
     input_context = self._make_input_context()
     return input_lib.get_distributed_dataset(
         dataset,
         self._input_workers,
         self._container_strategy(),
         split_batch_by=self._num_replicas_in_sync,
         input_context=input_context)
示例#6
0
    def _experimental_distribute_dataset(self, dataset, options):
        if options is None or options.experimental_prefetch_to_device:
            self._check_spec(dataset.element_spec)

        return input_lib.get_distributed_dataset(
            dataset,
            self._get_input_workers(options),
            self._container_strategy(),
            split_batch_by=self._num_replicas_in_sync)
示例#7
0
  def testDatasetV2IterError(self, distribution):
    worker_device_pairs = [("", ["/device:CPU:0"])]
    input_workers = input_lib.InputWorkers(worker_device_pairs)
    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)

    dist_dataset = input_lib.get_distributed_dataset(
        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)

    with self.assertRaisesRegexp(RuntimeError,
                                 "or when eager execution is enabled"):
      iter(dist_dataset)
示例#8
0
    def testIterableIterator(self, distribution):
        worker_device_pairs = [("", ["/device:CPU:0"])]
        input_workers = input_lib.InputWorkers(worker_device_pairs)

        dataset = dataset_ops.DatasetV2.range(10)
        dist_dataset = input_lib.get_distributed_dataset(
            dataset, input_workers, distribution)

        iterator = iter(dist_dataset)
        for i, element in enumerate(iterator):
            self.assertEqual(i, element.numpy())
 def _experimental_distribute_dataset(self, dataset, options):
     if (options and options.experimental_replication_mode
             == distribute_lib.InputReplicationMode.PER_REPLICA):
         raise NotImplementedError("InputReplicationMode.PER_REPLICA "
                                   "is only supported in "
                                   "`distribute_datasets_from_function`.")
     return input_lib.get_distributed_dataset(
         dataset,
         self._input_workers_with_options(options),
         self._container_strategy(),
         num_replicas_in_sync=self._num_replicas_in_sync,
         options=options)
示例#10
0
  def testDatasetV2IterError(self, distribution):
    worker_device_pairs = [("", ["/device:CPU:0"])]
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    device_map = values.ReplicaDeviceMap(devices)
    input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)

    dist_dataset = input_lib.get_distributed_dataset(
        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)

    with self.assertRaisesRegexp(RuntimeError,
                                 "or when eager execution is enabled"):
      iter(dist_dataset)
    def _experimental_distribute_dataset(self, dataset, options):
        input_workers_devices = self._input_workers_with_options()

        # If this DistributedDataset is created outside ClusterCoordinator, i,e,
        # outside a tf.function, we don't build its underlying datasets immediately
        # until it is passed to ClusterCoordinator.create_per_worker_dataset.
        return input_lib.get_distributed_dataset(
            dataset,
            input_workers_devices,
            self._container_strategy(),
            num_replicas_in_sync=self._num_replicas_in_sync,
            options=options,
            build=ops.inside_function())  # will be built by ClusterCoordinator
示例#12
0
    def testIterableIterator(self, distribution):
        worker_device_pairs = [("", ["/device:CPU:0"])]
        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        device_map = values.ReplicaDeviceMap(devices)
        input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

        dataset = dataset_ops.DatasetV2.range(10)
        dist_dataset = input_lib.get_distributed_dataset(
            dataset, input_workers, distribution)

        iterator = iter(dist_dataset)
        for i, element in enumerate(iterator):
            self.assertEqual(i, element.numpy())
 def _experimental_distribute_dataset(self, dataset, options):
     # Note that split_batch_by argument is not passed because it is always 1 in
     # this strategy, and adding it adds unnecessary overhead to the dataset.
     if (options and options.experimental_replication_mode
             == distribute_lib.InputReplicationMode.PER_REPLICA):
         raise NotImplementedError(
             "InputReplicationMode.PER_REPLICA "
             "is only supported in  "
             "`experimental_distribute_datasets_from_function`.")
     return input_lib.get_distributed_dataset(
         dataset,
         self._input_workers_with_options(options),
         self._container_strategy(),
         options=options)
示例#14
0
  def testMultiDeviceIterInitialize(self, distribution):
    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)

    input_workers = input_lib.InputWorkers(worker_device_pairs)

    dist_dataset = input_lib.get_distributed_dataset(
        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)

    iterator = dataset_ops.make_one_shot_iterator(dist_dataset)

    @def_function.function
    def init_func_for_iter():
      self.evaluate(iterator.initializer)

    init_func_for_iter()
    def _experimental_distribute_dataset(self, dataset, options):
        self._assert_used_with_cluster_coordinator()
        if not ops.get_default_graph().building_function:
            raise ValueError(
                "The `experimental_distribute_dataset` method must be called inside "
                "a `tf.function` passed to `create_per_worker_dataset` of "
                "`tf.distribute.experimental.coordinator.ClusterCoordinator`")

        input_workers_devices = self._input_workers_with_options()

        return input_lib.get_distributed_dataset(
            dataset,
            input_workers_devices,
            self._container_strategy(),
            num_replicas_in_sync=self._num_replicas_in_sync,
            options=options)
 def _experimental_distribute_dataset(self, dataset):
   # Note that split_batch_by argument is not passed because it is always 1 in
   # this strategy, and adding it adds unnecessary overhead to the dataset.
   return input_lib.get_distributed_dataset(dataset, self._input_workers,
                                            self._container_strategy())
示例#17
0
 def _experimental_distribute_dataset(self, dataset):
     return input_lib.get_distributed_dataset(dataset, self._input_workers,
                                              self._num_replicas_in_sync)
示例#18
0
 def _experimental_distribute_dataset(self, dataset, options):
     return input_lib.get_distributed_dataset(
         dataset,
         self._get_input_workers(options),
         self._container_strategy(),
         split_batch_by=self._num_replicas_in_sync)
示例#19
0
 def _experimental_distribute_dataset(self, dataset):
   return input_lib.get_distributed_dataset(dataset, self._input_workers,
                                            self._num_replicas_in_sync)
示例#20
0
 def _experimental_distribute_dataset(self, dataset):
     return input_lib.get_distributed_dataset(dataset, self._input_workers,
                                              self._container_strategy())