示例#1
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     return values.PerReplicaDataset(
         self._call_dataset_fn(input_fn, distribute_lib.InputContext()),
         [self._device])
示例#2
0
  def testInitializableIterator(self):
    with context.graph_mode():
      devices = ["/device:CPU:0"]
      # Using random input since that is only allowed with initializable
      # iterator.
      dataset = dataset_ops.Dataset.from_tensor_slices(
          random_ops.random_uniform((10,)))

      device_map = values.ReplicaDeviceMap(devices)
      input_workers = values.InputWorkers(device_map)
      per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0)
      iterator = per_replica_dataset.make_initializable_iterator()

      self.evaluate(iterator.initializer)
      next_element = iterator.get_next_as_list()
      for _ in range(10):
        self.evaluate(next_element)

      # Should fail after the input is finished.
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)

      # After re-initializing the iterator, should be able to iterate again.
      self.evaluate(iterator.initializer)
      for _ in range(10):
        self.evaluate(next_element)
示例#3
0
 def _distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     # TODO(yuefengz): shard the dataset.
     worker_index = 0
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._input_workers,
                                     worker_index,
                                     prefetch_on_device=True)
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                         self._input_workers, 0)
     else:
         return values.MultiWorkerDataset(
             functools.partial(self._call_dataset_fn, dataset_fn),
             self._input_workers,
             auto_shard=self._auto_shard_dataset)
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                         self._devices)
     else:
         return values.MultiWorkerDataset(functools.partial(
             self._call_dataset_fn, dataset_fn),
                                          self._worker_devices,
                                          auto_shard=False)
示例#6
0
 def _distribute_dataset(self, dataset_fn):
     if self._cluster_spec:
         return values.MultiWorkerDataset(
             partial(self._call_dataset_fn, dataset_fn),
             self._worker_devices,
             auto_shard=self._auto_shard_dataset)
     else:
         return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                         self._devices)
示例#7
0
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         # Add argument: prefetch_on_device=False
         return values.PerReplicaDataset(
             self._call_dataset_fn(dataset_fn), self._devices, prefetch_on_device=False)
     else:
         return values.MultiWorkerDataset(
             functools.partial(self._call_dataset_fn, dataset_fn),
             self._worker_devices,
             auto_shard=self._auto_shard_dataset)
示例#8
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   """Distributes the dataset to each local GPU."""
   if self._cluster_spec is None:
     input_pipeline_id = 0
   else:
     input_pipeline_id = multi_worker_util.id_in_cluster(
         self._cluster_spec, self._task_type, self._task_id)
   input_context = distribute_lib.InputContext(
       num_input_pipelines=self._num_workers,
       input_pipeline_id=input_pipeline_id,
       num_replicas_in_sync=self._num_replicas_in_sync)
   return values.PerReplicaDataset(
       self._call_dataset_fn(input_fn, input_context), self._devices, True)
示例#9
0
    def _test_iterator(self, devices, dataset, expected_values):
        per_replica_dataset = values.PerReplicaDataset(dataset, devices)
        if context.executing_eagerly():
            iterator = per_replica_dataset.make_one_shot_iterator()
        else:
            iterator = per_replica_dataset.make_initializable_iterator()
            self.evaluate([iterator.initializer])

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = self.evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            self.evaluate(
                [values.select_device(d, next_element) for d in devices])
示例#10
0
  def _test_iterator(self, devices, dataset, expected_values):
    device_map = values.ReplicaDeviceMap(devices)
    input_workers = values.InputWorkers(device_map)
    per_replica_dataset = values.PerReplicaDataset(dataset, input_workers, 0)
    if context.executing_eagerly():
      iterator = per_replica_dataset.make_one_shot_iterator()
    else:
      iterator = per_replica_dataset.make_initializable_iterator()
      self.evaluate([iterator.initializer])

    for expected_value in expected_values:
      next_element = iterator.get_next_as_list()
      computed_value = self.evaluate(next_element)
      self.assertEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next_as_list()
      self.evaluate(next_element)
示例#11
0
    def _make_input_fn_iterator(
            self,
            input_fn,
            replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
        if self._cluster_spec:
            input_fns = []
            for i in range(len(self._worker_devices)):
                input_context = distribute_lib.InputContext(
                    num_input_pipelines=len(self._worker_devices),
                    input_pipeline_id=i,
                    num_replicas_in_sync=self._num_replicas_in_sync)
                input_fns.append(
                    partial(self._call_dataset_fn, input_fn, input_context))

            return values.MultiWorkerDataset(input_fns, self._worker_devices,
                                             self._auto_shard_dataset)
        else:
            input_context = distribute_lib.InputContext(
                num_input_pipelines=1,
                input_pipeline_id=0,
                num_replicas_in_sync=self._num_replicas_in_sync)
            return values.PerReplicaDataset(
                self._call_dataset_fn(input_fn, input_context), self._devices)
示例#12
0
 def _distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._compute_devices, True)
示例#13
0
 def _distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     # TODO(yuefengz): shard the dataset.
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._devices, True)
示例#14
0
 def _distribute_dataset(self, dataset_fn):
   return values.PerReplicaDataset(
       self._call_dataset_fn(dataset_fn), [self._device])
示例#15
0
 def _make_dataset_iterator(self, dataset):
   distributed_dataset = values.PerReplicaDataset(dataset, [self._device])
   # TODO(priyag): Return distribution strategy specific InputIterator
   return distributed_dataset.make_initializable_iterator()
示例#16
0
 def _distribute_dataset(self, dataset_fn):
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._input_workers, 0)
示例#17
0
 def _distribute_dataset(self, dataset_fn):
     """Distributes the dataset to each local GPU."""
     return values.PerReplicaDataset(self._call_dataset_fn(dataset_fn),
                                     self._input_workers,
                                     0,
                                     prefetch_on_device=True)