def test_tail_logfile_never_generates(self): """ Ensures that we properly shutdown the threadpool even when the logfile never generates. """ tail = TailLog("writer", log_files={0: "foobar.log"}, dst=sys.stdout).start() tail.stop() self.assertTrue(tail.stopped()) self.assertTrue(tail._threadpool._shutdown)
def test_tail_logfile_error_in_tail_fn(self, mock_logger): """ Ensures that when there is an error in the tail_fn (the one that runs in the threadpool), it is dealt with and raised properly. """ # try giving tail log a directory (should fail with an IsADirectoryError tail = TailLog("writer", log_files={0: self.test_dir}, dst=sys.stdout).start() tail.stop() mock_logger.error.assert_called_once()
def test_tail(self): """ writer() writes 0 - max (on number on each line) to a log file. Run nprocs such writers and tail the log files into an IOString and validate that all lines are accounted for. """ nprocs = 32 max = 1000 interval_sec = 0.0001 log_files = { local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log") for local_rank in range(nprocs) } dst = io.StringIO() tail = TailLog("writer", log_files, dst, interval_sec).start() # sleep here is intentional to ensure that the log tail # can gracefully handle and wait for non-existent log files time.sleep(interval_sec * 10) futs = [] for local_rank, file in log_files.items(): f = self.threadpool.submit( write, max=max, sleep=interval_sec * local_rank, file=file ) futs.append(f) wait(futs, return_when=ALL_COMPLETED) self.assertFalse(tail.stopped()) tail.stop() dst.seek(0) actual: Dict[int, Set[int]] = {} for line in dst.readlines(): header, num = line.split(":") nums = actual.setdefault(header, set()) nums.add(int(num)) self.assertEqual(nprocs, len(actual)) self.assertEqual( {f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual ) self.assertTrue(tail.stopped())
class PContext(abc.ABC): """ The base class that standardizes operations over a set of processes that are launched via different mechanisms. The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. .. warning:: stdouts and stderrs should ALWAYS be a superset of tee_stdouts and tee_stderrs (respectively) this is b/c tee is implemented as a redirect + tail -f <stdout/stderr.log> """ def __init__( self, name: str, entrypoint: Union[Callable, str], args: Dict[int, Tuple], envs: Dict[int, Dict[str, str]], stdouts: Dict[int, str], stderrs: Dict[int, str], tee_stdouts: Dict[int, str], tee_stderrs: Dict[int, str], error_files: Dict[int, str], ): self.name = name # validate that all mappings have the same number of keys and # all local ranks are accounted for nprocs = len(args) _validate_full_rank(stdouts, nprocs, "stdouts") _validate_full_rank(stderrs, nprocs, "stderrs") self.entrypoint = entrypoint self.args = args self.envs = envs self.stdouts = stdouts self.stderrs = stderrs self.error_files = error_files self.nprocs = nprocs self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout) self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr) def start(self) -> None: """ Start processes using parameters defined in the constructor. """ self._start() self._stdout_tail.start() self._stderr_tail.start() @abc.abstractmethod def _start(self) -> None: """ Start processes using strategy defined in a particular context. """ raise NotImplementedError() @abc.abstractmethod def _poll(self) -> Optional[RunProcsResult]: """ Polls the run status of the processes running under this context. This method follows an "all-or-nothing" policy and returns a ``RunProcessResults`` object if either all processes complete successfully or any process fails. Returns ``None`` if all processes are still running. """ raise NotImplementedError() def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: """ Waits for the specified ``timeout`` seconds, polling every ``period`` seconds for the processes to be done. Returns ``None`` if the processes are still running on timeout expiry. Negative timeout values are interpreted as "wait-forever". A timeout value of zero simply queries the status of the processes (e.g. equivalent to a poll). """ if timeout == 0: return self._poll() if timeout < 0: timeout = sys.maxsize expiry = time.time() + timeout while time.time() < expiry: pr = self._poll() if pr: return pr time.sleep(period) return None @abc.abstractmethod def pids(self) -> Dict[int, int]: """ Returns pids of processes mapped by their respective local_ranks """ raise NotImplementedError() @abc.abstractmethod def _close(self) -> None: r""" Terminates all processes managed by this context and cleans up any meta resources (e.g. redirect, error_file files). """ raise NotImplementedError() def close(self) -> None: self._close() if self._stdout_tail: self._stdout_tail.stop() if self._stderr_tail: self._stderr_tail.stop()
class PContext(abc.ABC): """ The base class that standardizes operations over a set of processes that are launched via different mechanisms. The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. .. warning:: stdouts and stderrs should ALWAYS be a superset of tee_stdouts and tee_stderrs (respectively) this is b/c tee is implemented as a redirect + tail -f <stdout/stderr.log> """ def __init__( self, name: str, entrypoint: Union[Callable, str], args: Dict[int, Tuple], envs: Dict[int, Dict[str, str]], stdouts: Dict[int, str], stderrs: Dict[int, str], tee_stdouts: Dict[int, str], tee_stderrs: Dict[int, str], error_files: Dict[int, str], ): self.name = name # validate that all mappings have the same number of keys and # all local ranks are accounted for nprocs = len(args) _validate_full_rank(stdouts, nprocs, "stdouts") _validate_full_rank(stderrs, nprocs, "stderrs") self.entrypoint = entrypoint self.args = args self.envs = envs self.stdouts = stdouts self.stderrs = stderrs self.error_files = error_files self.nprocs = nprocs self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout) self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr) def start(self) -> None: """ Start processes using parameters defined in the constructor. """ signal.signal(signal.SIGTERM, _terminate_process_handler) signal.signal(signal.SIGINT, _terminate_process_handler) if not IS_WINDOWS: signal.signal(signal.SIGHUP, _terminate_process_handler) signal.signal(signal.SIGQUIT, _terminate_process_handler) self._start() self._stdout_tail.start() self._stderr_tail.start() @abc.abstractmethod def _start(self) -> None: """ Start processes using strategy defined in a particular context. """ raise NotImplementedError() @abc.abstractmethod def _poll(self) -> Optional[RunProcsResult]: """ Polls the run status of the processes running under this context. This method follows an "all-or-nothing" policy and returns a ``RunProcessResults`` object if either all processes complete successfully or any process fails. Returns ``None`` if all processes are still running. """ raise NotImplementedError() def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: """ Waits for the specified ``timeout`` seconds, polling every ``period`` seconds for the processes to be done. Returns ``None`` if the processes are still running on timeout expiry. Negative timeout values are interpreted as "wait-forever". A timeout value of zero simply queries the status of the processes (e.g. equivalent to a poll). ..note: Multiprocesing library registers SIGTERM and SIGINT signal handlers that raise ``SignalException`` when the signals received. It is up to the consumer of the code to properly handle the exception. It is important not to swallow the exception otherwise the process would not terminate. Example of the typical workflow can be: .. code-block:: python pc = start_processes(...) try: pc.wait(1) .. do some other work except SignalException as e: pc.shutdown(e.sigval, timeout=30) If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating received signal. If child processes will not terminate in the timeout time, the process will send the SIGKILL. """ if timeout == 0: return self._poll() if timeout < 0: timeout = sys.maxsize expiry = time.time() + timeout while time.time() < expiry: pr = self._poll() if pr: return pr time.sleep(period) return None @abc.abstractmethod def pids(self) -> Dict[int, int]: """ Returns pids of processes mapped by their respective local_ranks """ raise NotImplementedError() @abc.abstractmethod def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: r""" Terminates all processes managed by this context and cleans up any meta resources (e.g. redirect, error_file files). """ raise NotImplementedError() def close(self, death_sig: Optional[signal.Signals] = None, timeout: int = 30) -> None: r""" Terminates all processes managed by this context and cleans up any meta resources (e.g. redirect, error_file files). Args: death_sig: Death signal to terminate porcesses. timeout: Time to wait for processes to finish, if process is still alive after this time, it will be terminated via SIGKILL. """ if not death_sig: death_sig = _get_default_signal() self._close(death_sig=death_sig, timeout=timeout) if self._stdout_tail: self._stdout_tail.stop() if self._stderr_tail: self._stderr_tail.stop()