コード例 #1
0
 def distribute_dataset(self, dataset_fn):
   worker_devices = [
       (self.get_host(hid), [self.get_host_cpu_device(hid)])
       for hid in range(self.num_hosts)
   ]
   return values.MultiWorkerDataset(
       functools.partial(self._call_dataset_fn, dataset_fn), worker_devices)
コード例 #2
0
 def distribute_dataset(self, dataset_fn):
   if self._cluster_spec:
     return values.MultiWorkerDataset(
         partial(self._call_dataset_fn, dataset_fn), self._worker_devices,
         auto_shard=self._auto_shard_dataset)
   else:
     return values.PerReplicaDataset(
         self._call_dataset_fn(dataset_fn), self._devices)
コード例 #3
0
 def distribute_dataset(self, dataset_fn):
   if self._cluster_spec:
     return values.MultiWorkerDataset(
         partial(self._call_dataset_fn, dataset_fn), self._worker_devices,
         self._prefetch_on_device, self._auto_shard_dataset)
   else:
     return values.PerDeviceDataset(
         self._call_dataset_fn(dataset_fn), self._devices,
         self._prefetch_on_device)
コード例 #4
0
 def distribute_dataset(self, dataset_fn):
   if self._cluster_spec:
     return values.MultiWorkerDataset(
         partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
         self._prefetch_on_device)
   else:
     return values.PerDeviceDataset(
         self._call_dataset_fn(dataset_fn),
         self._devices,
         self._prefetch_on_device,
         source_device=device_util.resolve("/device:CPU:0"))
コード例 #5
0
  def make_dataset_iterator(self, dataset):
    """Make iterators for each of the TPU hosts.

    We first unbatch the users input dataset and then rebatch it with the
    per replica batch size that is calculated using
    `global_batch_size // num_replicas_in_sync`. The currently supported cases
    are as follows:
    `dataset.batch()` is the last operation on the dataset.
    `dataset.apply(map_and_batch)` is the last operation on the dataset.
    `dataset.batch().prefetch()` are the last 2 operations on the dataset.
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    Args:
      dataset: The `tf.data` dataset passed by the user.

    Returns:
      iterator: InputIterator created for each of the host machines.
    """
    # TODO(sourabhbajaj): Remove this in lieu of distributed datasets
    def _get_dataset_batch_size(dataset):
      """Get the global batch size from the dataset object."""
      # pylint: disable=protected-access
      if isinstance(dataset, dataset_ops.DatasetV1Adapter):
        dataset = dataset._dataset
      if isinstance(dataset, dataset_ops.BatchDataset):
        return tensor_util.constant_value(dataset._batch_size)
      elif isinstance(dataset, batching._MapAndBatchDataset):
        return dataset._batch_size
      elif isinstance(dataset, dataset_ops.PrefetchDataset):
        return _get_dataset_batch_size(dataset._input_dataset)
      # pylint: enable=protected-access
      raise ValueError(
          "Unable to fetch the batch size from the input dataset. `batch` "
          "`map_and_batch` need to be the last operations on the dataset. "
          "The batch operations can be followed by a prefetch.")

    global_batch_size = _get_dataset_batch_size(dataset)
    if global_batch_size % self.num_replicas_in_sync:
      raise ValueError(
          "Batch size %s cannot be sharded evenly across replicas %s" % (
              global_batch_size, self.num_replicas_in_sync))
    per_replica_batch_size = global_batch_size // self.num_replicas_in_sync
    dataset = dataset.apply(batching.unbatch())
    dataset = dataset.batch(per_replica_batch_size, drop_remainder=True)

    worker_devices = [
        (self.get_host(hid), [self.get_host_cpu_device(hid)])
        for hid in range(self.num_hosts)
    ]
    distributed_dataset = values.MultiWorkerDataset(
        functools.partial(self._call_dataset_fn, lambda: dataset),
        worker_devices)
    # TODO(priyag): Return distribution strategy specific InputIterator
    return distributed_dataset.make_initializable_iterator()
コード例 #6
0
 def _test_dataset(self,
                   dataset_fn,
                   worker_devices,
                   devices,
                   expected_values,
                   auto_shard=True):
     multi_worker_dataset = values.MultiWorkerDataset(dataset_fn,
                                                      worker_devices,
                                                      auto_shard=auto_shard)
     multi_worker_iterator = multi_worker_dataset.make_initializable_iterator(
     )
     with self.cached_session() as sess:
         sess.run(multi_worker_iterator.initializer)
         self._test_iterator(sess, multi_worker_iterator, devices,
                             expected_values)
コード例 #7
0
ファイル: values_test.py プロジェクト: sgcm520/tensorflow2
  def testValueErrorForIterator(self):
    # Incompatiable arguments.
    with self.assertRaises(ValueError):
      values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"})

    # Test duplicated devices under same worker.
    worker_device_map, _ = self._cpu_devices()
    worker_device_map["/job:worker/replica:0/task:0"].append(
        "/job:worker/replica:0/task:0/device:CPU:0")
    with context.graph_mode():
      dataset_fn = lambda: dataset_ops.Dataset.range(8)
      multi_worker_dataset = values.MultiWorkerDataset(
          dataset_fn, worker_device_map, prefetch_on_device=False)
      multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
      with self.assertRaises(ValueError):
        multi_worker_iterator.get_next()
コード例 #8
0
ファイル: values_test.py プロジェクト: sgcm520/tensorflow2
  def testInitializableIterator(self):
    worker_device_map, devices = self._cpu_devices()
    with context.graph_mode():
      dataset_fn = lambda: dataset_ops.Dataset.range(8)
      multi_worker_dataset = values.MultiWorkerDataset(
          dataset_fn, worker_device_map, prefetch_on_device=False)
      multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()

      self.evaluate(multi_worker_iterator.initializer)
      self._test_iterator(multi_worker_iterator, devices,
                          [[0, 1], [2, 3], [4, 5], [6, 7]])

      # After re-initializing the iterator, should be able to iterate again.
      self.evaluate(multi_worker_iterator.initializer)
      self._test_iterator(multi_worker_iterator, devices,
                          [[0, 1], [2, 3], [4, 5], [6, 7]])
コード例 #9
0
    def testInitializableIterator(self):
        worker_devices, devices = self._cpu_devices()
        with context.graph_mode(), self.cached_session() as sess:
            dataset_fn = lambda: dataset_ops.Dataset.range(8)
            multi_worker_dataset = values.MultiWorkerDataset(dataset_fn,
                                                             worker_devices,
                                                             auto_shard=True)
            multi_worker_iterator = multi_worker_dataset.make_initializable_iterator(
            )

            sess.run(multi_worker_iterator.initializer)
            self._test_iterator(sess, multi_worker_iterator, devices,
                                [[0, 1], [2, 3], [4, 5], [6, 7]])

            # After re-initializing the iterator, should be able to iterate again.
            sess.run(multi_worker_iterator.initializer)
            self._test_iterator(sess, multi_worker_iterator, devices,
                                [[0, 1], [2, 3], [4, 5], [6, 7]])
コード例 #10
0
    def _make_input_fn_iterator(
            self,
            input_fn,
            replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
        if self._cluster_spec:
            input_fns = []
            for i in range(len(self._worker_devices)):
                input_context = distribute_lib.InputContext(
                    num_input_pipelines=len(self._worker_devices),
                    input_pipeline_id=i,
                    num_replicas_in_sync=self.num_replicas_in_sync)
                input_fns.append(
                    partial(self._call_dataset_fn, input_fn, input_context))

            return values.MultiWorkerDataset(input_fns, self._worker_devices,
                                             self._auto_shard_dataset)
        else:
            input_context = distribute_lib.InputContext(
                num_input_pipelines=1,
                input_pipeline_id=0,
                num_replicas_in_sync=self.num_replicas_in_sync)
            return values.PerReplicaDataset(
                self._call_dataset_fn(input_fn, input_context), self._devices)
コード例 #11
0
 def distribute_dataset(self, dataset_fn):
   return values.MultiWorkerDataset(
       partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
       self._prefetch_on_device)
コード例 #12
0
 def _test_dataset(self, dataset_fn, worker_device_map, devices,
                   expected_values):
     multi_worker_dataset = values.MultiWorkerDataset(
         dataset_fn, worker_device_map, prefetch_on_device=False)
     multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator()
     self._test_iterator(multi_worker_iterator, devices, expected_values)