def testChooseAlgorithm(self): # Not use nccl if there is any cpu device. self.assertIsInstance( cross_device_ops_lib.select_cross_device_ops(["/cpu:0"]), cross_device_ops_lib.ReductionToOneDevice) # Not use nccl if requested device is not visible to TensorFlow. # TODO(yuefengz): make `select_cross_device_ops` work with device strings # self.assertIsInstance( # cross_device_ops_lib.select_cross_device_ops(["/gpu:100"]), # cross_device_ops_lib.ReductionToOneDevice) if context.num_gpus() < 1: return devices = ["/gpu:0"] def mock_get_registered_kernels_for_op(op): if op == "NcclAllReduce": return [object] else: return [] # Use nccl if nccl kernel is found. with test.mock.patch.object(kernels, "get_registered_kernels_for_op", mock_get_registered_kernels_for_op): self.assertIsInstance( cross_device_ops_lib.select_cross_device_ops(devices), cross_device_ops_lib.NcclAllReduce) # Not use nccl if nccl kernel is not found. with test.mock.patch.object(kernels, "get_registered_kernels_for_op", lambda _: []): self.assertIsInstance( cross_device_ops_lib.select_cross_device_ops(devices), cross_device_ops_lib.ReductionToOneDevice)
def _initialize_single_worker(self, devices): """Initializes the object for single-worker training.""" self._devices = tuple(device_util.canonicalize(d) for d in devices) self._input_workers_devices = ( (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) self._inferred_cross_device_ops = None if self._cross_device_ops else ( cross_device_ops_lib.select_cross_device_ops(devices)) self._host_input_device = numpy_dataset.SingleDevice( self._input_workers_devices[0][0]) self._is_multi_worker_training = False logging.info("Using MirroredStrategy with devices %r", devices) device_spec = tf_device.DeviceSpec.from_string( self._input_workers_devices[0][0]) # Ensures when we enter strategy.scope() we use the correct default device if device_spec.job is not None and device_spec.job != "localhost": self._default_device = "/job:%s/replica:%d/task:%d" % ( device_spec.job, device_spec.replica, device_spec.task)