def start(self): """Starts processes, one for each task in `cluster_spec`.""" for task_type, addresses in self._cluster_spec.items(): for task_id, _ in enumerate(addresses): p = multi_process_lib.Process(target=self._proc_func_wrapper, args=(task_type, task_id) + self._args, kwargs=self._kwargs) p.start() self._processes.append(p)
def run(self, proc_func, args=None, kwargs=None): """Runs `proc_func` with `args` and `kwargs` on all jobs. Args: proc_func: The function to be run. args: Optional positional arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`. Returns: A list of return values. """ # TODO(b/150264776): skip in OSS until it's implemented. multi_process_lib.Process() if self._runner is None: self._start() # Since we start the processes as daemon they're going to be killed by # SIGTERM when the program exits. We only turn on streaming during run() to # avoid printing the stacktrace caused by the SIGTERM. self._runner._stream_stdout = True # pylint: disable=protected-access try: proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL) for conn in self._conn.values(): conn.send((proc_func, args or [], kwargs or {})) process_statuses = [] for (task_type, task_id), conn in self._conn.items(): logging.info('Waiting for the result from %s-%d', task_type, task_id) try: process_statuses.append(conn.recv()) except EOFError: # This shouldn't happen due to exceptions in proc_func. This usually # means bugs in the runner. self._reset() raise RuntimeError( 'Unexpected EOF. Worker process may have died. ' 'Please report a bug') return_values = [] for process_status in process_statuses: assert isinstance(process_status, _ProcessStatusInfo) if not process_status.is_successful: six.reraise(*process_status.exc_info) if process_status.return_value is not None: return_values.append(process_status.return_value) return return_values finally: self._runner._stream_stdout = False # pylint: disable=protected-access
def start(self): """Starts processes, one for each task in `cluster_spec`. If 'chief' job exists in the cluster, it is guaranteed that 'chief' process exits before other jobs to prevent chief from continuing to connect to them which causes error. """ for task_type, addresses in self._cluster_spec.items(): for task_id, _ in enumerate(addresses): p = multi_process_lib.Process(target=self._proc_func_wrapper, args=(task_type, task_id) + self._args, kwargs=self._kwargs) p.start() self._processes.append(p)
def run(self, proc_func, args=None, kwargs=None): """Runs `proc_func` with `args` and `kwargs` on all jobs. Args: proc_func: The function to be run. args: Optional positional arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`. Returns: A list of return values. """ # TODO(b/150264776): skip in OSS until it's implemented. multi_process_lib.Process() if self._runner is None: self._start() proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL) for conn in self._conn.values(): conn.send((proc_func, args or [], kwargs or {})) process_statuses = [] for (task_type, task_id), conn in self._conn.items(): logging.info('Waiting for the result from %s-%d', task_type, task_id) try: process_statuses.append(conn.recv()) except EOFError: # This shouldn't happen due to exceptions in proc_func. This usually # means bugs in the runner. self.shutdown() raise RuntimeError( 'Unexpected EOF. Worker process may have died. ' 'Please report a bug') return_values = [] for process_status in process_statuses: assert isinstance(process_status, _ProcessStatusInfo) if not process_status.is_successful: six.reraise(*process_status.exc_info) if process_status.return_value is not None: return_values.append(process_status.return_value) return return_values
def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, cluster_spec, args, kwargs): """Start a subprocess and a thread the reads lines from the subprocess.""" global _next_pipe_index pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index] _next_pipe_index += 1 p = multi_process_lib.Process( target=_Subprocess(), args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer, self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly, pipe_w) + args, kwargs=kwargs) p.start() self._outstanding_subprocess_count += 1 # For each subprocess, we dedicate a thread continuously reading lines # from them. thread = threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._continuously_readline_from_sub, args=(pipe_r, task_type, task_id)) thread.start()
def start_single_process(self, task_type, task_id, proc_func=None, updated_cluster_spec=None, args=None, kwargs=None): """Starts a single process. This starts a process in the cluster with the task type, task id, and the process function (`proc_func`). If process function is `None`, the function provided at `__init__` will be used. If `updated_cluster_spec` is not `None`, the cluster spec used by this subprocess will be updated. TODO(rchao): It is meant that all subprocesses will be updated with the new cluster spec, but this has yet to be implemented. At this time only the newly started subprocess picks up this updated cluster spec. Args: task_type: The task type. task_id: The task id. proc_func: The process function to be run on the newly started process. If `None`, the function provided at `__init__` will be used. updated_cluster_spec: If not `None`, the cluster spec used by this subprocess will be updated. args: Optional positional arguments to be supplied in `proc_func`. kwargs: Optional keyword arguments to be supplied in `proc_func`. """ self._cluster_spec = updated_cluster_spec or self._cluster_spec proc_func = proc_func or self._proc_func p = multi_process_lib.Process( target=self._proc_func_wrapper, args=(proc_func, task_type, task_id, self._cluster_spec) + (args or ()), kwargs=(kwargs or {})) p.start() self._outstanding_subprocess_count += 1
def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, args, kwargs): """Start a subprocess and a thread the reads lines from the subprocess.""" global _next_pipe_index pipe_r, pipe_w = multi_process_lib.get_user_data( )[STREAMING_PIPE][_next_pipe_index] _next_pipe_index += 1 p = multi_process_lib.Process( target=self._proc_func_wrapper, args=(proc_func, task_type, task_id, self._cluster_spec, self._rpc_layer, pipe_w) + args, kwargs=kwargs) p.start() self._outstanding_subprocess_count += 1 # For each subprocess, we dedicate a thread continuously reading lines # from them. thread = threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._continuously_readline_from_sub, args=(pipe_r, task_type, task_id), daemon=True) thread.start() _threads.append(thread)
def run(self, proc_func, cluster_spec, proc_flags=None, timeout=200, time_to_exit=None, return_std_stream=False, args=None, kwargs=None): """Run functions on local sub-processes. Experimental. API subject to change. To fully inspect logging from subprocesses, use `--test_arg=--logtostderr` flag with bazel test. Args: proc_func: Function to be run on the processes. This will be run on processes for all task types. cluster_spec: Dict for cluster spec. The following is an example of cluster with three workers and two ps's. {"worker": ["worker0.example.com:2222", "worker1.example.com:2222", "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]} proc_flags: Dict that contains the key/values of the flags used on the processes. timeout: Time out in seconds. If the sub-process takes more than this time to complete, raise an error. time_to_exit: If set, sub-processes is forced to exit at approximately this many seconds after `run()` is called, through `signal.alarm()` api. This is for simulation of interruption on a process so in such cases no error is raised. Note that this is best effort at Python level since Python signal handler does not get executed inside the low-level (C) signal handler, so it can be delayed. return_std_stream: Boolean, whether the messages streamed to stdout and stderr in subprocesses are captured. If True, the messages are stored in a list returned as the second element. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. Returns: If `return_std_stream` is False, a list that stores the return data added by subprocesses through `multi_process_runner._add_return_data(data)` call, or through normal function return; if `return_std_stream` is True, a two-element tuple of `(return_data_list, std_stream_data_list)`, where `return_data_list` stores the return data added by processes through `multi_process_runner._add_return_data(data)` call or through normal function return, and `std_stream_data_list` stores the messages streamed to stdout and stderr in the subprocesses. Raises: RuntimeError: If any of the subprocesses raise an error, or if any of the subprocesses does not return or error out within `timeout` seconds. """ assert cluster_spec is not None assert callable(proc_func) processes = [] args = args or () kwargs = kwargs or {} def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit, executing_eagerly, *arg, **kwargs): """The wrapper function that actually gets run on the process(es).""" @contextlib.contextmanager def runtime_mode(executing_eagerly): if executing_eagerly: with context.eager_mode(): yield else: with context.graph_mode(): yield with runtime_mode(executing_eagerly): os.environ['TF_CONFIG'] = tf_config_as_json if proc_flags is not None: for flag_key, flag_value in proc_flags.items(): setattr(flags.FLAGS, flag_key, flag_value) stdout_collector = _LogCollector( sys.__stdout__) if return_std_stream else None stderr_collector = _LogCollector( sys.__stderr__) if return_std_stream else None def finish_wrapper_func_properly(func_result): """Call to finish `wrapper_func` properly.""" # Clear the alarm. signal.alarm(0) if (return_std_stream and stdout_collector is not None and stderr_collector is not None): # If stdout and stderr are to be collected, add them to std stream # queue. self._add_std_stream_data_flattened( stdout_collector.log) self._add_std_stream_data_flattened( stderr_collector.log) # Un-redirect stdout and stderr. sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ self._get_internal_queue().put(func_result) if time_to_exit is not None: def handler(signum, frame): del signum, frame finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) # pylint: disable=protected-access os._exit(0) signal.signal(signal.SIGALRM, handler) signal.alarm(time_to_exit) if return_std_stream: sys.stdout = stdout_collector sys.stderr = stderr_collector try: return_data = proc_func(*arg, **kwargs) if return_data is not None: self._add_return_data(return_data) # pylint: disable=broad-except except Exception: # Capture all exceptions to be reported to parent process. finish_wrapper_func_properly( _ExcInfoWrapper(sys.exc_info())) # Re-raise the exception in addition to reporting it to the parent # process, so that even if `--test_timeout` flag is set and the # error doesn't make it to be shown in parent process before bazel's # timeout, the log would still show what happens in this subprocess, # instead of silently suppressing the error due to early bazel # timeout. Raising an error in the subprocess produces stack trace in # the log, but the program continues running. raise finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) # Start number of processes according to `count_dict`. for job_type, addresses in cluster_spec.items(): for task_id, _ in enumerate(addresses): tf_config_as_json = json.dumps({ 'cluster': cluster_spec, 'task': { 'type': job_type, 'index': task_id } }) p = multi_process_lib.Process( target=wrapper_func, args=(tf_config_as_json, proc_func, proc_flags, time_to_exit, context.executing_eagerly()) + args, kwargs=kwargs) p.start() processes.append(p) internal_queue_results = [] for _ in range(len(processes)): try: internal_queue_results.append( self._get_internal_queue().get(timeout=timeout)) except Queue.Empty: # First check if any of the subprocesses raised exception. for internal_queue_result in internal_queue_results: if isinstance(internal_queue_result, _ExcInfoWrapper): six.reraise(*internal_queue_result.exc_info) # If none of those did, report time out to user. raise RuntimeError( 'One or more subprocesses timed out. Please use ' '`--test_arg=--logtostderr` bazel flag to inspect logs for ' 'subprocess debugging info. Timeout = {} sec.'.format( timeout)) for internal_queue_result in internal_queue_results: if isinstance(internal_queue_result, _ExcInfoWrapper): six.reraise(*internal_queue_result.exc_info) assert internal_queue_result == _FINISH_PROPERLY_MESSAGE def queue_to_list(queue_to_convert): """Convert `queue.Queue` to `list`.""" list_to_return = [] while True: try: list_to_return.append(queue_to_convert.get(block=False)) except Queue.Empty: break return list_to_return if return_std_stream: return tuple( queue_to_list(multi_process_lib.get_user_data()[queue_name]) for queue_name in [ _AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE ]) else: return queue_to_list(multi_process_lib.get_user_data()[ _AvailableQueues.PUBLIC_QUEUE])
def run(proc_func, count_dict, proc_flags=None, timeout=200, time_to_exit=None, return_std_stream=False, args=None, kwargs=None): """Run functions on local sub-processes. Args: proc_func: Function to be run on the processes. This will be run on processes for all task types. count_dict: Dict for task_type/count of such task type. proc_flags: Dict that contains the key/values of the flags used on the processes. timeout: Time out in seconds. If the sub-process takes more than this time to complete, raise an error. time_to_exit: If set, sub-processes is forced to exit at approximately this many seconds after `run()` is called, through `signal.alarm()` api. This is for simulation of interruption on a process so in such cases no error is raised. Note that this is best effort at Python level since Python signal handler does not get executed inside the low-level (C) signal handler, so it can be delayed. return_std_stream: Boolean, whether the messages streamed to stdout and stderr in subprocesses are captured. If True, the messages are stored in a list returned as the second element. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. Returns: If `return_std_stream` is False, a list that stores the return data added by processes through `multi_process_runner.add_return_data(data)` call; if `return_std_stream` is True, a two-element tuple of `(return_data_list, std_stream_data_list)`, where `return_data_list` stores the return data added by processes through `multi_process_runner.add_return_data(data)` call, and `std_stream_data_list` stores the messages streamed to stdout and stderr in the subprocesses. Raises: RuntimeError: If any of the subprocesses raise an error, or if any of the subprocesses does not return or error out within `timeout` seconds. TODO(rchao): Open source this with a solution to handle multi_process_lib. """ assert callable(proc_func) processes = [] cluster_spec = {} args = args or () kwargs = kwargs or {} for task_type, count in count_dict.items(): cluster_spec[task_type] = [ 'localhost:{}'.format(multi_worker_test_base.pick_unused_port()) for _ in range(count) ] def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit, *arg, **kwargs): """The wrapper function that actually gets run on the process(es).""" os.environ['TF_CONFIG'] = tf_config_as_json if proc_flags is not None: for flag_key, flag_value in proc_flags.items(): setattr(flags.FLAGS, flag_key, flag_value) stdout_collector = _LogCollector( sys.__stdout__) if return_std_stream else None stderr_collector = _LogCollector( sys.__stderr__) if return_std_stream else None def finish_wrapper_func_properly(finish_message=_FINISH_PROPERLY_MESSAGE): """Call to finish `wrapper_func` properly.""" # Clear the alarm. signal.alarm(0) if (return_std_stream and stdout_collector is not None and stderr_collector is not None): # If stdout and stderr are to be collected, add them to std stream # queue. _add_std_stream_data_flattened(stdout_collector.log) _add_std_stream_data_flattened(stderr_collector.log) # Un-redirect stdout and stderr. sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ _get_internal_queue().put(finish_message) if time_to_exit is not None: def handler(signum, frame): del signum, frame finish_wrapper_func_properly() # pylint: disable=protected-access os._exit(0) signal.signal(signal.SIGALRM, handler) signal.alarm(time_to_exit) if return_std_stream: sys.stdout = stdout_collector sys.stderr = stderr_collector try: proc_func(*arg, **kwargs) # pylint: disable=broad-except except Exception as e: # Capture all exceptions to be reported to parent process. finish_wrapper_func_properly( 'Exception raised by subprocess: {}: {} {}'.format( e.__class__.__name__, str(e), traceback.format_exc())) return finish_wrapper_func_properly() # Start number of processes according to `count_dict`. for task_type, count in count_dict.items(): for task_id in range(count): tf_config_as_json = json.dumps({ 'cluster': cluster_spec, 'task': { 'type': task_type, 'index': task_id } }) p = multi_process_lib.Process( target=wrapper_func, args=(tf_config_as_json, proc_func, proc_flags, time_to_exit) + args, kwargs=kwargs) p.start() processes.append(p) internal_queue_results = [] for _ in range(len(processes)): try: internal_queue_results.append( _get_internal_queue().get(timeout=timeout)) except Queue.Empty: raise RuntimeError( 'One or more subprocesses timed out. Please inspect logs for ' 'subprocess debugging info. Timeout = {} sec.'.format(timeout)) for internal_queue_result in internal_queue_results: if internal_queue_result.startswith('Exception raised by subprocess'): raise RuntimeError(internal_queue_result) assert internal_queue_result == _FINISH_PROPERLY_MESSAGE def queue_to_list(queue_to_convert): """Convert `queue.Queue` to `list`.""" list_to_return = [] while True: try: list_to_return.append(queue_to_convert.get(block=False)) except Queue.Empty: break return list_to_return if return_std_stream: return tuple( queue_to_list(multi_process_lib.get_user_data()[queue_name]) for queue_name in [AvailableQueues.PUBLIC_QUEUE, AvailableQueues.STD_STREAM_QUEUE]) else: return queue_to_list( multi_process_lib.get_user_data()[AvailableQueues.PUBLIC_QUEUE])