Example #1
0
    def test_tail_logfile_never_generates(self):
        """
        Ensures that we properly shutdown the threadpool
        even when the logfile never generates.
        """

        tail = TailLog("writer", log_files={0: "foobar.log"}, dst=sys.stdout).start()
        tail.stop()
        self.assertTrue(tail.stopped())
        self.assertTrue(tail._threadpool._shutdown)
Example #2
0
    def test_tail_logfile_error_in_tail_fn(self, mock_logger):
        """
        Ensures that when there is an error in the tail_fn (the one that runs in the
        threadpool), it is dealt with and raised properly.
        """

        # try giving tail log a directory (should fail with an IsADirectoryError
        tail = TailLog("writer", log_files={0: self.test_dir}, dst=sys.stdout).start()
        tail.stop()

        mock_logger.error.assert_called_once()
Example #3
0
    def test_tail(self):
        """
        writer() writes 0 - max (on number on each line) to a log file.
        Run nprocs such writers and tail the log files into an IOString
        and validate that all lines are accounted for.
        """
        nprocs = 32
        max = 1000
        interval_sec = 0.0001

        log_files = {
            local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
            for local_rank in range(nprocs)
        }

        dst = io.StringIO()
        tail = TailLog("writer", log_files, dst, interval_sec).start()
        # sleep here is intentional to ensure that the log tail
        # can gracefully handle and wait for non-existent log files
        time.sleep(interval_sec * 10)

        futs = []
        for local_rank, file in log_files.items():
            f = self.threadpool.submit(
                write, max=max, sleep=interval_sec * local_rank, file=file
            )
            futs.append(f)

        wait(futs, return_when=ALL_COMPLETED)
        self.assertFalse(tail.stopped())
        tail.stop()

        dst.seek(0)
        actual: Dict[int, Set[int]] = {}

        for line in dst.readlines():
            header, num = line.split(":")
            nums = actual.setdefault(header, set())
            nums.add(int(num))

        self.assertEqual(nprocs, len(actual))
        self.assertEqual(
            {f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
        )
        self.assertTrue(tail.stopped())
Example #4
0
class PContext(abc.ABC):
    """
    The base class that standardizes operations over a set of processes
    that are launched via different mechanisms. The name ``PContext``
    is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.

    .. warning:: stdouts and stderrs should ALWAYS be a superset of
                 tee_stdouts and tee_stderrs (respectively) this is b/c
                 tee is implemented as a redirect + tail -f <stdout/stderr.log>
    """
    def __init__(
        self,
        name: str,
        entrypoint: Union[Callable, str],
        args: Dict[int, Tuple],
        envs: Dict[int, Dict[str, str]],
        stdouts: Dict[int, str],
        stderrs: Dict[int, str],
        tee_stdouts: Dict[int, str],
        tee_stderrs: Dict[int, str],
        error_files: Dict[int, str],
    ):
        self.name = name
        # validate that all mappings have the same number of keys and
        # all local ranks are accounted for
        nprocs = len(args)
        _validate_full_rank(stdouts, nprocs, "stdouts")
        _validate_full_rank(stderrs, nprocs, "stderrs")

        self.entrypoint = entrypoint
        self.args = args
        self.envs = envs
        self.stdouts = stdouts
        self.stderrs = stderrs
        self.error_files = error_files
        self.nprocs = nprocs

        self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout)
        self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr)

    def start(self) -> None:
        """
        Start processes using parameters defined in the constructor.
        """
        self._start()
        self._stdout_tail.start()
        self._stderr_tail.start()

    @abc.abstractmethod
    def _start(self) -> None:
        """
        Start processes using strategy defined in a particular context.
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _poll(self) -> Optional[RunProcsResult]:
        """
        Polls the run status of the processes running under this context.
        This method follows an "all-or-nothing" policy and returns
        a ``RunProcessResults`` object if either all processes complete
        successfully or any process fails. Returns ``None`` if
        all processes are still running.
        """
        raise NotImplementedError()

    def wait(self,
             timeout: float = -1,
             period: float = 1) -> Optional[RunProcsResult]:
        """
        Waits for the specified ``timeout`` seconds, polling every ``period`` seconds
        for the processes to be done. Returns ``None`` if the processes are still running
        on timeout expiry. Negative timeout values are interpreted as "wait-forever".
        A timeout value of zero simply queries the status of the processes (e.g. equivalent
        to a poll).
        """

        if timeout == 0:
            return self._poll()

        if timeout < 0:
            timeout = sys.maxsize

        expiry = time.time() + timeout
        while time.time() < expiry:
            pr = self._poll()
            if pr:
                return pr
            time.sleep(period)

        return None

    @abc.abstractmethod
    def pids(self) -> Dict[int, int]:
        """
        Returns pids of processes mapped by their respective local_ranks
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _close(self) -> None:
        r"""
        Terminates all processes managed by this context and cleans up any
        meta resources (e.g. redirect, error_file files).
        """
        raise NotImplementedError()

    def close(self) -> None:
        self._close()
        if self._stdout_tail:
            self._stdout_tail.stop()
        if self._stderr_tail:
            self._stderr_tail.stop()
Example #5
0
class PContext(abc.ABC):
    """
    The base class that standardizes operations over a set of processes
    that are launched via different mechanisms. The name ``PContext``
    is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.

    .. warning:: stdouts and stderrs should ALWAYS be a superset of
                 tee_stdouts and tee_stderrs (respectively) this is b/c
                 tee is implemented as a redirect + tail -f <stdout/stderr.log>
    """
    def __init__(
        self,
        name: str,
        entrypoint: Union[Callable, str],
        args: Dict[int, Tuple],
        envs: Dict[int, Dict[str, str]],
        stdouts: Dict[int, str],
        stderrs: Dict[int, str],
        tee_stdouts: Dict[int, str],
        tee_stderrs: Dict[int, str],
        error_files: Dict[int, str],
    ):
        self.name = name
        # validate that all mappings have the same number of keys and
        # all local ranks are accounted for
        nprocs = len(args)
        _validate_full_rank(stdouts, nprocs, "stdouts")
        _validate_full_rank(stderrs, nprocs, "stderrs")

        self.entrypoint = entrypoint
        self.args = args
        self.envs = envs
        self.stdouts = stdouts
        self.stderrs = stderrs
        self.error_files = error_files
        self.nprocs = nprocs

        self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout)
        self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr)

    def start(self) -> None:
        """
        Start processes using parameters defined in the constructor.
        """
        signal.signal(signal.SIGTERM, _terminate_process_handler)
        signal.signal(signal.SIGINT, _terminate_process_handler)
        if not IS_WINDOWS:
            signal.signal(signal.SIGHUP, _terminate_process_handler)
            signal.signal(signal.SIGQUIT, _terminate_process_handler)
        self._start()
        self._stdout_tail.start()
        self._stderr_tail.start()

    @abc.abstractmethod
    def _start(self) -> None:
        """
        Start processes using strategy defined in a particular context.
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _poll(self) -> Optional[RunProcsResult]:
        """
        Polls the run status of the processes running under this context.
        This method follows an "all-or-nothing" policy and returns
        a ``RunProcessResults`` object if either all processes complete
        successfully or any process fails. Returns ``None`` if
        all processes are still running.
        """
        raise NotImplementedError()

    def wait(self,
             timeout: float = -1,
             period: float = 1) -> Optional[RunProcsResult]:
        """
        Waits for the specified ``timeout`` seconds, polling every ``period`` seconds
        for the processes to be done. Returns ``None`` if the processes are still running
        on timeout expiry. Negative timeout values are interpreted as "wait-forever".
        A timeout value of zero simply queries the status of the processes (e.g. equivalent
        to a poll).

        ..note: Multiprocesing library registers SIGTERM and SIGINT signal handlers that raise
                ``SignalException`` when the signals received. It is up to the consumer of the code
                to properly handle the exception. It is important not to swallow the exception otherwise
                the process would not terminate. Example of the typical workflow can be:

        .. code-block:: python
            pc = start_processes(...)
            try:
                pc.wait(1)
                .. do some other work
            except SignalException as e:
                pc.shutdown(e.sigval, timeout=30)

        If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
        received signal. If child processes will not terminate in the timeout time, the process will send
        the SIGKILL.
        """

        if timeout == 0:
            return self._poll()

        if timeout < 0:
            timeout = sys.maxsize

        expiry = time.time() + timeout
        while time.time() < expiry:
            pr = self._poll()
            if pr:
                return pr
            time.sleep(period)

        return None

    @abc.abstractmethod
    def pids(self) -> Dict[int, int]:
        """
        Returns pids of processes mapped by their respective local_ranks
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
        r"""
        Terminates all processes managed by this context and cleans up any
        meta resources (e.g. redirect, error_file files).
        """
        raise NotImplementedError()

    def close(self,
              death_sig: Optional[signal.Signals] = None,
              timeout: int = 30) -> None:
        r"""
        Terminates all processes managed by this context and cleans up any
        meta resources (e.g. redirect, error_file files).

        Args:
            death_sig: Death signal to terminate porcesses.
            timeout: Time to wait for processes to finish, if process is
                still alive after this time, it will be terminated via SIGKILL.
        """
        if not death_sig:
            death_sig = _get_default_signal()
        self._close(death_sig=death_sig, timeout=timeout)
        if self._stdout_tail:
            self._stdout_tail.stop()
        if self._stderr_tail:
            self._stderr_tail.stop()