def execute_on(self, worker): """Executes the closure on the given worker. Args: worker: a `Worker` object. """ replica_args = _select_worker_slice(worker.worker_index, self._args) replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs) e = ( _maybe_get_error_and_rebuild_remote_values(worker, replica_args) or _maybe_get_error_and_rebuild_remote_values(worker, replica_kwargs)) if e: if not isinstance(e, InputError): e = InputError(e) for remote_value in nest.flatten(self._output_remote_values): remote_value._set_error(e) # pylint: disable=protected-access return with ops.device(worker.device_name): with context.executor_scope(worker.executor): with metric_utils.monitored_timer("closure_execution"): output_value = self._function( *nest.map_structure(_maybe_get_remote_value, replica_args), **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) for remote_value, value in zip( nest.flatten(self._output_remote_values), nest.flatten(output_value)): remote_value._set_value(value) # pylint: disable=protected-access
def __init__(self, function, cancellation_mgr, args=None, kwargs=None): if not callable(function): raise ValueError("Function passed to `Client.schedule` must be a " "callable object.") self._args = args or () self._kwargs = kwargs or {} _disallow_remote_value_as_input(self._args) _disallow_remote_value_as_input(self._kwargs) if isinstance(function, def_function.Function): replica_args = _select_worker_slice(0, self._args) replica_kwargs = _select_worker_slice(0, self._kwargs) # Note: no need to handle function registration failure since this kind of # failure will not raise exceptions as designed in the runtime. The client # has to rely on subsequent operations that raise to catch function # registration failure. # Record the function tracing overhead. Note that we pass in the tracing # count of the def_function.Function as a state tracker, so that metrics # will only record the time for actual function tracing (i.e., excluding # function cache lookups). with metric_utils.monitored_timer( "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access concrete_function = function.get_concrete_function( *nest.map_structure(_maybe_as_type_spec, replica_args), **nest.map_structure(_maybe_as_type_spec, replica_kwargs)) self._function = cancellation_mgr.get_cancelable_function( concrete_function) self._output_remote_values = nest.map_structure( lambda x: RemoteValue(self, x), concrete_function.structured_outputs) elif isinstance(function, tf_function.ConcreteFunction): self._function = cancellation_mgr.get_cancelable_function(function) self._output_remote_values = nest.map_structure( lambda x: RemoteValue(self, x), function.structured_outputs) else: # Regular python functions. self._function = function # TODO(yuefengz): maybe we should trace python functions if their inputs # are Python primitives, tensors and composite tensors. self._output_remote_values = RemoteValue(self, None)
def _process_closure(self, closure): """Runs a closure with preemption handling.""" try: with self._cluster.failure_handler.wait_on_failure( on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure), # pylint: disable=protected-access on_recovery_fn=self._set_resources_aborted, worker_device_name=self.device_name): closure.execute_on(self) # TODO(yuefengz): we don't have to materialize results every step. with metric_utils.monitored_timer("remote_value_fetch"): closure._fetch_output_remote_values() # pylint: disable=protected-access self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access except Exception as e: # pylint: disable=broad-except logging.error( "/job:worker/task:%d encountered the following error when processing " "closure: %r:%s", self.worker_index, e, e) nest.map_structure( lambda x: x._set_error(e), # pylint: disable=protected-access closure._output_remote_values) # pylint: disable=protected-access self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access