def join(self, timeout=None): """Joins all the processes with timeout. Args: timeout: if set and not all processes report status within roughly `timeout` seconds, a `RuntimeError` exception will be thrown. Returns: It returns a tuple. The first element is a list that stores the return data added by subprocesses through `_add_return_data` or through normal function return; The second element is a list of the messages streamed to stdout and stderr in the subprocesses if `capture_std_stream` is True or `None` otherwise. Raises: RuntimeError: if not all processes report status within `timeout` seconds. Or the exception propagated from any child process. """ if not timeout: if self._max_run_time: timeout = self._max_run_time + 10 # add 10 seconds grace period else: timeout = float('inf') start_time = time.time() while self._outstanding_subprocess_count > 0: while True: try: process_status = self._get_process_status_queue().get(timeout=10) break except Queue.Empty: if time.time() - start_time > timeout: # If none of those did, report timeout to user. raise RuntimeError( 'One or more subprocesses timed out. Please use ' '`--test_arg=--logtostderr` bazel flag to inspect logs for ' 'subprocess debugging info. Number of outstanding subprocesses ' 'is %d.' % self._outstanding_subprocess_count) self._outstanding_subprocess_count -= 1 assert isinstance(process_status, _ProcessStatusInfo) if not process_status.is_successful: six.reraise(*process_status.exc_info) if self._capture_std_stream: # TODO(yuefengz): we need to make sure elements match the same process in # the two returned lists so as to not surprise users. Consider creating a # `ReturnData` class. return tuple( self._queue_to_list(multi_process_lib.get_user_data()[queue_name]) for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE]) else: return (self._queue_to_list( multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None)
def _add_return_data(self, data): """Adds return data that will be returned by `join`. The function provides a way for child processes to communicate with the parent process. Data passed to `_add_return_data` will be available in a Python Queue.Queue that is eventually returned by `join`. Args: data: data to be made available in the queue returned by `join`. """ # TODO(rchao): Incorporate the task type and id information in a data # wrapper that becomes what is stored in the queue so we can tell where # the data is from. multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data)
def _add_std_stream_data_flattened(self, data): # TODO(yuefengz): currently the same queue is used by multiple processes. It # is difficult for users to distinguish between logs from different # processes. std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE] for d in list(data): std_stream_queue.put(d)
def add_return_data(data): """Add return data that will be returned by `multi_process_runner.run()`. The function provides a way for processes started by `multi_process_runner.run()` to communicate with the original process that started the sub-processes. Data passed to `add_return_data` will be available in a python Queue.Queue that is eventually returned by `multi_process_runner.run()`. Args: data: data to be made available in the queue returned by `multi_process_runner.run()`. """ # TODO(rchao): Incorporate the task type and id information in a data # wrapper that becomes what is stored in the queue so we can tell where # the data is from. multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data)
def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id, args, kwargs): """Start a subprocess and a thread the reads lines from the subprocess.""" global _next_pipe_index pipe_r, pipe_w = multi_process_lib.get_user_data( )[STREAMING_PIPE][_next_pipe_index] _next_pipe_index += 1 p = multi_process_lib.Process( target=self._proc_func_wrapper, args=(proc_func, task_type, task_id, self._cluster_spec, self._rpc_layer, pipe_w) + args, kwargs=kwargs) p.start() self._outstanding_subprocess_count += 1 # For each subprocess, we dedicate a thread continuously reading lines # from them. thread = threading.Thread( # pylint: disable=unexpected-keyword-arg target=self._continuously_readline_from_sub, args=(pipe_r, task_type, task_id), daemon=True) thread.start() _threads.append(thread)
def _get_parent_to_sub_queue(self): return multi_process_lib.get_user_data()[PARENT_TO_SUB_QUEUE]
def _get_inter_process_queue(self): return multi_process_lib.get_user_data()[INTER_PROCESS_QUEUE]
def _get_process_status_queue(self): return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE]
def run(self, proc_func, cluster_spec, proc_flags=None, timeout=200, time_to_exit=None, return_std_stream=False, args=None, kwargs=None): """Run functions on local sub-processes. Experimental. API subject to change. To fully inspect logging from subprocesses, use `--test_arg=--logtostderr` flag with bazel test. Args: proc_func: Function to be run on the processes. This will be run on processes for all task types. cluster_spec: Dict for cluster spec. The following is an example of cluster with three workers and two ps's. {"worker": ["worker0.example.com:2222", "worker1.example.com:2222", "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]} proc_flags: Dict that contains the key/values of the flags used on the processes. timeout: Time out in seconds. If the sub-process takes more than this time to complete, raise an error. time_to_exit: If set, sub-processes is forced to exit at approximately this many seconds after `run()` is called, through `signal.alarm()` api. This is for simulation of interruption on a process so in such cases no error is raised. Note that this is best effort at Python level since Python signal handler does not get executed inside the low-level (C) signal handler, so it can be delayed. return_std_stream: Boolean, whether the messages streamed to stdout and stderr in subprocesses are captured. If True, the messages are stored in a list returned as the second element. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. Returns: If `return_std_stream` is False, a list that stores the return data added by subprocesses through `multi_process_runner._add_return_data(data)` call, or through normal function return; if `return_std_stream` is True, a two-element tuple of `(return_data_list, std_stream_data_list)`, where `return_data_list` stores the return data added by processes through `multi_process_runner._add_return_data(data)` call or through normal function return, and `std_stream_data_list` stores the messages streamed to stdout and stderr in the subprocesses. Raises: RuntimeError: If any of the subprocesses raise an error, or if any of the subprocesses does not return or error out within `timeout` seconds. """ assert cluster_spec is not None assert callable(proc_func) processes = [] args = args or () kwargs = kwargs or {} def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit, executing_eagerly, *arg, **kwargs): """The wrapper function that actually gets run on the process(es).""" @contextlib.contextmanager def runtime_mode(executing_eagerly): if executing_eagerly: with context.eager_mode(): yield else: with context.graph_mode(): yield with runtime_mode(executing_eagerly): os.environ['TF_CONFIG'] = tf_config_as_json if proc_flags is not None: for flag_key, flag_value in proc_flags.items(): setattr(flags.FLAGS, flag_key, flag_value) stdout_collector = _LogCollector( sys.__stdout__) if return_std_stream else None stderr_collector = _LogCollector( sys.__stderr__) if return_std_stream else None def finish_wrapper_func_properly(func_result): """Call to finish `wrapper_func` properly.""" # Clear the alarm. signal.alarm(0) if (return_std_stream and stdout_collector is not None and stderr_collector is not None): # If stdout and stderr are to be collected, add them to std stream # queue. self._add_std_stream_data_flattened( stdout_collector.log) self._add_std_stream_data_flattened( stderr_collector.log) # Un-redirect stdout and stderr. sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ self._get_internal_queue().put(func_result) if time_to_exit is not None: def handler(signum, frame): del signum, frame finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) # pylint: disable=protected-access os._exit(0) signal.signal(signal.SIGALRM, handler) signal.alarm(time_to_exit) if return_std_stream: sys.stdout = stdout_collector sys.stderr = stderr_collector try: return_data = proc_func(*arg, **kwargs) if return_data is not None: self._add_return_data(return_data) # pylint: disable=broad-except except Exception: # Capture all exceptions to be reported to parent process. finish_wrapper_func_properly( _ExcInfoWrapper(sys.exc_info())) # Re-raise the exception in addition to reporting it to the parent # process, so that even if `--test_timeout` flag is set and the # error doesn't make it to be shown in parent process before bazel's # timeout, the log would still show what happens in this subprocess, # instead of silently suppressing the error due to early bazel # timeout. Raising an error in the subprocess produces stack trace in # the log, but the program continues running. raise finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE) # Start number of processes according to `count_dict`. for job_type, addresses in cluster_spec.items(): for task_id, _ in enumerate(addresses): tf_config_as_json = json.dumps({ 'cluster': cluster_spec, 'task': { 'type': job_type, 'index': task_id } }) p = multi_process_lib.Process( target=wrapper_func, args=(tf_config_as_json, proc_func, proc_flags, time_to_exit, context.executing_eagerly()) + args, kwargs=kwargs) p.start() processes.append(p) internal_queue_results = [] for _ in range(len(processes)): try: internal_queue_results.append( self._get_internal_queue().get(timeout=timeout)) except Queue.Empty: # First check if any of the subprocesses raised exception. for internal_queue_result in internal_queue_results: if isinstance(internal_queue_result, _ExcInfoWrapper): six.reraise(*internal_queue_result.exc_info) # If none of those did, report time out to user. raise RuntimeError( 'One or more subprocesses timed out. Please use ' '`--test_arg=--logtostderr` bazel flag to inspect logs for ' 'subprocess debugging info. Timeout = {} sec.'.format( timeout)) for internal_queue_result in internal_queue_results: if isinstance(internal_queue_result, _ExcInfoWrapper): six.reraise(*internal_queue_result.exc_info) assert internal_queue_result == _FINISH_PROPERLY_MESSAGE def queue_to_list(queue_to_convert): """Convert `queue.Queue` to `list`.""" list_to_return = [] while True: try: list_to_return.append(queue_to_convert.get(block=False)) except Queue.Empty: break return list_to_return if return_std_stream: return tuple( queue_to_list(multi_process_lib.get_user_data()[queue_name]) for queue_name in [ _AvailableQueues.PUBLIC_QUEUE, _AvailableQueues.STD_STREAM_QUEUE ]) else: return queue_to_list(multi_process_lib.get_user_data()[ _AvailableQueues.PUBLIC_QUEUE])
def _add_std_stream_data_flattened(self, data): std_stream_queue = multi_process_lib.get_user_data()[ _AvailableQueues.STD_STREAM_QUEUE] for d in list(data): std_stream_queue.put(d)
def _get_subprocess_info_queue(self): return multi_process_lib.get_user_data()[SUBPROCESS_INFO_QUEUE]
def join(self, timeout=None): """Joins all the processes with timeout. Args: timeout: if set and not all processes report status within roughly `timeout` seconds, a `RuntimeError` exception will be thrown. Returns: A MultiProcessRunnerResult object, which has two attributes, `return_value` and `stdout`. `return_value` always contains the return values from the subprocesses. If `list_stdout` argument is True at `__init__`, `stdout` is available that contains a list of all messages from subprocesses' stdout and stderr. Raises: RuntimeError: if not all processes report status approximatelty within `timeout` seconds, or there's an exception propagated from any subprocess. """ if not timeout: timeout = float('inf') start_time = time.time() while self._outstanding_subprocess_count > 0: while True: try: process_status = self._get_process_status_queue().get( timeout=10) break except Queue.Empty: if self._all_forced_terminated: break if time.time() - start_time > timeout: # If none of those did, report timeout to user. raise RuntimeError( 'One or more subprocesses timed out. ' 'Number of outstanding subprocesses ' 'is %d.' % self._outstanding_subprocess_count) if self._all_forced_terminated: break self._outstanding_subprocess_count -= 1 assert isinstance(process_status, _ProcessStatusInfo) if not process_status.is_successful: six.reraise(*process_status.exc_info) if self._dependence_on_chief and process_status.task_type == 'chief': self.terminate_all() break # Giving threads some time to finish the message reading from subprocesses. time.sleep(5) stdout = self._queue_to_list( multi_process_lib.get_user_data()[STREAMING_QUEUE]) return_value = self._queue_to_list( multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]) # Notifying the threads that are reading lines that we should stop. for pipe_index in range(self._starting_pipe_index, _next_pipe_index): # pylint: disable=protected-access _, pipe_w = multi_process_lib.get_user_data( )[STREAMING_PIPE][pipe_index] writer = os.fdopen(pipe_w.fileno(), 'w') # Writing end of file message so the threads that's actively reading lines # know to stop. writer.writelines(['EOF']) writer.close() for thread in _threads: thread.join(5) return MultiProcessRunnerResult(stdout=stdout, return_value=return_value)
def _add_stdout_in_queue(self, formatted_line, task_type, task_id): del task_type, task_id # A queue instead of a simple list is used here due to b/150652733. multi_process_lib.get_user_data()[STREAMING_QUEUE].put(formatted_line)
def barrier(): return multi_process_lib.get_user_data()[BARRIER]
def _get_internal_queue(self): return multi_process_lib.get_user_data()[ _AvailableQueues.INTERNAL_QUEUE]
def _resource(resource_name): return multi_process_lib.get_user_data()[resource_name]
def run(proc_func, count_dict, proc_flags=None, timeout=200, time_to_exit=None, return_std_stream=False, args=None, kwargs=None): """Run functions on local sub-processes. Args: proc_func: Function to be run on the processes. This will be run on processes for all task types. count_dict: Dict for task_type/count of such task type. proc_flags: Dict that contains the key/values of the flags used on the processes. timeout: Time out in seconds. If the sub-process takes more than this time to complete, raise an error. time_to_exit: If set, sub-processes is forced to exit at approximately this many seconds after `run()` is called, through `signal.alarm()` api. This is for simulation of interruption on a process so in such cases no error is raised. Note that this is best effort at Python level since Python signal handler does not get executed inside the low-level (C) signal handler, so it can be delayed. return_std_stream: Boolean, whether the messages streamed to stdout and stderr in subprocesses are captured. If True, the messages are stored in a list returned as the second element. args: Positional arguments to be sent to functions run on processes. kwargs: Keyword arguments to be sent to functions run on processes. Returns: If `return_std_stream` is False, a list that stores the return data added by processes through `multi_process_runner.add_return_data(data)` call; if `return_std_stream` is True, a two-element tuple of `(return_data_list, std_stream_data_list)`, where `return_data_list` stores the return data added by processes through `multi_process_runner.add_return_data(data)` call, and `std_stream_data_list` stores the messages streamed to stdout and stderr in the subprocesses. Raises: RuntimeError: If any of the subprocesses raise an error, or if any of the subprocesses does not return or error out within `timeout` seconds. TODO(rchao): Open source this with a solution to handle multi_process_lib. """ assert callable(proc_func) processes = [] cluster_spec = {} args = args or () kwargs = kwargs or {} for task_type, count in count_dict.items(): cluster_spec[task_type] = [ 'localhost:{}'.format(multi_worker_test_base.pick_unused_port()) for _ in range(count) ] def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit, *arg, **kwargs): """The wrapper function that actually gets run on the process(es).""" os.environ['TF_CONFIG'] = tf_config_as_json if proc_flags is not None: for flag_key, flag_value in proc_flags.items(): setattr(flags.FLAGS, flag_key, flag_value) stdout_collector = _LogCollector( sys.__stdout__) if return_std_stream else None stderr_collector = _LogCollector( sys.__stderr__) if return_std_stream else None def finish_wrapper_func_properly(finish_message=_FINISH_PROPERLY_MESSAGE): """Call to finish `wrapper_func` properly.""" # Clear the alarm. signal.alarm(0) if (return_std_stream and stdout_collector is not None and stderr_collector is not None): # If stdout and stderr are to be collected, add them to std stream # queue. _add_std_stream_data_flattened(stdout_collector.log) _add_std_stream_data_flattened(stderr_collector.log) # Un-redirect stdout and stderr. sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ _get_internal_queue().put(finish_message) if time_to_exit is not None: def handler(signum, frame): del signum, frame finish_wrapper_func_properly() # pylint: disable=protected-access os._exit(0) signal.signal(signal.SIGALRM, handler) signal.alarm(time_to_exit) if return_std_stream: sys.stdout = stdout_collector sys.stderr = stderr_collector try: proc_func(*arg, **kwargs) # pylint: disable=broad-except except Exception as e: # Capture all exceptions to be reported to parent process. finish_wrapper_func_properly( 'Exception raised by subprocess: {}: {} {}'.format( e.__class__.__name__, str(e), traceback.format_exc())) return finish_wrapper_func_properly() # Start number of processes according to `count_dict`. for task_type, count in count_dict.items(): for task_id in range(count): tf_config_as_json = json.dumps({ 'cluster': cluster_spec, 'task': { 'type': task_type, 'index': task_id } }) p = multi_process_lib.Process( target=wrapper_func, args=(tf_config_as_json, proc_func, proc_flags, time_to_exit) + args, kwargs=kwargs) p.start() processes.append(p) internal_queue_results = [] for _ in range(len(processes)): try: internal_queue_results.append( _get_internal_queue().get(timeout=timeout)) except Queue.Empty: raise RuntimeError( 'One or more subprocesses timed out. Please inspect logs for ' 'subprocess debugging info. Timeout = {} sec.'.format(timeout)) for internal_queue_result in internal_queue_results: if internal_queue_result.startswith('Exception raised by subprocess'): raise RuntimeError(internal_queue_result) assert internal_queue_result == _FINISH_PROPERLY_MESSAGE def queue_to_list(queue_to_convert): """Convert `queue.Queue` to `list`.""" list_to_return = [] while True: try: list_to_return.append(queue_to_convert.get(block=False)) except Queue.Empty: break return list_to_return if return_std_stream: return tuple( queue_to_list(multi_process_lib.get_user_data()[queue_name]) for queue_name in [AvailableQueues.PUBLIC_QUEUE, AvailableQueues.STD_STREAM_QUEUE]) else: return queue_to_list( multi_process_lib.get_user_data()[AvailableQueues.PUBLIC_QUEUE])