Ejemplo n.º 1
0
 def _distribute_dataset(self, dataset_fn):
   worker_devices = [
       (self.get_host(hid), [self.get_host_cpu_device(hid)])
       for hid in range(self.num_hosts)
   ]
   return values.MultiWorkerDataset(
       functools.partial(self._call_dataset_fn, dataset_fn), worker_devices)
Ejemplo n.º 2
0
 def _test_dataset(self, dataset_fn, worker_devices, devices,
                   expected_values, auto_shard=True):
   multi_worker_dataset = values.MultiWorkerDataset(
       dataset_fn, worker_devices, auto_shard=auto_shard)
   multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
   with self.cached_session() as sess:
     sess.run(multi_worker_iterator.initializer)
     self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
Ejemplo n.º 3
0
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                         self._devices)
     else:
         return values.MultiWorkerDataset(functools.partial(
             self._call_dataset_fn, dataset_fn),
                                          self._worker_devices,
                                          auto_shard=False)
Ejemplo n.º 4
0
 def _distribute_dataset(self, dataset_fn):
     if self._cluster_spec:
         return values.MultiWorkerDataset(
             partial(self._call_dataset_fn, dataset_fn),
             self._worker_devices,
             auto_shard=self._auto_shard_dataset)
     else:
         return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                         self._devices)
Ejemplo n.º 5
0
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                         self._input_workers, 0)
     else:
         return values.MultiWorkerDataset(
             functools.partial(self._call_dataset_fn, dataset_fn),
             self._input_workers,
             auto_shard=self._auto_shard_dataset)
Ejemplo n.º 6
0
 def _test_dataset(self, dataset_fn, worker_devices, devices,
                   expected_values):
   device_map = values.ReplicaDeviceMap(devices)
   input_workers = values.InputWorkers(device_map, worker_devices)
   multi_worker_dataset = values.MultiWorkerDataset(
       dataset_fn, input_workers)
   multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
   with self.cached_session() as sess:
     sess.run(multi_worker_iterator.initializer)
     self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
Ejemplo n.º 7
0
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         # Add argument: prefetch_on_device=False
         return values.PerReplicaDataset(
             self._call_dataset_fn(dataset_fn), self._devices, prefetch_on_device=False)
     else:
         return values.MultiWorkerDataset(
             functools.partial(self._call_dataset_fn, dataset_fn),
             self._worker_devices,
             auto_shard=self._auto_shard_dataset)
Ejemplo n.º 8
0
    def _make_dataset_iterator(self, dataset):
        """Make iterators for each of the TPU hosts.

    We first unbatch the users input dataset and then rebatch it with the
    per replica batch size that is calculated using
    `global_batch_size // num_replicas_in_sync`. The currently supported cases
    are as follows:
    `dataset.batch()` is the last operation on the dataset.
    `dataset.apply(map_and_batch)` is the last operation on the dataset.
    `dataset.batch().prefetch()` are the last 2 operations on the dataset.
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    Args:
      dataset: The `tf.data` dataset passed by the user.

    Returns:
      iterator: InputIterator created for each of the host machines.
    """

        # TODO(sourabhbajaj): Remove this in lieu of distributed datasets
        def _get_dataset_batch_size(dataset):
            """Get the global batch size from the dataset object."""
            # pylint: disable=protected-access
            if isinstance(dataset, dataset_ops.DatasetV1Adapter):
                dataset = dataset._dataset
            if isinstance(dataset, dataset_ops.BatchDataset):
                return tensor_util.constant_value(dataset._batch_size)
            elif isinstance(dataset, batching._MapAndBatchDataset):
                return dataset._batch_size
            elif isinstance(dataset, dataset_ops.PrefetchDataset):
                return _get_dataset_batch_size(dataset._input_dataset)
            # pylint: enable=protected-access
            raise ValueError(
                "Unable to fetch the batch size from the input dataset. `batch` "
                "`map_and_batch` need to be the last operations on the dataset. "
                "The batch operations can be followed by a prefetch.")

        global_batch_size = _get_dataset_batch_size(dataset)
        if global_batch_size % self._num_replicas_in_sync:
            raise ValueError(
                "Batch size %s cannot be sharded evenly across replicas %s" %
                (global_batch_size, self.num_replicas_in_sync))
        per_replica_batch_size = global_batch_size // self._num_replicas_in_sync
        dataset = dataset.apply(batching.unbatch())
        dataset = dataset.batch(per_replica_batch_size, drop_remainder=True)

        worker_devices = [(self.get_host(hid), [self.get_host_cpu_device(hid)])
                          for hid in range(self.num_hosts)]
        distributed_dataset = values.MultiWorkerDataset(
            functools.partial(self._call_dataset_fn, lambda: dataset),
            worker_devices)
        # TODO(priyag): Return distribution strategy specific InputIterator
        return distributed_dataset.make_initializable_iterator()
Ejemplo n.º 9
0
  def testValueErrorForIterator(self):
    # Incompatiable arguments.
    with self.assertRaises(ValueError):
      values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"})

    # Test duplicated devices under same worker.
    worker_devices, _ = self._cpu_devices()
    worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0")
    with context.graph_mode():
      dataset_fn = lambda: dataset_ops.Dataset.range(8)
      multi_worker_dataset = values.MultiWorkerDataset(
          dataset_fn, worker_devices, auto_shard=True)
      multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()
      with self.assertRaises(ValueError):
        multi_worker_iterator.get_next()
Ejemplo n.º 10
0
  def testInitializableIterator(self):
    worker_devices, devices = self._cpu_devices()
    with context.graph_mode(), self.cached_session() as sess:
      dataset_fn = lambda: dataset_ops.Dataset.range(8)
      multi_worker_dataset = values.MultiWorkerDataset(
          dataset_fn, worker_devices, auto_shard=True)
      multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()

      sess.run(multi_worker_iterator.initializer)
      self._test_iterator(sess, multi_worker_iterator, devices,
                          [[0, 1], [2, 3], [4, 5], [6, 7]])

      # After re-initializing the iterator, should be able to iterate again.
      sess.run(multi_worker_iterator.initializer)
      self._test_iterator(sess, multi_worker_iterator, devices,
                          [[0, 1], [2, 3], [4, 5], [6, 7]])
Ejemplo n.º 11
0
  def testInitializableIterator(self):
    worker_devices, devices = self._cpu_devices()
    with context.graph_mode(), self.cached_session() as sess:
      dataset_fn = lambda: dataset_ops.Dataset.range(8)
      device_map = values.ReplicaDeviceMap(devices)
      input_workers = values.InputWorkers(device_map, worker_devices)
      multi_worker_dataset = values.MultiWorkerDataset(
          dataset_fn, input_workers)
      multi_worker_iterator = multi_worker_dataset.make_initializable_iterator()

      sess.run(multi_worker_iterator.initializer)
      self._test_iterator(
          sess, multi_worker_iterator, devices,
          [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])

      # After re-initializing the iterator, should be able to iterate again.
      sess.run(multi_worker_iterator.initializer)
      self._test_iterator(
          sess, multi_worker_iterator, devices,
          [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
Ejemplo n.º 12
0
    def _make_input_fn_iterator(
            self,
            input_fn,
            replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
        if self._cluster_spec:
            input_fns = []
            for i in range(len(self._worker_devices)):
                input_context = distribute_lib.InputContext(
                    num_input_pipelines=len(self._worker_devices),
                    input_pipeline_id=i,
                    num_replicas_in_sync=self._num_replicas_in_sync)
                input_fns.append(
                    partial(self._call_dataset_fn, input_fn, input_context))

            return values.MultiWorkerDataset(input_fns, self._worker_devices,
                                             self._auto_shard_dataset)
        else:
            input_context = distribute_lib.InputContext(
                num_input_pipelines=1,
                input_pipeline_id=0,
                num_replicas_in_sync=self._num_replicas_in_sync)
            return values.PerReplicaDataset(
                self._call_dataset_fn(input_fn, input_context), self._devices)
Ejemplo n.º 13
0
 def _distribute_dataset(self, dataset_fn):
   return values.MultiWorkerDataset(
       functools.partial(self._call_dataset_fn, dataset_fn),
       self._input_workers)