def __init__(self,
                 proc_func,
                 cluster_spec,
                 max_run_time=None,
                 capture_std_stream=False,
                 grpc_fail_fast=False,
                 args=None,
                 kwargs=None):
        """Creates a multi-process runner.

    Args:
      proc_func: Function to be run on child processes. This will be run on
        processes for all task types.
      cluster_spec: Dict for cluster spec. The following is an example of
        cluster with three workers and two ps's.
        {"worker": ["worker0.example.com:2222",
                    "worker1.example.com:2222",
                    "worker2.example.com:2222"],
         "ps": ["ps0.example.com:2222",
                "ps1.example.com:2222"]}
      max_run_time: If set, child processes is forced to exit at approximately
        this many seconds after `start` is called. We achieve this through
        `signal.alarm()` api. Note that this is best effort at Python level
        since Python signal handler does not get executed when it runs lower
        level C/C++ code. So it can be delayed for arbitrarily long time.
      capture_std_stream: Boolean, whether the messages streamed to stdout and
        stderr in subprocesses are captured.
      grpc_fail_fast: Whether GRPC connection between processes should fail
        without retrying. Defaults to False.
      args: Positional arguments to be sent to functions run on processes.
      kwargs: Keyword arguments to be sent to functions run on processes.

    Raises:
      RuntimeError: if `multi_process_runner.test_main()` is not called.
    """
        assert cluster_spec is not None
        assert callable(proc_func)

        if not multi_process_lib.using_context_manager():
            raise RuntimeError(
                '`multi_process_runner` is not initialized. '
                'Please call `multi_process_runner.test_main()` '
                'within `if __name__ == \'__main__\':` block '
                'in your python module to properly initialize '
                '`multi_process_runner`.')

        self._proc_func = proc_func
        self._cluster_spec = cluster_spec
        self._max_run_time = max_run_time
        self._capture_std_stream = capture_std_stream
        self._grpc_fail_fast = grpc_fail_fast
        self._args = args or ()
        self._kwargs = kwargs or {}
        self._processes = []

        # Child processes should have the same v2 and eager behavior.
        self._v2_enabled = tf2.enabled()
        self._executing_eagerly = context.executing_eagerly()
  def __init__(self,
               proc_func,
               cluster_spec,
               rpc_layer=None,
               max_run_time=None,
               grpc_fail_fast=None,
               stream_stdout=True,
               list_stdout=False,
               args=None,
               kwargs=None):
    """Creates a multi-process runner.

    Args:
      proc_func: Function to be run on child processes. This will be run on
        processes for all task types.
      cluster_spec: Dict for cluster spec. The following is an example of
        cluster with three workers and two ps's.
        {"worker": ["worker0.example.com:2222",
                    "worker1.example.com:2222",
                    "worker2.example.com:2222"],
         "ps": ["ps0.example.com:2222",
                "ps1.example.com:2222"]}
      rpc_layer: RPC layer to use. Default value is 'grpc+loas'.
      max_run_time: If set, child processes is forced to exit at approximately
        this many seconds after `start` is called. We achieve this through
        `signal.alarm()` api. Note that this is best effort at Python level
        since Python signal handler does not get executed when it runs lower
        level C/C++ code. So it can be delayed for arbitrarily long time.
      grpc_fail_fast: Whether GRPC connection between processes should fail
        without retrying. Defaults to None, in which case the environment
        variable is not explicitly set.
      stream_stdout: True if the output/error from the subprocesses should be
        streamed to be printed in parent process' log. Defaults to True.
      list_stdout: True if the output/error from the subprocesses should be
        collected to be attached to the resulting `MultiProcessRunnerResult`
        returned from `MultiProcessRunner.join()`. If True, the list of stdout
        can be retrieved via `MultiProcessRunnerResult.stdout` attribute.
        Defaults to False.
      args: Positional arguments to be sent to functions run on processes.
      kwargs: Keyword arguments to be sent to functions run on processes.

    Raises:
      RuntimeError: if `multi_process_runner.test_main()` is not called.
      ValueError: if there are more than one chief in the `cluster_spec`.
    """
    assert cluster_spec is not None
    if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1:
      raise ValueError('If chief exists in the cluster, there must be at most '
                       'one chief. Current `cluster_spec` has {} chiefs.'
                       .format(len(cluster_spec['chief'])))

    assert callable(proc_func)

    if not multi_process_lib.using_context_manager():
      raise RuntimeError('`multi_process_runner` is not initialized. '
                         'Please call `multi_process_runner.test_main()` '
                         'within `if __name__ == \'__main__\':` block '
                         'in your python module to properly initialize '
                         '`multi_process_runner`.')

    self._proc_func = proc_func
    self._cluster_spec = cluster_spec
    self._rpc_layer = rpc_layer
    self._max_run_time = max_run_time
    self._grpc_fail_fast = grpc_fail_fast
    self._stream_stdout = stream_stdout
    # TODO(rchao): Revisit list_stdout argument to consider other solution.
    self._list_stdout = list_stdout
    self._dependence_on_chief = True
    self._args = args or ()
    self._kwargs = kwargs or {}

    self._outstanding_subprocess_count = 0

    # Child processes should have the same v2 and eager behavior.
    self._v2_enabled = tf2.enabled()
    self._executing_eagerly = context.executing_eagerly()

    # This flag will be set to True once terminate_all() is called.
    self._all_forced_terminated = False
Example #3
0
    def run(self,
            proc_func,
            cluster_spec,
            proc_flags=None,
            timeout=200,
            time_to_exit=None,
            return_std_stream=False,
            args=None,
            kwargs=None):
        """Run functions on local sub-processes.

    Experimental. API subject to change. To fully inspect logging from
    subprocesses, use `--test_arg=--logtostderr` flag with bazel test.

    Args:
      proc_func: Function to be run on the processes. This will be run on
        processes for all task types.
      cluster_spec: Dict for cluster spec. The following is an example of
        cluster with three workers and two ps's.
        {"worker": ["worker0.example.com:2222",
                    "worker1.example.com:2222",
                    "worker2.example.com:2222"],
         "ps": ["ps0.example.com:2222",
                "ps1.example.com:2222"]}
      proc_flags: Dict that contains the key/values of the flags used on the
        processes.
      timeout: Time out in seconds. If the sub-process takes more than this time
        to complete, raise an error.
      time_to_exit: If set, sub-processes is forced to exit at approximately
        this many seconds after `run()` is called, through `signal.alarm()` api.
        This is for simulation of interruption on a process so in such cases no
        error is raised. Note that this is best effort at Python level since
        Python signal handler does not get executed inside the low-level (C)
        signal handler, so it can be delayed.
      return_std_stream: Boolean, whether the messages streamed to stdout and
        stderr in subprocesses are captured. If True, the messages are stored in
        a list returned as the second element.
      args: Positional arguments to be sent to functions run on processes.
      kwargs: Keyword arguments to be sent to functions run on processes.

    Returns:
      If `return_std_stream` is False, a list that stores the return data added
      by subprocesses through `multi_process_runner._add_return_data(data)`
      call,
      or through normal function return; if `return_std_stream` is True, a
      two-element tuple of `(return_data_list, std_stream_data_list)`, where
      `return_data_list` stores the return data added by processes through
      `multi_process_runner._add_return_data(data)` call or through normal
      function
      return, and `std_stream_data_list` stores the messages streamed to stdout
      and stderr in the subprocesses.

    Raises:
      RuntimeError: If any of the subprocesses raise an error, or if any of the
        subprocesses does not return or error out within `timeout` seconds.
    """

        assert cluster_spec is not None
        assert callable(proc_func)

        if not multi_process_lib.using_context_manager():
            raise RuntimeError(
                '`multi_process_runner` is not initialized. '
                'Please call `multi_process_runner.test_main()` '
                'within `if __name__ == \'__main__\':` block '
                'in your python module to properly initialize '
                '`multi_process_runner`.')

        processes = []
        args = args or ()
        kwargs = kwargs or {}

        def wrapper_func(tf_config_as_json, proc_func, proc_flags,
                         time_to_exit, executing_eagerly, *arg, **kwargs):
            """The wrapper function that actually gets run on the process(es)."""
            @contextlib.contextmanager
            def runtime_mode(executing_eagerly):
                if executing_eagerly:
                    with context.eager_mode():
                        yield
                else:
                    with context.graph_mode():
                        yield

            with runtime_mode(executing_eagerly):
                os.environ['TF_CONFIG'] = tf_config_as_json
                if proc_flags is not None:
                    for flag_key, flag_value in proc_flags.items():
                        setattr(flags.FLAGS, flag_key, flag_value)

                stdout_collector = _LogCollector(
                    sys.__stdout__) if return_std_stream else None
                stderr_collector = _LogCollector(
                    sys.__stderr__) if return_std_stream else None

                def finish_wrapper_func_properly(func_result):
                    """Call to finish `wrapper_func` properly."""
                    # Clear the alarm.
                    signal.alarm(0)
                    if (return_std_stream and stdout_collector is not None
                            and stderr_collector is not None):
                        # If stdout and stderr are to be collected, add them to std stream
                        # queue.
                        self._add_std_stream_data_flattened(
                            stdout_collector.log)
                        self._add_std_stream_data_flattened(
                            stderr_collector.log)
                        # Un-redirect stdout and stderr.
                        sys.stdout = sys.__stdout__
                        sys.stderr = sys.__stderr__
                    self._get_internal_queue().put(func_result)

                if time_to_exit is not None:

                    def handler(signum, frame):
                        del signum, frame
                        finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
                        # pylint: disable=protected-access
                        os._exit(0)

                    signal.signal(signal.SIGALRM, handler)
                    signal.alarm(time_to_exit)

                if return_std_stream:
                    sys.stdout = stdout_collector
                    sys.stderr = stderr_collector

                try:
                    return_data = proc_func(*arg, **kwargs)
                    if return_data is not None:
                        self._add_return_data(return_data)
                # pylint: disable=broad-except
                except Exception:
                    # Capture all exceptions to be reported to parent process.
                    finish_wrapper_func_properly(
                        _ExcInfoWrapper(sys.exc_info()))

                    # Re-raise the exception in addition to reporting it to the parent
                    # process, so that even if `--test_timeout` flag is set and the
                    # error doesn't make it to be shown in parent process before bazel's
                    # timeout, the log would still show what happens in this subprocess,
                    # instead of silently suppressing the error due to early bazel
                    # timeout. Raising an error in the subprocess produces stack trace in
                    # the log, but the program continues running.
                    raise

                finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)

        # Start number of processes according to `count_dict`.
        for job_type, addresses in cluster_spec.items():
            for task_id, _ in enumerate(addresses):
                tf_config_as_json = json.dumps({
                    'cluster': cluster_spec,
                    'task': {
                        'type': job_type,
                        'index': task_id
                    }
                })
                p = multi_process_lib.Process(
                    target=wrapper_func,
                    args=(tf_config_as_json, proc_func, proc_flags,
                          time_to_exit, context.executing_eagerly()) + args,
                    kwargs=kwargs)
                p.start()
                processes.append(p)

        internal_queue_results = []
        for _ in range(len(processes)):
            try:
                internal_queue_results.append(
                    self._get_internal_queue().get(timeout=timeout))
            except Queue.Empty:
                # First check if any of the subprocesses raised exception.
                for internal_queue_result in internal_queue_results:
                    if isinstance(internal_queue_result, _ExcInfoWrapper):
                        six.reraise(*internal_queue_result.exc_info)
                # If none of those did, report time out to user.
                raise RuntimeError(
                    'One or more subprocesses timed out. Please use '
                    '`--test_arg=--logtostderr` bazel flag to inspect logs for '
                    'subprocess debugging info. Timeout = {} sec.'.format(
                        timeout))

        for internal_queue_result in internal_queue_results:
            if isinstance(internal_queue_result, _ExcInfoWrapper):
                six.reraise(*internal_queue_result.exc_info)
            assert internal_queue_result == _FINISH_PROPERLY_MESSAGE

        def queue_to_list(queue_to_convert):
            """Convert `queue.Queue` to `list`."""
            list_to_return = []
            while True:
                try:
                    list_to_return.append(queue_to_convert.get(block=False))
                except Queue.Empty:
                    break
            return list_to_return

        if return_std_stream:
            return tuple(
                queue_to_list(multi_process_lib.get_user_data()[queue_name])
                for queue_name in [
                    _AvailableQueues.PUBLIC_QUEUE,
                    _AvailableQueues.STD_STREAM_QUEUE
                ])
        else:
            return queue_to_list(multi_process_lib.get_user_data()[
                _AvailableQueues.PUBLIC_QUEUE])