예제 #1
0
 def start(self):
     """Starts processes, one for each task in `cluster_spec`."""
     for task_type, addresses in self._cluster_spec.items():
         for task_id, _ in enumerate(addresses):
             p = multi_process_lib.Process(target=self._proc_func_wrapper,
                                           args=(task_type, task_id) +
                                           self._args,
                                           kwargs=self._kwargs)
             p.start()
             self._processes.append(p)
    def run(self, proc_func, args=None, kwargs=None):
        """Runs `proc_func` with `args` and `kwargs` on all jobs.

    Args:
      proc_func: The function to be run.
      args: Optional positional arguments to be supplied in `proc_func`.
      kwargs: Optional keyword arguments to be supplied in `proc_func`.

    Returns:
      A list of return values.
    """
        # TODO(b/150264776): skip in OSS until it's implemented.
        multi_process_lib.Process()
        if self._runner is None:
            self._start()

        # Since we start the processes as daemon they're going to be killed by
        # SIGTERM when the program exits. We only turn on streaming during run() to
        # avoid printing the stacktrace caused by the SIGTERM.
        self._runner._stream_stdout = True  # pylint: disable=protected-access

        try:
            proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL)
            for conn in self._conn.values():
                conn.send((proc_func, args or [], kwargs or {}))

            process_statuses = []
            for (task_type, task_id), conn in self._conn.items():
                logging.info('Waiting for the result from %s-%d', task_type,
                             task_id)
                try:
                    process_statuses.append(conn.recv())
                except EOFError:
                    # This shouldn't happen due to exceptions in proc_func. This usually
                    # means bugs in the runner.
                    self._reset()
                    raise RuntimeError(
                        'Unexpected EOF. Worker process may have died. '
                        'Please report a bug')

            return_values = []
            for process_status in process_statuses:
                assert isinstance(process_status, _ProcessStatusInfo)
                if not process_status.is_successful:
                    six.reraise(*process_status.exc_info)
                if process_status.return_value is not None:
                    return_values.append(process_status.return_value)

            return return_values
        finally:
            self._runner._stream_stdout = False  # pylint: disable=protected-access
예제 #3
0
    def start(self):
        """Starts processes, one for each task in `cluster_spec`.

    If 'chief' job exists in the cluster, it is guaranteed that 'chief'
    process exits before other jobs to prevent chief from continuing to connect
    to them which causes error.
    """
        for task_type, addresses in self._cluster_spec.items():
            for task_id, _ in enumerate(addresses):
                p = multi_process_lib.Process(target=self._proc_func_wrapper,
                                              args=(task_type, task_id) +
                                              self._args,
                                              kwargs=self._kwargs)
                p.start()
                self._processes.append(p)
예제 #4
0
    def run(self, proc_func, args=None, kwargs=None):
        """Runs `proc_func` with `args` and `kwargs` on all jobs.

    Args:
      proc_func: The function to be run.
      args: Optional positional arguments to be supplied in `proc_func`.
      kwargs: Optional keyword arguments to be supplied in `proc_func`.

    Returns:
      A list of return values.
    """
        # TODO(b/150264776): skip in OSS until it's implemented.
        multi_process_lib.Process()
        if self._runner is None:
            self._start()

        proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL)
        for conn in self._conn.values():
            conn.send((proc_func, args or [], kwargs or {}))

        process_statuses = []
        for (task_type, task_id), conn in self._conn.items():
            logging.info('Waiting for the result from %s-%d', task_type,
                         task_id)
            try:
                process_statuses.append(conn.recv())
            except EOFError:
                # This shouldn't happen due to exceptions in proc_func. This usually
                # means bugs in the runner.
                self.shutdown()
                raise RuntimeError(
                    'Unexpected EOF. Worker process may have died. '
                    'Please report a bug')

        return_values = []
        for process_status in process_statuses:
            assert isinstance(process_status, _ProcessStatusInfo)
            if not process_status.is_successful:
                six.reraise(*process_status.exc_info)
            if process_status.return_value is not None:
                return_values.append(process_status.return_value)

        return return_values
예제 #5
0
  def _start_subprocess_and_reading_thread(self, proc_func, task_type, task_id,
                                           cluster_spec, args, kwargs):
    """Start a subprocess and a thread the reads lines from the subprocess."""
    global _next_pipe_index
    pipe_r, pipe_w = _resource(STREAMING_PIPE)[_next_pipe_index]
    _next_pipe_index += 1

    p = multi_process_lib.Process(
        target=_Subprocess(),
        args=(proc_func, task_type, task_id, cluster_spec, self._rpc_layer,
              self._grpc_fail_fast, self._v2_enabled, self._executing_eagerly,
              pipe_w) + args,
        kwargs=kwargs)
    p.start()
    self._outstanding_subprocess_count += 1

    # For each subprocess, we dedicate a thread continuously reading lines
    # from them.
    thread = threading.Thread(  # pylint: disable=unexpected-keyword-arg
        target=self._continuously_readline_from_sub,
        args=(pipe_r, task_type, task_id))
    thread.start()
예제 #6
0
    def start_single_process(self,
                             task_type,
                             task_id,
                             proc_func=None,
                             updated_cluster_spec=None,
                             args=None,
                             kwargs=None):
        """Starts a single process.

    This starts a process in the cluster with the task type, task id, and the
    process function (`proc_func`). If process function is `None`, the function
    provided at `__init__` will be used. If `updated_cluster_spec` is not
    `None`, the cluster spec used by this subprocess will be updated.

    TODO(rchao): It is meant that all subprocesses will be updated with the new
    cluster spec, but this has yet to be implemented. At this time only the
    newly started subprocess picks up this updated cluster spec.

    Args:
      task_type: The task type.
      task_id: The task id.
      proc_func: The process function to be run on the newly started
        process. If `None`, the function provided at `__init__` will be used.
      updated_cluster_spec: If not `None`, the cluster spec used by this
        subprocess will be updated.
      args: Optional positional arguments to be supplied in `proc_func`.
      kwargs: Optional keyword arguments to be supplied in `proc_func`.
    """
        self._cluster_spec = updated_cluster_spec or self._cluster_spec
        proc_func = proc_func or self._proc_func
        p = multi_process_lib.Process(
            target=self._proc_func_wrapper,
            args=(proc_func, task_type, task_id, self._cluster_spec) + (args or
                                                                        ()),
            kwargs=(kwargs or {}))
        p.start()
        self._outstanding_subprocess_count += 1
    def _start_subprocess_and_reading_thread(self, proc_func, task_type,
                                             task_id, args, kwargs):
        """Start a subprocess and a thread the reads lines from the subprocess."""
        global _next_pipe_index
        pipe_r, pipe_w = multi_process_lib.get_user_data(
        )[STREAMING_PIPE][_next_pipe_index]
        _next_pipe_index += 1

        p = multi_process_lib.Process(
            target=self._proc_func_wrapper,
            args=(proc_func, task_type, task_id, self._cluster_spec,
                  self._rpc_layer, pipe_w) + args,
            kwargs=kwargs)
        p.start()
        self._outstanding_subprocess_count += 1

        # For each subprocess, we dedicate a thread continuously reading lines
        # from them.
        thread = threading.Thread(  # pylint: disable=unexpected-keyword-arg
            target=self._continuously_readline_from_sub,
            args=(pipe_r, task_type, task_id),
            daemon=True)
        thread.start()
        _threads.append(thread)
    def run(self,
            proc_func,
            cluster_spec,
            proc_flags=None,
            timeout=200,
            time_to_exit=None,
            return_std_stream=False,
            args=None,
            kwargs=None):
        """Run functions on local sub-processes.

    Experimental. API subject to change. To fully inspect logging from
    subprocesses, use `--test_arg=--logtostderr` flag with bazel test.

    Args:
      proc_func: Function to be run on the processes. This will be run on
        processes for all task types.
      cluster_spec: Dict for cluster spec. The following is an example of
        cluster with three workers and two ps's.
        {"worker": ["worker0.example.com:2222",
                    "worker1.example.com:2222",
                    "worker2.example.com:2222"],
         "ps": ["ps0.example.com:2222",
                "ps1.example.com:2222"]}
      proc_flags: Dict that contains the key/values of the flags used on the
        processes.
      timeout: Time out in seconds. If the sub-process takes more than this time
        to complete, raise an error.
      time_to_exit: If set, sub-processes is forced to exit at approximately
        this many seconds after `run()` is called, through `signal.alarm()` api.
        This is for simulation of interruption on a process so in such cases no
        error is raised. Note that this is best effort at Python level since
        Python signal handler does not get executed inside the low-level (C)
        signal handler, so it can be delayed.
      return_std_stream: Boolean, whether the messages streamed to stdout and
        stderr in subprocesses are captured. If True, the messages are stored in
        a list returned as the second element.
      args: Positional arguments to be sent to functions run on processes.
      kwargs: Keyword arguments to be sent to functions run on processes.

    Returns:
      If `return_std_stream` is False, a list that stores the return data added
      by subprocesses through `multi_process_runner._add_return_data(data)`
      call,
      or through normal function return; if `return_std_stream` is True, a
      two-element tuple of `(return_data_list, std_stream_data_list)`, where
      `return_data_list` stores the return data added by processes through
      `multi_process_runner._add_return_data(data)` call or through normal
      function
      return, and `std_stream_data_list` stores the messages streamed to stdout
      and stderr in the subprocesses.

    Raises:
      RuntimeError: If any of the subprocesses raise an error, or if any of the
        subprocesses does not return or error out within `timeout` seconds.
    """

        assert cluster_spec is not None
        assert callable(proc_func)
        processes = []
        args = args or ()
        kwargs = kwargs or {}

        def wrapper_func(tf_config_as_json, proc_func, proc_flags,
                         time_to_exit, executing_eagerly, *arg, **kwargs):
            """The wrapper function that actually gets run on the process(es)."""
            @contextlib.contextmanager
            def runtime_mode(executing_eagerly):
                if executing_eagerly:
                    with context.eager_mode():
                        yield
                else:
                    with context.graph_mode():
                        yield

            with runtime_mode(executing_eagerly):
                os.environ['TF_CONFIG'] = tf_config_as_json
                if proc_flags is not None:
                    for flag_key, flag_value in proc_flags.items():
                        setattr(flags.FLAGS, flag_key, flag_value)

                stdout_collector = _LogCollector(
                    sys.__stdout__) if return_std_stream else None
                stderr_collector = _LogCollector(
                    sys.__stderr__) if return_std_stream else None

                def finish_wrapper_func_properly(func_result):
                    """Call to finish `wrapper_func` properly."""
                    # Clear the alarm.
                    signal.alarm(0)
                    if (return_std_stream and stdout_collector is not None
                            and stderr_collector is not None):
                        # If stdout and stderr are to be collected, add them to std stream
                        # queue.
                        self._add_std_stream_data_flattened(
                            stdout_collector.log)
                        self._add_std_stream_data_flattened(
                            stderr_collector.log)
                        # Un-redirect stdout and stderr.
                        sys.stdout = sys.__stdout__
                        sys.stderr = sys.__stderr__
                    self._get_internal_queue().put(func_result)

                if time_to_exit is not None:

                    def handler(signum, frame):
                        del signum, frame
                        finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
                        # pylint: disable=protected-access
                        os._exit(0)

                    signal.signal(signal.SIGALRM, handler)
                    signal.alarm(time_to_exit)

                if return_std_stream:
                    sys.stdout = stdout_collector
                    sys.stderr = stderr_collector

                try:
                    return_data = proc_func(*arg, **kwargs)
                    if return_data is not None:
                        self._add_return_data(return_data)
                # pylint: disable=broad-except
                except Exception:
                    # Capture all exceptions to be reported to parent process.
                    finish_wrapper_func_properly(
                        _ExcInfoWrapper(sys.exc_info()))

                    # Re-raise the exception in addition to reporting it to the parent
                    # process, so that even if `--test_timeout` flag is set and the
                    # error doesn't make it to be shown in parent process before bazel's
                    # timeout, the log would still show what happens in this subprocess,
                    # instead of silently suppressing the error due to early bazel
                    # timeout. Raising an error in the subprocess produces stack trace in
                    # the log, but the program continues running.
                    raise

                finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)

        # Start number of processes according to `count_dict`.
        for job_type, addresses in cluster_spec.items():
            for task_id, _ in enumerate(addresses):
                tf_config_as_json = json.dumps({
                    'cluster': cluster_spec,
                    'task': {
                        'type': job_type,
                        'index': task_id
                    }
                })
                p = multi_process_lib.Process(
                    target=wrapper_func,
                    args=(tf_config_as_json, proc_func, proc_flags,
                          time_to_exit, context.executing_eagerly()) + args,
                    kwargs=kwargs)
                p.start()
                processes.append(p)

        internal_queue_results = []
        for _ in range(len(processes)):
            try:
                internal_queue_results.append(
                    self._get_internal_queue().get(timeout=timeout))
            except Queue.Empty:
                # First check if any of the subprocesses raised exception.
                for internal_queue_result in internal_queue_results:
                    if isinstance(internal_queue_result, _ExcInfoWrapper):
                        six.reraise(*internal_queue_result.exc_info)
                # If none of those did, report time out to user.
                raise RuntimeError(
                    'One or more subprocesses timed out. Please use '
                    '`--test_arg=--logtostderr` bazel flag to inspect logs for '
                    'subprocess debugging info. Timeout = {} sec.'.format(
                        timeout))

        for internal_queue_result in internal_queue_results:
            if isinstance(internal_queue_result, _ExcInfoWrapper):
                six.reraise(*internal_queue_result.exc_info)
            assert internal_queue_result == _FINISH_PROPERLY_MESSAGE

        def queue_to_list(queue_to_convert):
            """Convert `queue.Queue` to `list`."""
            list_to_return = []
            while True:
                try:
                    list_to_return.append(queue_to_convert.get(block=False))
                except Queue.Empty:
                    break
            return list_to_return

        if return_std_stream:
            return tuple(
                queue_to_list(multi_process_lib.get_user_data()[queue_name])
                for queue_name in [
                    _AvailableQueues.PUBLIC_QUEUE,
                    _AvailableQueues.STD_STREAM_QUEUE
                ])
        else:
            return queue_to_list(multi_process_lib.get_user_data()[
                _AvailableQueues.PUBLIC_QUEUE])
예제 #9
0
def run(proc_func,
        count_dict,
        proc_flags=None,
        timeout=200,
        time_to_exit=None,
        return_std_stream=False,
        args=None,
        kwargs=None):
  """Run functions on local sub-processes.

  Args:
    proc_func: Function to be run on the processes. This will be run on
      processes for all task types.
    count_dict: Dict for task_type/count of such task type.
    proc_flags: Dict that contains the key/values of the flags used on the
      processes.
    timeout: Time out in seconds. If the sub-process takes more than this time
      to complete, raise an error.
    time_to_exit: If set, sub-processes is forced to exit at approximately this
      many seconds after `run()` is called, through `signal.alarm()` api. This
      is for simulation of interruption on a process so in such cases no error
      is raised. Note that this is best effort at Python level since Python
      signal handler does not get executed inside the low-level (C) signal
      handler, so it can be delayed.
    return_std_stream: Boolean, whether the messages streamed to stdout and
      stderr in subprocesses are captured. If True, the messages are stored
      in a list returned as the second element.
    args: Positional arguments to be sent to functions run on processes.
    kwargs: Keyword arguments to be sent to functions run on processes.

  Returns:
    If `return_std_stream` is False, a list that stores the return data added
    by processes through `multi_process_runner.add_return_data(data)` call;
    if `return_std_stream` is True, a two-element tuple of
    `(return_data_list, std_stream_data_list)`, where `return_data_list` stores
    the return data added by processes through
    `multi_process_runner.add_return_data(data)` call, and
    `std_stream_data_list` stores the messages streamed to stdout and stderr
    in the subprocesses.

  Raises:
    RuntimeError: If any of the subprocesses raise an error, or if any of the
      subprocesses does not return or error out within `timeout` seconds.

  TODO(rchao): Open source this with a solution to handle multi_process_lib.
  """

  assert callable(proc_func)
  processes = []
  cluster_spec = {}
  args = args or ()
  kwargs = kwargs or {}

  for task_type, count in count_dict.items():
    cluster_spec[task_type] = [
        'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
        for _ in range(count)
    ]

  def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit,
                   *arg, **kwargs):
    """The wrapper function that actually gets run on the process(es)."""

    os.environ['TF_CONFIG'] = tf_config_as_json
    if proc_flags is not None:
      for flag_key, flag_value in proc_flags.items():
        setattr(flags.FLAGS, flag_key, flag_value)

    stdout_collector = _LogCollector(
        sys.__stdout__) if return_std_stream else None
    stderr_collector = _LogCollector(
        sys.__stderr__) if return_std_stream else None

    def finish_wrapper_func_properly(finish_message=_FINISH_PROPERLY_MESSAGE):
      """Call to finish `wrapper_func` properly."""
      # Clear the alarm.
      signal.alarm(0)
      if (return_std_stream and stdout_collector is not None and
          stderr_collector is not None):
        # If stdout and stderr are to be collected, add them to std stream
        # queue.
        _add_std_stream_data_flattened(stdout_collector.log)
        _add_std_stream_data_flattened(stderr_collector.log)
        # Un-redirect stdout and stderr.
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
      _get_internal_queue().put(finish_message)

    if time_to_exit is not None:

      def handler(signum, frame):
        del signum, frame
        finish_wrapper_func_properly()
        # pylint: disable=protected-access
        os._exit(0)

      signal.signal(signal.SIGALRM, handler)
      signal.alarm(time_to_exit)

    if return_std_stream:
      sys.stdout = stdout_collector
      sys.stderr = stderr_collector

    try:
      proc_func(*arg, **kwargs)
    # pylint: disable=broad-except
    except Exception as e:
      # Capture all exceptions to be reported to parent process.
      finish_wrapper_func_properly(
          'Exception raised by subprocess: {}: {} {}'.format(
              e.__class__.__name__, str(e), traceback.format_exc()))
      return

    finish_wrapper_func_properly()

  # Start number of processes according to `count_dict`.
  for task_type, count in count_dict.items():
    for task_id in range(count):
      tf_config_as_json = json.dumps({
          'cluster': cluster_spec,
          'task': {
              'type': task_type,
              'index': task_id
          }
      })
      p = multi_process_lib.Process(
          target=wrapper_func,
          args=(tf_config_as_json, proc_func, proc_flags, time_to_exit) + args,
          kwargs=kwargs)
      p.start()
      processes.append(p)

  internal_queue_results = []
  for _ in range(len(processes)):
    try:
      internal_queue_results.append(
          _get_internal_queue().get(timeout=timeout))
    except Queue.Empty:
      raise RuntimeError(
          'One or more subprocesses timed out. Please inspect logs for '
          'subprocess debugging info. Timeout = {} sec.'.format(timeout))

  for internal_queue_result in internal_queue_results:
    if internal_queue_result.startswith('Exception raised by subprocess'):
      raise RuntimeError(internal_queue_result)
    assert internal_queue_result == _FINISH_PROPERLY_MESSAGE

  def queue_to_list(queue_to_convert):
    """Convert `queue.Queue` to `list`."""
    list_to_return = []
    while True:
      try:
        list_to_return.append(queue_to_convert.get(block=False))
      except Queue.Empty:
        break
    return list_to_return

  if return_std_stream:
    return tuple(
        queue_to_list(multi_process_lib.get_user_data()[queue_name])
        for queue_name in
        [AvailableQueues.PUBLIC_QUEUE, AvailableQueues.STD_STREAM_QUEUE])
  else:
    return queue_to_list(
        multi_process_lib.get_user_data()[AvailableQueues.PUBLIC_QUEUE])