예제 #1
0
  def testChooseAlgorithm(self):
    # Not use nccl if there is any cpu device.
    self.assertIsInstance(
        cross_device_ops_lib.select_cross_device_ops(["/cpu:0"]),
        cross_device_ops_lib.ReductionToOneDevice)

    # Not use nccl if requested device is not visible to TensorFlow.
    # TODO(yuefengz): make `select_cross_device_ops` work with device strings
    # self.assertIsInstance(
    #     cross_device_ops_lib.select_cross_device_ops(["/gpu:100"]),
    #     cross_device_ops_lib.ReductionToOneDevice)

    if context.num_gpus() < 1:
      return

    devices = ["/gpu:0"]

    def mock_get_registered_kernels_for_op(op):
      if op == "NcclAllReduce":
        return [object]
      else:
        return []

    # Use nccl if nccl kernel is found.
    with test.mock.patch.object(kernels, "get_registered_kernels_for_op",
                                mock_get_registered_kernels_for_op):
      self.assertIsInstance(
          cross_device_ops_lib.select_cross_device_ops(devices),
          cross_device_ops_lib.NcclAllReduce)

    # Not use nccl if nccl kernel is not found.
    with test.mock.patch.object(kernels,
                                "get_registered_kernels_for_op", lambda _: []):
      self.assertIsInstance(
          cross_device_ops_lib.select_cross_device_ops(devices),
          cross_device_ops_lib.ReductionToOneDevice)
예제 #2
0
 def _initialize_single_worker(self, devices):
   """Initializes the object for single-worker training."""
   self._devices = tuple(device_util.canonicalize(d) for d in devices)
   self._input_workers_devices = (
       (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
   self._inferred_cross_device_ops = None if self._cross_device_ops else (
       cross_device_ops_lib.select_cross_device_ops(devices))
   self._host_input_device = numpy_dataset.SingleDevice(
       self._input_workers_devices[0][0])
   self._is_multi_worker_training = False
   logging.info("Using MirroredStrategy with devices %r", devices)
   device_spec = tf_device.DeviceSpec.from_string(
       self._input_workers_devices[0][0])
   # Ensures when we enter strategy.scope() we use the correct default device
   if device_spec.job is not None and device_spec.job != "localhost":
     self._default_device = "/job:%s/replica:%d/task:%d" % (
         device_spec.job, device_spec.replica, device_spec.task)