def join(self, timeout=None):
    """Joins all the processes with timeout.

    Args:
      timeout: if set and not all processes report status within roughly
        `timeout` seconds, a `RuntimeError` exception will be thrown.

    Returns:
      It returns a tuple. The first element is a list that stores the return
      data added by subprocesses through `_add_return_data` or through normal
      function return; The second element is a list of the messages streamed to
      stdout and stderr in the subprocesses if `capture_std_stream` is True or
      `None` otherwise.

    Raises:
      RuntimeError: if not all processes report status within `timeout` seconds.
      Or the exception propagated from any child process.
    """
    if not timeout:
      if self._max_run_time:
        timeout = self._max_run_time + 10  # add 10 seconds grace period
      else:
        timeout = float('inf')
    start_time = time.time()
    while self._outstanding_subprocess_count > 0:
      while True:
        try:
          process_status = self._get_process_status_queue().get(timeout=10)
          break
        except Queue.Empty:
          if time.time() - start_time > timeout:
            # If none of those did, report timeout to user.
            raise RuntimeError(
                'One or more subprocesses timed out. Please use '
                '`--test_arg=--logtostderr` bazel flag to inspect logs for '
                'subprocess debugging info. Number of outstanding subprocesses '
                'is %d.' % self._outstanding_subprocess_count)

      self._outstanding_subprocess_count -= 1
      assert isinstance(process_status, _ProcessStatusInfo)
      if not process_status.is_successful:
        six.reraise(*process_status.exc_info)

    if self._capture_std_stream:
      # TODO(yuefengz): we need to make sure elements match the same process in
      # the two returned lists so as to not surprise users. Consider creating a
      # `ReturnData` class.
      return tuple(
          self._queue_to_list(multi_process_lib.get_user_data()[queue_name])
          for queue_name in [RETURN_VALUE_QUEUE, STD_STREAM_QUEUE])
    else:
      return (self._queue_to_list(
          multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE]), None)
Example #2
0
    def _add_return_data(self, data):
        """Adds return data that will be returned by `join`.

    The function provides a way for child processes to communicate with the
    parent process. Data passed to `_add_return_data` will be available in a
    Python Queue.Queue that is eventually returned by `join`.

    Args:
      data: data to be made available in the queue returned by `join`.
    """
        # TODO(rchao): Incorporate the task type and id information in a data
        # wrapper that becomes what is stored in the queue so we can tell where
        # the data is from.
        multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE].put(data)
Example #3
0
 def _add_std_stream_data_flattened(self, data):
     # TODO(yuefengz): currently the same queue is used by multiple processes. It
     # is difficult for users to distinguish between logs from different
     # processes.
     std_stream_queue = multi_process_lib.get_user_data()[STD_STREAM_QUEUE]
     for d in list(data):
         std_stream_queue.put(d)
def add_return_data(data):
  """Add return data that will be returned by `multi_process_runner.run()`.

  The function provides a way for processes started by
  `multi_process_runner.run()` to communicate with the original process
  that started the sub-processes. Data passed to `add_return_data` will
  be available in a python Queue.Queue that is eventually returned by
  `multi_process_runner.run()`.

  Args:
    data: data to be made available in the queue returned by
      `multi_process_runner.run()`.
  """
  # TODO(rchao): Incorporate the task type and id information in a data
  # wrapper that becomes what is stored in the queue so we can tell where
  # the data is from.
  multi_process_lib.get_user_data()[_AvailableQueues.PUBLIC_QUEUE].put(data)
    def _start_subprocess_and_reading_thread(self, proc_func, task_type,
                                             task_id, args, kwargs):
        """Start a subprocess and a thread the reads lines from the subprocess."""
        global _next_pipe_index
        pipe_r, pipe_w = multi_process_lib.get_user_data(
        )[STREAMING_PIPE][_next_pipe_index]
        _next_pipe_index += 1

        p = multi_process_lib.Process(
            target=self._proc_func_wrapper,
            args=(proc_func, task_type, task_id, self._cluster_spec,
                  self._rpc_layer, pipe_w) + args,
            kwargs=kwargs)
        p.start()
        self._outstanding_subprocess_count += 1

        # For each subprocess, we dedicate a thread continuously reading lines
        # from them.
        thread = threading.Thread(  # pylint: disable=unexpected-keyword-arg
            target=self._continuously_readline_from_sub,
            args=(pipe_r, task_type, task_id),
            daemon=True)
        thread.start()
        _threads.append(thread)
Example #6
0
 def _get_parent_to_sub_queue(self):
     return multi_process_lib.get_user_data()[PARENT_TO_SUB_QUEUE]
Example #7
0
 def _get_inter_process_queue(self):
     return multi_process_lib.get_user_data()[INTER_PROCESS_QUEUE]
Example #8
0
 def _get_process_status_queue(self):
     return multi_process_lib.get_user_data()[PROCESS_STATUS_QUEUE]
    def run(self,
            proc_func,
            cluster_spec,
            proc_flags=None,
            timeout=200,
            time_to_exit=None,
            return_std_stream=False,
            args=None,
            kwargs=None):
        """Run functions on local sub-processes.

    Experimental. API subject to change. To fully inspect logging from
    subprocesses, use `--test_arg=--logtostderr` flag with bazel test.

    Args:
      proc_func: Function to be run on the processes. This will be run on
        processes for all task types.
      cluster_spec: Dict for cluster spec. The following is an example of
        cluster with three workers and two ps's.
        {"worker": ["worker0.example.com:2222",
                    "worker1.example.com:2222",
                    "worker2.example.com:2222"],
         "ps": ["ps0.example.com:2222",
                "ps1.example.com:2222"]}
      proc_flags: Dict that contains the key/values of the flags used on the
        processes.
      timeout: Time out in seconds. If the sub-process takes more than this time
        to complete, raise an error.
      time_to_exit: If set, sub-processes is forced to exit at approximately
        this many seconds after `run()` is called, through `signal.alarm()` api.
        This is for simulation of interruption on a process so in such cases no
        error is raised. Note that this is best effort at Python level since
        Python signal handler does not get executed inside the low-level (C)
        signal handler, so it can be delayed.
      return_std_stream: Boolean, whether the messages streamed to stdout and
        stderr in subprocesses are captured. If True, the messages are stored in
        a list returned as the second element.
      args: Positional arguments to be sent to functions run on processes.
      kwargs: Keyword arguments to be sent to functions run on processes.

    Returns:
      If `return_std_stream` is False, a list that stores the return data added
      by subprocesses through `multi_process_runner._add_return_data(data)`
      call,
      or through normal function return; if `return_std_stream` is True, a
      two-element tuple of `(return_data_list, std_stream_data_list)`, where
      `return_data_list` stores the return data added by processes through
      `multi_process_runner._add_return_data(data)` call or through normal
      function
      return, and `std_stream_data_list` stores the messages streamed to stdout
      and stderr in the subprocesses.

    Raises:
      RuntimeError: If any of the subprocesses raise an error, or if any of the
        subprocesses does not return or error out within `timeout` seconds.
    """

        assert cluster_spec is not None
        assert callable(proc_func)
        processes = []
        args = args or ()
        kwargs = kwargs or {}

        def wrapper_func(tf_config_as_json, proc_func, proc_flags,
                         time_to_exit, executing_eagerly, *arg, **kwargs):
            """The wrapper function that actually gets run on the process(es)."""
            @contextlib.contextmanager
            def runtime_mode(executing_eagerly):
                if executing_eagerly:
                    with context.eager_mode():
                        yield
                else:
                    with context.graph_mode():
                        yield

            with runtime_mode(executing_eagerly):
                os.environ['TF_CONFIG'] = tf_config_as_json
                if proc_flags is not None:
                    for flag_key, flag_value in proc_flags.items():
                        setattr(flags.FLAGS, flag_key, flag_value)

                stdout_collector = _LogCollector(
                    sys.__stdout__) if return_std_stream else None
                stderr_collector = _LogCollector(
                    sys.__stderr__) if return_std_stream else None

                def finish_wrapper_func_properly(func_result):
                    """Call to finish `wrapper_func` properly."""
                    # Clear the alarm.
                    signal.alarm(0)
                    if (return_std_stream and stdout_collector is not None
                            and stderr_collector is not None):
                        # If stdout and stderr are to be collected, add them to std stream
                        # queue.
                        self._add_std_stream_data_flattened(
                            stdout_collector.log)
                        self._add_std_stream_data_flattened(
                            stderr_collector.log)
                        # Un-redirect stdout and stderr.
                        sys.stdout = sys.__stdout__
                        sys.stderr = sys.__stderr__
                    self._get_internal_queue().put(func_result)

                if time_to_exit is not None:

                    def handler(signum, frame):
                        del signum, frame
                        finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)
                        # pylint: disable=protected-access
                        os._exit(0)

                    signal.signal(signal.SIGALRM, handler)
                    signal.alarm(time_to_exit)

                if return_std_stream:
                    sys.stdout = stdout_collector
                    sys.stderr = stderr_collector

                try:
                    return_data = proc_func(*arg, **kwargs)
                    if return_data is not None:
                        self._add_return_data(return_data)
                # pylint: disable=broad-except
                except Exception:
                    # Capture all exceptions to be reported to parent process.
                    finish_wrapper_func_properly(
                        _ExcInfoWrapper(sys.exc_info()))

                    # Re-raise the exception in addition to reporting it to the parent
                    # process, so that even if `--test_timeout` flag is set and the
                    # error doesn't make it to be shown in parent process before bazel's
                    # timeout, the log would still show what happens in this subprocess,
                    # instead of silently suppressing the error due to early bazel
                    # timeout. Raising an error in the subprocess produces stack trace in
                    # the log, but the program continues running.
                    raise

                finish_wrapper_func_properly(_FINISH_PROPERLY_MESSAGE)

        # Start number of processes according to `count_dict`.
        for job_type, addresses in cluster_spec.items():
            for task_id, _ in enumerate(addresses):
                tf_config_as_json = json.dumps({
                    'cluster': cluster_spec,
                    'task': {
                        'type': job_type,
                        'index': task_id
                    }
                })
                p = multi_process_lib.Process(
                    target=wrapper_func,
                    args=(tf_config_as_json, proc_func, proc_flags,
                          time_to_exit, context.executing_eagerly()) + args,
                    kwargs=kwargs)
                p.start()
                processes.append(p)

        internal_queue_results = []
        for _ in range(len(processes)):
            try:
                internal_queue_results.append(
                    self._get_internal_queue().get(timeout=timeout))
            except Queue.Empty:
                # First check if any of the subprocesses raised exception.
                for internal_queue_result in internal_queue_results:
                    if isinstance(internal_queue_result, _ExcInfoWrapper):
                        six.reraise(*internal_queue_result.exc_info)
                # If none of those did, report time out to user.
                raise RuntimeError(
                    'One or more subprocesses timed out. Please use '
                    '`--test_arg=--logtostderr` bazel flag to inspect logs for '
                    'subprocess debugging info. Timeout = {} sec.'.format(
                        timeout))

        for internal_queue_result in internal_queue_results:
            if isinstance(internal_queue_result, _ExcInfoWrapper):
                six.reraise(*internal_queue_result.exc_info)
            assert internal_queue_result == _FINISH_PROPERLY_MESSAGE

        def queue_to_list(queue_to_convert):
            """Convert `queue.Queue` to `list`."""
            list_to_return = []
            while True:
                try:
                    list_to_return.append(queue_to_convert.get(block=False))
                except Queue.Empty:
                    break
            return list_to_return

        if return_std_stream:
            return tuple(
                queue_to_list(multi_process_lib.get_user_data()[queue_name])
                for queue_name in [
                    _AvailableQueues.PUBLIC_QUEUE,
                    _AvailableQueues.STD_STREAM_QUEUE
                ])
        else:
            return queue_to_list(multi_process_lib.get_user_data()[
                _AvailableQueues.PUBLIC_QUEUE])
 def _add_std_stream_data_flattened(self, data):
     std_stream_queue = multi_process_lib.get_user_data()[
         _AvailableQueues.STD_STREAM_QUEUE]
     for d in list(data):
         std_stream_queue.put(d)
 def _get_subprocess_info_queue(self):
     return multi_process_lib.get_user_data()[SUBPROCESS_INFO_QUEUE]
    def join(self, timeout=None):
        """Joins all the processes with timeout.

    Args:
      timeout: if set and not all processes report status within roughly
        `timeout` seconds, a `RuntimeError` exception will be thrown.

    Returns:
      A MultiProcessRunnerResult object, which has two attributes,
      `return_value` and `stdout`. `return_value` always contains the return
      values from the subprocesses. If `list_stdout` argument is True at
      `__init__`, `stdout` is available that contains a list of all messages
      from subprocesses' stdout and stderr.

    Raises:
      RuntimeError: if not all processes report status approximatelty within
      `timeout` seconds, or there's an exception propagated from any subprocess.
    """

        if not timeout:
            timeout = float('inf')
        start_time = time.time()
        while self._outstanding_subprocess_count > 0:
            while True:
                try:
                    process_status = self._get_process_status_queue().get(
                        timeout=10)
                    break
                except Queue.Empty:
                    if self._all_forced_terminated:
                        break
                    if time.time() - start_time > timeout:
                        # If none of those did, report timeout to user.
                        raise RuntimeError(
                            'One or more subprocesses timed out. '
                            'Number of outstanding subprocesses '
                            'is %d.' % self._outstanding_subprocess_count)

            if self._all_forced_terminated:
                break
            self._outstanding_subprocess_count -= 1
            assert isinstance(process_status, _ProcessStatusInfo)
            if not process_status.is_successful:
                six.reraise(*process_status.exc_info)

            if self._dependence_on_chief and process_status.task_type == 'chief':
                self.terminate_all()
                break

        # Giving threads some time to finish the message reading from subprocesses.
        time.sleep(5)

        stdout = self._queue_to_list(
            multi_process_lib.get_user_data()[STREAMING_QUEUE])
        return_value = self._queue_to_list(
            multi_process_lib.get_user_data()[RETURN_VALUE_QUEUE])

        # Notifying the threads that are reading lines that we should stop.
        for pipe_index in range(self._starting_pipe_index, _next_pipe_index):  # pylint: disable=protected-access
            _, pipe_w = multi_process_lib.get_user_data(
            )[STREAMING_PIPE][pipe_index]
            writer = os.fdopen(pipe_w.fileno(), 'w')
            # Writing end of file message so the threads that's actively reading lines
            # know to stop.
            writer.writelines(['EOF'])
            writer.close()

        for thread in _threads:
            thread.join(5)

        return MultiProcessRunnerResult(stdout=stdout,
                                        return_value=return_value)
 def _add_stdout_in_queue(self, formatted_line, task_type, task_id):
     del task_type, task_id
     # A queue instead of a simple list is used here due to b/150652733.
     multi_process_lib.get_user_data()[STREAMING_QUEUE].put(formatted_line)
def barrier():
  return multi_process_lib.get_user_data()[BARRIER]
 def _get_internal_queue(self):
     return multi_process_lib.get_user_data()[
         _AvailableQueues.INTERNAL_QUEUE]
def _resource(resource_name):
  return multi_process_lib.get_user_data()[resource_name]
def run(proc_func,
        count_dict,
        proc_flags=None,
        timeout=200,
        time_to_exit=None,
        return_std_stream=False,
        args=None,
        kwargs=None):
  """Run functions on local sub-processes.

  Args:
    proc_func: Function to be run on the processes. This will be run on
      processes for all task types.
    count_dict: Dict for task_type/count of such task type.
    proc_flags: Dict that contains the key/values of the flags used on the
      processes.
    timeout: Time out in seconds. If the sub-process takes more than this time
      to complete, raise an error.
    time_to_exit: If set, sub-processes is forced to exit at approximately this
      many seconds after `run()` is called, through `signal.alarm()` api. This
      is for simulation of interruption on a process so in such cases no error
      is raised. Note that this is best effort at Python level since Python
      signal handler does not get executed inside the low-level (C) signal
      handler, so it can be delayed.
    return_std_stream: Boolean, whether the messages streamed to stdout and
      stderr in subprocesses are captured. If True, the messages are stored
      in a list returned as the second element.
    args: Positional arguments to be sent to functions run on processes.
    kwargs: Keyword arguments to be sent to functions run on processes.

  Returns:
    If `return_std_stream` is False, a list that stores the return data added
    by processes through `multi_process_runner.add_return_data(data)` call;
    if `return_std_stream` is True, a two-element tuple of
    `(return_data_list, std_stream_data_list)`, where `return_data_list` stores
    the return data added by processes through
    `multi_process_runner.add_return_data(data)` call, and
    `std_stream_data_list` stores the messages streamed to stdout and stderr
    in the subprocesses.

  Raises:
    RuntimeError: If any of the subprocesses raise an error, or if any of the
      subprocesses does not return or error out within `timeout` seconds.

  TODO(rchao): Open source this with a solution to handle multi_process_lib.
  """

  assert callable(proc_func)
  processes = []
  cluster_spec = {}
  args = args or ()
  kwargs = kwargs or {}

  for task_type, count in count_dict.items():
    cluster_spec[task_type] = [
        'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
        for _ in range(count)
    ]

  def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit,
                   *arg, **kwargs):
    """The wrapper function that actually gets run on the process(es)."""

    os.environ['TF_CONFIG'] = tf_config_as_json
    if proc_flags is not None:
      for flag_key, flag_value in proc_flags.items():
        setattr(flags.FLAGS, flag_key, flag_value)

    stdout_collector = _LogCollector(
        sys.__stdout__) if return_std_stream else None
    stderr_collector = _LogCollector(
        sys.__stderr__) if return_std_stream else None

    def finish_wrapper_func_properly(finish_message=_FINISH_PROPERLY_MESSAGE):
      """Call to finish `wrapper_func` properly."""
      # Clear the alarm.
      signal.alarm(0)
      if (return_std_stream and stdout_collector is not None and
          stderr_collector is not None):
        # If stdout and stderr are to be collected, add them to std stream
        # queue.
        _add_std_stream_data_flattened(stdout_collector.log)
        _add_std_stream_data_flattened(stderr_collector.log)
        # Un-redirect stdout and stderr.
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
      _get_internal_queue().put(finish_message)

    if time_to_exit is not None:

      def handler(signum, frame):
        del signum, frame
        finish_wrapper_func_properly()
        # pylint: disable=protected-access
        os._exit(0)

      signal.signal(signal.SIGALRM, handler)
      signal.alarm(time_to_exit)

    if return_std_stream:
      sys.stdout = stdout_collector
      sys.stderr = stderr_collector

    try:
      proc_func(*arg, **kwargs)
    # pylint: disable=broad-except
    except Exception as e:
      # Capture all exceptions to be reported to parent process.
      finish_wrapper_func_properly(
          'Exception raised by subprocess: {}: {} {}'.format(
              e.__class__.__name__, str(e), traceback.format_exc()))
      return

    finish_wrapper_func_properly()

  # Start number of processes according to `count_dict`.
  for task_type, count in count_dict.items():
    for task_id in range(count):
      tf_config_as_json = json.dumps({
          'cluster': cluster_spec,
          'task': {
              'type': task_type,
              'index': task_id
          }
      })
      p = multi_process_lib.Process(
          target=wrapper_func,
          args=(tf_config_as_json, proc_func, proc_flags, time_to_exit) + args,
          kwargs=kwargs)
      p.start()
      processes.append(p)

  internal_queue_results = []
  for _ in range(len(processes)):
    try:
      internal_queue_results.append(
          _get_internal_queue().get(timeout=timeout))
    except Queue.Empty:
      raise RuntimeError(
          'One or more subprocesses timed out. Please inspect logs for '
          'subprocess debugging info. Timeout = {} sec.'.format(timeout))

  for internal_queue_result in internal_queue_results:
    if internal_queue_result.startswith('Exception raised by subprocess'):
      raise RuntimeError(internal_queue_result)
    assert internal_queue_result == _FINISH_PROPERLY_MESSAGE

  def queue_to_list(queue_to_convert):
    """Convert `queue.Queue` to `list`."""
    list_to_return = []
    while True:
      try:
        list_to_return.append(queue_to_convert.get(block=False))
      except Queue.Empty:
        break
    return list_to_return

  if return_std_stream:
    return tuple(
        queue_to_list(multi_process_lib.get_user_data()[queue_name])
        for queue_name in
        [AvailableQueues.PUBLIC_QUEUE, AvailableQueues.STD_STREAM_QUEUE])
  else:
    return queue_to_list(
        multi_process_lib.get_user_data()[AvailableQueues.PUBLIC_QUEUE])