예제 #1
0
    def test_health_check_post_connect(self) -> None:
        with ipc.PIDServer(addr=0, num_clients=2) as pid_server:
            assert pid_server.listener
            _, port = pid_server.listener.getsockname()

            fail_time = time.time() + 0.2

            def health_check() -> None:
                assert time.time() < fail_time

            procs = [
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False)),
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False)),
            ]

            for p in procs:
                p.start()

            with pytest.raises(AssertionError):
                pid_server.run(health_check, poll_period=0.05)

            for p in procs:
                p.join()

            assert len(pid_server.graceful_shutdowns) == 0
예제 #2
0
    def test_return_code_on_worker_error(self) -> None:
        with ipc.PIDServer(addr=0, num_clients=2) as pid_server:
            assert pid_server.listener
            _, port = pid_server.listener.getsockname()

            # Enforce that the crashed worker causes the exit before the other worker exits.
            deadline = time.time() + 20

            # Enforce that run_subprocess exits nonzero on a worker failure, even if the main
            # subprocess exits zero.
            procs = [
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False, 30)),
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False, 0.5, 1, True)),
            ]

            for p in procs:
                p.start()

            error_code = pid_server.run_subprocess(["sleep", "2"])

            assert error_code == 79

            assert time.time(
            ) < deadline, "crashing worker did not trigger exit"

            for p in procs:
                p.terminate()
                p.join()
예제 #3
0
    def test_worker_crashes(self) -> None:
        with ipc.PIDServer(addr=0, num_clients=2) as pid_server:
            assert pid_server.listener
            _, port = pid_server.listener.getsockname()

            # Enforce that the crashed worker causes the exit before the other worker exits.
            deadline = time.time() + 20

            procs = [
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False, 30)),
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False, 0.5, 1, True)),
            ]

            for p in procs:
                p.start()

            with pytest.raises(det.errors.WorkerError):
                pid_server.run()

            assert time.time(
            ) < deadline, "crashing worker did not trigger exit"

            for p in procs:
                p.terminate()
                p.join()

            assert len(pid_server.graceful_shutdowns) == 0
예제 #4
0
    def test_single_worker_failure_is_caught(self) -> None:
        # This is a regression test; there used to be a codepath where we would stop checking pid's
        # after the last pidclient disconnected, even if it disconnected with a failure.
        with ipc.PIDServer(addr=0, num_clients=1) as pid_server:
            assert pid_server.listener
            _, port = pid_server.listener.getsockname()

            p = multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, False, 0.5, 1, True))

            p.start()

            with pytest.raises(det.errors.WorkerError):
                pid_server.run()

            p.terminate()
            p.join()
예제 #5
0
    def test_normal_execution(self) -> None:
        with ipc.PIDServer(addr=0, num_clients=2) as pid_server:
            assert pid_server.listener
            _, port = pid_server.listener.getsockname()

            procs = [
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, True, 0.1, 5)),
                multiprocessing.Process(target=TestPIDServer._worker_proc,
                                        args=(port, True, 0.1, 5)),
            ]

            for p in procs:
                p.start()

            pid_server.run()

            for p in procs:
                p.join()

            assert len(pid_server.graceful_shutdowns) == 2
예제 #6
0
            f"{opt} argument '{val}' is not valid; it should be a signal name ('SIGTERM', "
            "'SIGKILL', etc) or 'WAIT'"
        )
    return out


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-x", "--on-fail", dest="on_fail", action="store", default="SIGTERM")
    parser.add_argument("-e", "--on-exit", dest="on_exit", action="store", default="WAIT")
    parser.add_argument("--grace-period", dest="grace_period", type=int, default=3)
    parser.add_argument("addr")
    parser.add_argument("num_workers", type=int)
    parser.add_argument("cmd")
    parser.add_argument("cmd_args", nargs="*")
    args = parser.parse_args()

    on_fail = read_action("--on-fail", args.on_fail)
    on_exit = read_action("--on-exit", args.on_exit)
    addr = ipc.read_pid_server_addr(args.addr)

    with ipc.PIDServer(addr, args.num_workers) as pid_server:
        sys.exit(
            pid_server.run_subprocess(
                cmd=[args.cmd] + args.cmd_args,
                on_fail=on_fail,
                on_exit=on_exit,
                grace_period=args.grace_period,
            ),
        )