Beispiel #1
0
  def execute_on(self, worker):
    """Executes the closure on the given worker.

    Args:
      worker: a `Worker` object.
    """
    replica_args = _select_worker_slice(worker.worker_index, self._args)
    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)

    e = (
        _maybe_get_error_and_rebuild_remote_values(worker, replica_args) or
        _maybe_get_error_and_rebuild_remote_values(worker, replica_kwargs))
    if e:
      if not isinstance(e, InputError):
        e = InputError(e)
      for remote_value in nest.flatten(self._output_remote_values):
        remote_value._set_error(e)  # pylint: disable=protected-access
      return

    with ops.device(worker.device_name):
      with context.executor_scope(worker.executor):
        with metric_utils.monitored_timer("closure_execution"):
          output_value = self._function(
              *nest.map_structure(_maybe_get_remote_value, replica_args),
              **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
    for remote_value, value in zip(
        nest.flatten(self._output_remote_values), nest.flatten(output_value)):
      remote_value._set_value(value)  # pylint: disable=protected-access
Beispiel #2
0
  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
    if not callable(function):
      raise ValueError("Function passed to `Client.schedule` must be a "
                       "callable object.")
    self._args = args or ()
    self._kwargs = kwargs or {}

    _disallow_remote_value_as_input(self._args)
    _disallow_remote_value_as_input(self._kwargs)

    if isinstance(function, def_function.Function):
      replica_args = _select_worker_slice(0, self._args)
      replica_kwargs = _select_worker_slice(0, self._kwargs)

      # Note: no need to handle function registration failure since this kind of
      # failure will not raise exceptions as designed in the runtime. The client
      # has to rely on subsequent operations that raise to catch function
      # registration failure.

      # Record the function tracing overhead. Note that we pass in the tracing
      # count of the def_function.Function as a state tracker, so that metrics
      # will only record the time for actual function tracing (i.e., excluding
      # function cache lookups).
      with metric_utils.monitored_timer(
          "function_tracing", state_tracker=function._get_tracing_count):  # pylint: disable=protected-access
        concrete_function = function.get_concrete_function(
            *nest.map_structure(_maybe_as_type_spec, replica_args),
            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
      self._function = cancellation_mgr.get_cancelable_function(
          concrete_function)
      self._output_remote_values = nest.map_structure(
          lambda x: RemoteValue(self, x), concrete_function.structured_outputs)
    elif isinstance(function, tf_function.ConcreteFunction):
      self._function = cancellation_mgr.get_cancelable_function(function)
      self._output_remote_values = nest.map_structure(
          lambda x: RemoteValue(self, x), function.structured_outputs)
    else:
      # Regular python functions.
      self._function = function
      # TODO(yuefengz): maybe we should trace python functions if their inputs
      # are Python primitives, tensors and composite tensors.
      self._output_remote_values = RemoteValue(self, None)
Beispiel #3
0
 def _process_closure(self, closure):
   """Runs a closure with preemption handling."""
   try:
     with self._cluster.failure_handler.wait_on_failure(
         on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  # pylint: disable=protected-access
         on_recovery_fn=self._set_resources_aborted,
         worker_device_name=self.device_name):
       closure.execute_on(self)
       # TODO(yuefengz): we don't have to materialize results every step.
       with metric_utils.monitored_timer("remote_value_fetch"):
         closure._fetch_output_remote_values()  # pylint: disable=protected-access
       self._cluster._closure_queue.mark_finished()  # pylint: disable=protected-access
   except Exception as e:  # pylint: disable=broad-except
     logging.error(
         "/job:worker/task:%d encountered the following error when processing "
         "closure: %r:%s", self.worker_index, e, e)
     nest.map_structure(
         lambda x: x._set_error(e),  # pylint: disable=protected-access
         closure._output_remote_values)  # pylint: disable=protected-access
     self._cluster._closure_queue.mark_failed(e)  # pylint: disable=protected-access