Пример #1
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   return values.InputFunctionIterator(
       input_fn, [("/job:localhost", [self._device])],
       [distribute_lib.InputContext()])
Пример #2
0
    def _test_iterator(self,
                       input_fn,
                       worker_device_pairs,
                       expected_values,
                       sess=None):
        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        iterator = values.InputFunctionIterator(input_fn, worker_device_pairs)

        evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

        evaluate(iterator.initialize())

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            evaluate([values.select_device(d, next_element) for d in devices])

        # After re-initializing the iterator, should be able to iterate again.
        evaluate(iterator.initialize())

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate(
                [values.select_device(d, next_element) for d in devices])
            self.assertEqual(expected_value, computed_value)
Пример #3
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     worker = device_util.canonicalize("/device:CPU:0")
     worker_device_pairs = [(worker, [self._device])]
     return values.InputFunctionIterator(input_fn, worker_device_pairs,
                                         [distribute_lib.InputContext()])
Пример #4
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   input_contexts = []
   num_workers = self._input_workers.num_workers
   for i in range(num_workers):
     input_contexts.append(distribute_lib.InputContext(
         num_input_pipelines=num_workers,
         input_pipeline_id=i,
         num_replicas_in_sync=self._num_replicas_in_sync))
   return values.InputFunctionIterator(
       input_fn, self._input_workers, input_contexts)
Пример #5
0
    def _make_input_fn_iterator(
            self,
            input_fn,
            replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
        """Distributes the dataset to each local GPU."""
        if self._cluster_spec is None:
            input_pipeline_id = 0
        else:
            input_pipeline_id = multi_worker_util.id_in_cluster(
                self._cluster_spec, self._task_type, self._task_id)
        input_context = distribute_lib.InputContext(
            num_input_pipelines=self._num_workers,
            input_pipeline_id=input_pipeline_id,
            num_replicas_in_sync=self._num_replicas_in_sync)

        return values.InputFunctionIterator(input_fn, self._input_workers,
                                            [input_context])
Пример #6
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     input_contexts = []
     if self._cluster_spec:
         num_workers = len(self._worker_devices)
         worker_device_pairs = self._worker_devices
     else:
         num_workers = 1
         worker_device_pairs = [("/job:localhost", self._devices)]
     for i in range(num_workers):
         input_contexts.append(
             distribute_lib.InputContext(
                 num_input_pipelines=num_workers,
                 input_pipeline_id=i,
                 num_replicas_in_sync=self._num_replicas_in_sync))
     return values.InputFunctionIterator(input_fn, worker_device_pairs,
                                         input_contexts)
Пример #7
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     """Distributes the dataset to each local GPU."""
     if self._cluster_spec:
         input_pipeline_id = multi_worker_util.id_in_cluster(
             self._cluster_spec, self._task_type, self._task_id)
         num_input_pipelines = multi_worker_util.worker_count(
             self._cluster_spec, self._task_type)
     else:
         input_pipeline_id = 0
         num_input_pipelines = 1
     input_context = distribute_lib.InputContext(
         num_input_pipelines=num_input_pipelines,
         input_pipeline_id=input_pipeline_id,
         num_replicas_in_sync=self._num_replicas_in_sync)
     worker_device_pairs = [(self._worker_device, self._compute_devices)]
     return values.InputFunctionIterator(input_fn, worker_device_pairs,
                                         [input_context])
Пример #8
0
  def _make_input_fn_iterator(
      self,
      input_fn,
      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
    input_contexts = []
    if self._local_mode:
      num_workers = 1
      worker = device_util.canonicalize("/device:CPU:0")
      worker_device_pairs = [(worker, self._devices)]
    else:
      num_workers = len(self._worker_devices)
      worker_device_pairs = self._worker_devices

    for i in range(num_workers):
      input_contexts.append(distribute_lib.InputContext(
          num_input_pipelines=num_workers,
          input_pipeline_id=i,
          num_replicas_in_sync=self._num_replicas_in_sync))
    return values.InputFunctionIterator(
        input_fn, worker_device_pairs, input_contexts)
Пример #9
0
  def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
                     expected_values, sess=None, split_batch_by=None):
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    device_map = values.ReplicaDeviceMap(devices)
    input_workers = values.InputWorkers(device_map, worker_device_pairs)

    if input_type == "input_fn":
      input_contexts = [
          distribute_lib.InputContext() for _ in worker_device_pairs]
      input_fn = lambda _: dataset_fn()
      iterator = values.InputFunctionIterator(
          input_fn, input_workers, input_contexts)
    else:
      iterator = values.DatasetIterator(
          dataset_fn(), input_workers, split_batch_by)

    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertAllEqual(expected_value, computed_value)

    with self.assertRaises(errors.OutOfRangeError):
      next_element = iterator.get_next()
      evaluate([values.select_replica(r, next_element)
                for r in range(len(devices))])

    # After re-initializing the iterator, should be able to iterate again.
    evaluate(control_flow_ops.group(iterator.initialize()))

    for expected_value in expected_values:
      next_element = iterator.get_next()
      computed_value = evaluate(
          [values.select_replica(r, next_element) for r in range(len(devices))])
      self.assertAllEqual(expected_value, computed_value)
Пример #10
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     return values.InputFunctionIterator(input_fn, self._input_workers,
                                         [distribute_lib.InputContext()])