Exemplo n.º 1
0
    def __init__(self,
                 container_strategy,
                 tpu_cluster_resolver=None,
                 steps_per_run=None,
                 device_assignment=None):
        super(TPUExtended, self).__init__(container_strategy)

        if tpu_cluster_resolver is None:
            tpu_cluster_resolver = TPUClusterResolver("")

        if steps_per_run is None:
            # TODO(frankchn): Warn when we are being used by DS/Keras and this is
            # not specified.
            steps_per_run = 1

        self._tpu_cluster_resolver = tpu_cluster_resolver
        self._tpu_metadata = get_tpu_system_metadata(
            self._tpu_cluster_resolver)
        self._device_assignment = device_assignment

        # Device assignment is currently only supported for 1 core case.
        if self._device_assignment:
            assert isinstance(self._device_assignment,
                              device_assignment_lib.DeviceAssignment)
            if self._device_assignment.num_replicas != 1:
                raise ValueError(
                    "Device assignment is only supported for a single "
                    "core single replica case currently.")
            if self._device_assignment.num_cores_per_replica != 1:
                raise ValueError(
                    "Device assignment is only supported for a single "
                    "core single replica case currently.")
            if not all(self._device_assignment.core_assignment[0][0] ==
                       [0, 0, 0]):
                raise ValueError(
                    "Device assignment is only supported for a single "
                    "core single replica case currently.")

        # TODO(jhseu): Switch to DeviceAssignment to support pods and model
        # parallelism.
        self._device_index = {
            d.name: i
            for i, d in enumerate(self._tpu_metadata.devices)
            if "device:TPU:" in d.name
        }
        self._host_device = tpu_strategy_util.get_first_tpu_host_device(
            self._tpu_cluster_resolver)
        self._tpu_devices = tuple(sorted(self._device_index.keys()))
        # Only create variables for the number of replicas we're running.
        self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
        self._device_map = values.ReplicaDeviceMap(self._tpu_devices)

        # Preload the data onto the TPUs.
        input_worker_devices = collections.OrderedDict()
        for tpu_device in self._tpu_devices:
            host_device = _get_host_for_device(tpu_device)
            input_worker_devices.setdefault(host_device, [])
            input_worker_devices[host_device].append(tpu_device)
        self._input_workers = input_lib.InputWorkers(
            self._device_map, tuple(input_worker_devices.items()))

        # TODO(sourabhbajaj): Remove this once performance of running one step
        # at a time is comparable to multiple steps.
        self.steps_per_run = steps_per_run
        self._require_static_shapes = True
Exemplo n.º 2
0
  def _initialize_multi_worker(self, cluster_resolver):
    """Initialize devices for multiple workers.

    It creates variable devices and compute devices. Variables and operations
    will be assigned to them respectively. We have one compute device per
    replica. The variable device is a device function or device string. The
    default variable device assigns variables to parameter servers in a
    round-robin fashion.

    Args:
      cluster_resolver: a descendant of `ClusterResolver` object.

    Raises:
      ValueError: if the cluster doesn't have ps jobs.
    """
    num_gpus = cluster_resolver.num_accelerators()
    cluster_spec = cluster_resolver.cluster_spec()
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_index
    if not task_type or task_id is None:
      raise ValueError("When `cluster_spec` is given, you must also specify "
                       "`task_type` and `task_id`")
    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
    assert cluster_spec.as_dict()

    worker_device = "/job:%s/task:%d" % (task_type, task_id)
    self._input_host_device = numpy_dataset.SingleDevice(worker_device)

    # Define compute devices which is a list of device strings and one for each
    # replica. When there are GPUs, replicate operations on these GPUs.
    # Otherwise, place operations on CPU.
    if num_gpus > 0:
      compute_devices = tuple(
          "%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus))
    else:
      compute_devices = (worker_device,)

    self._device_map = values.ReplicaDeviceMap(compute_devices)
    self._input_workers = input_lib.InputWorkers(
        self._device_map, [(worker_device, compute_devices)])

    # In distributed mode, place variables on ps jobs in a round-robin fashion.
    # Note that devices returned from `replica_device_setter` are not
    # canonical and therefore we don't canonicalize all variable devices to
    # make them consistent.
    # TODO(yuefengz): support passing a strategy object to control variable
    # assignment.
    # TODO(yuefengz): merge the logic of replica_device_setter into this
    # class.
    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
    if num_ps_replicas == 0:
      raise ValueError("The cluster spec needs to have `ps` jobs.")
    self._variable_device = device_setter.replica_device_setter(
        ps_tasks=num_ps_replicas,
        worker_device=worker_device,
        merge_devices=True,
        cluster=cluster_spec)

    # The `_parameter_devices` is needed for the `parameter_devices` property
    # and is a list of all variable devices. Here parameter devices are all
    # tasks of the "ps" job.
    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
                                        range(num_ps_replicas)))

    # Add a default device so that ops without specified devices will not end up
    # on other workers.
    self._default_device = worker_device

    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                task_id)
    self._cluster_spec = cluster_spec
    self._task_type = task_type
    self._task_id = task_id

    logging.info(
        "Multi-worker ParameterServerStrategy with "
        "cluster_spec = %r, task_type = %r, task_id = %r, "
        "num_ps_replicas = %r, is_chief = %r, device_map = %r, "
        "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
        num_ps_replicas, self._is_chief, self._device_map,
        self._variable_device)
Exemplo n.º 3
0
 def __init__(self, distribute):
     super(_TestExtended, self).__init__(distribute)
     worker_device_pairs = [("", ["/device:CPU:0"])]
     self._input_workers = input_lib.InputWorkers(worker_device_pairs)
Exemplo n.º 4
0
 def _input_workers(self):
   if self._input_workers_obj is None:
     self._input_workers_obj = input_lib.InputWorkers(
         self._input_worker_devices)
   return self._input_workers_obj
Exemplo n.º 5
0
 def __init__(self, distribute):
     super(_TestExtended, self).__init__(distribute)
     device_map = values.ReplicaDeviceMap(["/device:CPU:0"])
     worker_device_pairs = [("", ["/device:CPU:0"])]
     self._input_workers = input_lib.InputWorkers(device_map,
                                                  worker_device_pairs)
  def _test_input_iteration(self,
                            input_type,
                            api_type,
                            iteration_type,
                            dataset_or_input_fn,
                            worker_device_pairs,
                            expected_values,
                            strategy,
                            sess=None,
                            split_batch_by=None,
                            input_context=None):
    if iteration_type == "for_loop" and not context.executing_eagerly():
      self.skipTest("unsupported test combination.")

    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
      self.skipTest("unsupported test combination.")

    devices = nest.flatten([ds for _, ds in worker_device_pairs])
    input_workers = input_lib.InputWorkers(worker_device_pairs)

    if api_type == "wrap_into_iterator":
      iterator = self._wrap_iterator(
          input_type,
          dataset_or_input_fn,
          input_workers,
          devices,
          split_batch_by,
          strategy,
          input_context=input_context)
    else:
      # wrapping into a dataset:
      dataset = self._wrap_dataset(
          input_type,
          dataset_or_input_fn,
          input_workers,
          split_batch_by,
          strategy,
          input_context=input_context)

      if context.executing_eagerly():
        iterator = iter(dataset)
      else:
        if isinstance(dataset, input_lib.DistributedDatasetV1):
          iterator = dataset.make_initializable_iterator()
        else:
          self.skipTest("unsupported test combination")

    if iteration_type == "get_next":
      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
      if isinstance(iterator, input_lib.DistributedIteratorV1):
        evaluate(control_flow_ops.group(iterator.initializer))

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = evaluate(
            [values.select_replica(r,
                                   next_element) for r in range(len(devices))])
        self.assertEqual(len(expected_value), len(computed_value))
        for i in range(len(expected_value)):
          self.assertAllEqual(expected_value[i], computed_value[i])

      with self.assertRaises(errors.OutOfRangeError):
        next_element = iterator.get_next()
        evaluate(
            [values.select_replica(r,
                                   next_element) for r in range(len(devices))])

      # After re-initializing the iterator, should be able to iterate again.
      if isinstance(iterator, input_lib.DistributedIteratorV1):
        evaluate(control_flow_ops.group(iterator.initializer))
      else:
        evaluate(control_flow_ops.group(iterator._initializer))

      for expected_value in expected_values:
        next_element = iterator.get_next()
        computed_value = evaluate(
            [values.select_replica(r,
                                   next_element) for r in range(len(devices))])
        self.assertEqual(len(expected_value), len(computed_value))
        for i in range(len(expected_value)):
          self.assertAllEqual(expected_value[i], computed_value[i])

    if iteration_type == "for_loop" and context.executing_eagerly():
      actual_values = []
      for x in dataset:
        computed_value = self.evaluate(
            [values.select_replica(r, x) for r in range(len(devices))])
        actual_values.append(computed_value)
      for i, expected_value in enumerate(expected_values):
        self.assertEqual(len(expected_value), len(actual_values[i]))
        for j in range(len(expected_value)):
          self.assertAllEqual(expected_value[j], actual_values[i][j])
 def _input_workers_with_options(self, options=None):
   input_workers_devices = (
       ("/device:CPU:0", self.worker_devices),)
   return input_lib.InputWorkers(
       input_workers_devices, canonicalize_devices=False)
Exemplo n.º 8
0
  def _initialize_multi_worker(self, cluster_resolver):
    """Initializes the object for multi-worker training."""
    cluster_spec = multi_worker_util.normalize_cluster_spec(
        cluster_resolver.cluster_spec())
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    if task_type is None or task_id is None:
      raise ValueError("When `cluster_spec` is given, you must also specify "
                       "`task_type` and `task_id`.")
    self._cluster_spec = cluster_spec
    self._task_type = task_type
    self._task_id = task_id

    self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
    if not self._num_workers:
      raise ValueError("No `worker`, `chief` or `evaluator` tasks can be found "
                       "in `cluster_spec`.")

    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                task_id)

    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
    self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)

    if (ops.executing_eagerly_outside_functions() and
        not getattr(self, "_local_or_standalone_client_mode", False)):
      context.context().configure_collective_ops(
          collective_leader=multi_worker_util.collective_leader(
              cluster_spec, task_type, task_id),
          scoped_allocator_enabled_ops=("CollectiveReduce",),
          use_nccl_communication=(self._communication == cross_device_ops_lib
                                  .CollectiveCommunication.NCCL),
          device_filters=("/job:%s/task:%d" % (task_type, task_id),))
      self._collective_ops_configured = True

    # Starting a std server in eager mode and in independent worker mode.
    if (context.executing_eagerly() and
        not getattr(self, "_std_server_started", False) and
        not getattr(self, "_local_or_standalone_client_mode", False)):
      # Checking _local_or_standalone_client_mode as well because we should not
      # create the std server in standalone client mode.
      config_proto = config_pb2.ConfigProto()
      config_proto = self._update_config_proto(config_proto)
      server_def = tensorflow_server_pb2.ServerDef(
          cluster=cluster_spec.as_cluster_def(),
          default_session_config=config_proto,
          job_name=task_type,
          task_index=task_id,
          protocol=cluster_resolver.rpc_layer or "grpc")
      context.context().enable_collective_ops(server_def)
      self._std_server_started = True
      # The `ensure_initialized` is needed before calling
      # `context.context().devices()`.
      context.context().ensure_initialized()
      logging.info(
          "Enabled multi-worker collective ops with available devices: %r",
          context.context().devices())

    # TODO(yuefengz): The `num_gpus` is only for this particular task. It
    # assumes all workers have the same number of GPUs. We should remove this
    # assumption by querying all tasks for their numbers of GPUs.
    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
    # some cases.
    if isinstance(cluster_resolver, TFConfigClusterResolver):
      num_gpus = context.num_gpus()
    else:
      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

    if num_gpus:
      local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
                            for i in range(num_gpus))
    else:
      local_devices = (self._worker_device,)

    self._collective_keys = cross_device_utils.CollectiveKeys()
    super(CollectiveAllReduceExtended, self)._initialize_local(local_devices)
    self._input_workers = input_lib.InputWorkers(
        self._device_map, [(self._worker_device, self.worker_devices)])
    self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
        num_workers=self._num_workers,
        num_gpus_per_worker=num_gpus,
        collective_keys=self._collective_keys)

    # Add a default device so that ops without specified devices will not end up
    # on other workers.
    self._default_device = "/job:%s/task:%d" % (task_type, task_id)

    # Save the num_gpus_per_worker and rpc_layer for configure method.
    self._num_gpus_per_worker = num_gpus
    self._rpc_layer = cluster_resolver.rpc_layer

    logging.info(
        "Multi-worker CollectiveAllReduceStrategy with cluster_spec = %r, "
        "task_type = %r, task_id = %r, num_workers = %r, local_devices = %r, "
        "communication = %s", cluster_spec.as_dict(), task_type,
        task_id, self._num_workers, local_devices,
        self._communication)
Exemplo n.º 9
0
    def __init__(self,
                 container_strategy,
                 tpu_cluster_resolver=None,
                 steps_per_run=None,
                 device_assignment=None):
        super(TPUExtended, self).__init__(container_strategy)

        if tpu_cluster_resolver is None:
            tpu_cluster_resolver = TPUClusterResolver("")

        if steps_per_run is None:
            # TODO(frankchn): Warn when we are being used by DS/Keras and this is
            # not specified.
            steps_per_run = 1

        self._tpu_function_cache = weakref.WeakKeyDictionary()
        self._tpu_cluster_resolver = tpu_cluster_resolver
        self._tpu_metadata = get_tpu_system_metadata(
            self._tpu_cluster_resolver)
        self._device_assignment = device_assignment

        self._tpu_devices = [
            d.name for d in self._tpu_metadata.devices
            if "device:TPU:" in d.name
        ]

        # Only create variables for the number of replicas we're running.
        if device_assignment is not None:
            job_name = device_spec.DeviceSpecV2.from_string(
                self._tpu_devices[0]).job

            self._tpu_devices = []
            for replica_id in range(device_assignment.num_replicas):
                tpu_device = device_assignment.tpu_device(replica=replica_id,
                                                          logical_core=0,
                                                          job=job_name)
                tpu_device = device_util.canonicalize(tpu_device)
                self._tpu_devices.append(tpu_device)

        self._host_device = device_util.get_host_for_device(
            self._tpu_devices[0])

        self._device_map = values.ReplicaDeviceMap(self._tpu_devices)

        # Preload the data onto the TPUs.
        input_worker_devices = collections.OrderedDict()
        for tpu_device in self._tpu_devices:
            host_device = device_util.get_host_for_device(tpu_device)
            input_worker_devices.setdefault(host_device, [])
            input_worker_devices[host_device].append(tpu_device)
        self._input_workers = input_lib.InputWorkers(
            self._device_map, tuple(input_worker_devices.items()))

        # TODO(sourabhbajaj): Remove this once performance of running one step
        # at a time is comparable to multiple steps.
        self.steps_per_run = steps_per_run
        self._require_static_shapes = True

        # TPUStrategy handles the graph replication in TF-XLA bridge, so we don't
        # need to retrace functions for each device.
        self._retrace_functions_for_each_device = False

        self.experimental_enable_get_next_as_optional = True
        self.experimental_enable_dynamic_batch_size = True