Example #1
0
 def _make_dataset_iterator(self, dataset):
     if self._cluster_spec:
         worker_device_pairs = self._worker_devices
     else:
         worker_device_pairs = [("/job:localhost", self._devices)]
     return values.DatasetIterator(dataset, worker_device_pairs,
                                   self._num_replicas_in_sync)
Example #2
0
    def _make_dataset_iterator(self, dataset):
        """Make iterators for each of the TPU hosts."""

        worker_devices = [(self.get_host(hid), [self.get_host_cpu_device(hid)])
                          for hid in range(self.num_hosts)]
        return values.DatasetIterator(dataset, worker_devices,
                                      self._num_replicas_in_sync)
 def _make_dataset_iterator(self, dataset):
     if self._cluster_spec:
         worker_device_pairs = self._worker_devices
     else:
         worker = device_util.canonicalize("/device:CPU:0")
         worker_device_pairs = [(worker, self._devices)]
     return values.DatasetIterator(dataset, worker_device_pairs,
                                   self._num_replicas_in_sync)
    def _make_dataset_iterator(self, dataset):
        """Make iterator from dataset without splitting the batch.

    This implementation is different than the one in
    `tf.distribute.MirroredStrategy` for purposes of backward compatibility.
    We treat the incoming dataset's batch size as per replica batch size.

    Args:
      dataset: `tf.data.Dataset` for input.
    Returns:
      An `InputIterator` which returns inputs for each step of the computation.
    """
        return values.DatasetIterator(dataset, self._input_workers)
Example #5
0
    def _make_dataset_iterator(self, dataset):
        """Make iterator from dataset without splitting the batch.

    This implementation is different than the one in
    `tf.distribute.MirroredStrategy` for purposes of backward compatibility.
    We treat the incoming dataset's batch size as per replica batch size.

    Args:
      dataset: `tf.data.Dataset` for input.
    Returns:
      An `InputIterator` which returns inputs for each step of the computation.
    """
        if self._local_mode:
            worker = device_util.canonicalize("/device:CPU:0")
            worker_device_pairs = [(worker, self._devices)]
        else:
            worker_device_pairs = self._worker_devices
        return values.DatasetIterator(dataset, worker_device_pairs)
Example #6
0
  def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
                     expected_values, sess=None, split_batch_by=None):
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    device_map = values.ReplicaDeviceMap(devices)
    input_workers = values.InputWorkers(device_map, worker_device_pairs)

    if input_type == "input_fn":
      input_contexts = [
          distribute_lib.InputContext() for _ in worker_device_pairs]
      input_fn = lambda _: dataset_fn()
      iterator = values.InputFunctionIterator(
          input_fn, input_workers, input_contexts)
    else:
      iterator = values.DatasetIterator(
          dataset_fn(), input_workers, split_batch_by)

    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertAllEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      evaluate([values.select_replica(r, next_element)
                for r in range(len(devices))])

    # After re-initializing the iterator, should be able to iterate again.
    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertAllEqual(expected_value, computed_value)
Example #7
0
 def _make_dataset_iterator(self, dataset):
     return values.DatasetIterator(dataset, self._input_workers,
                                   self._num_replicas_in_sync)
Example #8
0
 def _make_dataset_iterator(self, dataset):
   """Make iterator from dataset without splitting the batch."""
   return values.DatasetIterator(dataset, [("/job:localhost", [self._device])])
Example #9
0
 def _make_dataset_iterator(self, dataset):
     """Make iterator from dataset without splitting the batch."""
     return values.DatasetIterator(dataset, self._input_workers)
Example #10
0
  def _make_dataset_iterator(self, dataset):
    """Make iterators for each of the TPU hosts."""

    return values.DatasetIterator(dataset, self._input_workers,
                                  self._num_replicas_in_sync)
Example #11
0
 def _make_dataset_iterator(self, dataset):
     """Make iterator from dataset without splitting the batch."""
     worker = device_util.canonicalize("/device:CPU:0")
     worker_device_pairs = [(worker, [self._device])]
     return values.DatasetIterator(dataset, worker_device_pairs)
Example #12
0
 def _make_dataset_iterator(self, dataset):
     worker_device_pairs = [(self._worker_device, self._devices)]
     return values.DatasetIterator(dataset, worker_device_pairs,
                                   self._num_replicas_in_sync)