Exemplo n.º 1
0
class LocalElasticAgentZeusTest(unittest.TestCase):
    """
    Sets up and tears down a local zeus server for testing
    Tests should connect to localhost:self._mock_zeus_port
    """
    def setUp(self):
        self._mock_zeus = ZeusServerFixture()
        self._mock_zeus.wait_for_ready(max_secs=10)
        self._mock_zeus_port = self._mock_zeus.get_client_port()

    def tearDown(self):
        self._mock_zeus.kill()
        self._mock_zeus_port = None

    def _get_worker_spec(
        self,
        fn,
        args=(),
        max_restarts=1,
        num_agents=1,
        monitor_interval=0.1,
        local_world_size=8,
    ):
        run_id = str(uuid.uuid4().int)
        rdzv_handler = dist.rendezvous(
            f"zeus-adapter://localhost:{self._mock_zeus_port}/{run_id}"
            f"?min_size={num_agents}"
            f"&max_size={num_agents}")
        spec = WorkerSpec(
            role="test_trainer",
            local_world_size=local_world_size,
            fn=fn,
            args=args,
            rdzv_handler=rdzv_handler,
            max_restarts=max_restarts,
            monitor_interval=monitor_interval,
        )
        return spec

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_sad_function(self):
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=2)
        try:
            agent = LocalElasticAgent(spec, start_method="spawn")
            with self.assertRaises(WorkerGroupFailureException) as cm:
                agent.run()
        finally:
            spec.rdzv_handler.shutdown()

        excs = cm.exception.get_worker_exceptions()
        for i in range(spec.local_world_size):
            self.assertTrue(isinstance(excs[i], Exception))

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)
Exemplo n.º 2
0
class LaunchTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()
        cls._etcd_endpoint = cls._etcd_server.get_endpoint()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_python(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_bash(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            f"--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        with self.assertRaises(ValueError):
            # --no_python cannot be used with --module
            launch.main(args + ["--module"] + script_args)

        launch.main(args + script_args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    # @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_wrapper_fn_kill_script_process(self):
        """
        tests that the wrapper_fn properly terminates
        the script process (the script process is the sub_sub_process of
        the agent
        """
        nprocs = 2
        sleep = 300

        # wraps wrapper_fn to be torch.multiprocessing compatible
        # which requires rank to be passed as first arugment
        def wrap_wrap(rank, *args):
            launch.wrapper_fn(*args)

        context = start_processes(
            fn=wrap_wrap,
            args=(None, (path("bin/sleep_script.py"), "--sleep", f"{sleep}")),
            nprocs=nprocs,
            join=False,
            start_method="fork",
        )
        # quick check to see that the wrapper_fn started running
        # without this join() call we don't see an exception on typos
        # and other silly mistakes (silently fails)
        context.join(timeout=-1)

        script_pids = []
        for wrapper_fn_pid in context.pids():
            script_pid = get_child_pids(wrapper_fn_pid)
            # there should only be one child of wrapper_fn
            self.assertEqual(1, len(script_pid))
            script_pids.append(script_pid[0])

        for wrapper_fn_proc in context.processes:
            wrapper_fn_proc.terminate()
            wrapper_fn_proc.join()

        for script_pid in script_pids:
            self.assertFalse(pid_exists(script_pid))

    def _test_nproc_launch_configuration(self, nproc_type, expected_number):
        run_id = str(uuid.uuid4().int)
        nnodes = 1

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_type}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            f"--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        launch.main(args + script_args)

        world_size = nnodes * expected_number
        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_auto_configurations(self):
        self._test_nproc_launch_configuration("auto", os.cpu_count())

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_number_configurations(self):
        self._test_nproc_launch_configuration("4", 4)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_unknown_configurations(self):
        with self.assertRaises(ValueError):
            self._test_nproc_launch_configuration("unknown", 4)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.device_count", return_value=3)
    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
        self._test_nproc_launch_configuration("auto", 3)
        self._test_nproc_launch_configuration("gpu", 3)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        # we are only launching 1 node (even though max = 2)
        world_size = nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_standalone(self):
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--standalone",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic_multiple_agents(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        nnodes = 2
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        procs = []
        for _ in range(nnodes - 1):
            p = mp.Process(target=launch.main, args=[args])
            procs.append(p)
            p.start()
        launch.main(args)
        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    def test_min_max_nodes_parse(self):
        min_nodes, max_nodes = launch.parse_min_max_nnodes("1")
        self.assertTrue(min_nodes, max_nodes)
        self.assertTrue(1, min_nodes)
        min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20")
        self.assertTrue(2, min_nodes)
        self.assertTrue(20, max_nodes)
        with self.assertRaises(RuntimeError):
            launch.parse_min_max_nnodes("2:20:30")
Exemplo n.º 3
0
class LocalElasticAgentTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServerFixture()
        cls._etcd_server.start()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def _get_worker_spec(self,
                         fn,
                         args=(),
                         max_restarts=1,
                         num_agents=1,
                         monitor_interval=0.1):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        run_id = str(uuid.uuid4().int)
        rdzv_handler = dist.rendezvous(f"etcd://{host}:{port}/{run_id}"
                                       f"?min_workers={num_agents}"
                                       f"&max_workers={num_agents}")
        spec = WorkerSpec(
            role="test_trainer",
            local_world_size=8,
            fn=fn,
            args=args,
            rdzv_handler=rdzv_handler,
            max_restarts=max_restarts,
            monitor_interval=monitor_interval,
        )
        return spec

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_happy_function(self):
        spec = self._get_worker_spec(fn=_happy_function)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_distributed_sum(self):
        spec = self._get_worker_spec(fn=_distributed_sum, args=(0, ))
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_sad_function(self):
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        with self.assertRaises(WorkerGroupFailureException) as cm:
            agent.run()

        excs = cm.exception.get_worker_exceptions()
        for i in range(spec.local_world_size):
            self.assertTrue(isinstance(excs[i], Exception))

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_bipolar_function(self):
        spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        with self.assertRaises(Exception):
            agent.run()
        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_check_env_function(self):
        spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_get_worker_return_values(self):
        spec = self._get_worker_spec(fn=_return_rank_times, args=(2, ))
        agent = LocalElasticAgent(spec, start_method="fork")
        ret_vals = agent.run()

        self.assertEqual(spec.local_world_size, len(ret_vals))
        for i in range(spec.local_world_size):
            self.assertEqual(i * 2, ret_vals[i])

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_happy(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(target=_run_agent,
                                        args=(run_id, host, port, nnodes,
                                              nnodes))
            procs.append(p)
            p.start()

        # run one on the main process for debugging
        _run_agent(run_id, host, port, nnodes, nnodes)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_fault_tolerance(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(nnodes):
            p = multiprocessing.Process(target=_run_agent,
                                        args=(run_id, host, port, nnodes,
                                              nnodes))
            procs.append(p)
            p.start()

        # restart odd agents
        for i in range(nnodes):
            if i % 2 != 0:
                procs[i].kill()
                p = multiprocessing.Process(target=_run_agent,
                                            args=(run_id, host, port, nnodes,
                                                  nnodes))
                procs[i] = p
                p.start()

        for i in range(nnodes):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_elastic(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        min_size = 1
        max_size = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(max_size):
            p = multiprocessing.Process(target=_run_agent,
                                        args=(run_id, host, port, min_size,
                                              max_size))
            procs.append(p)
            p.start()

        # kill odd agents
        for i in range(max_size):
            if i % 2 != 0:
                procs[i].kill()

        for i in range(max_size):
            if i % 2 == 0:
                p = procs[i]
                p.join()
                self.assertEqual(0, p.exitcode)
Exemplo n.º 4
0
class LocalTimerExample(unittest.TestCase):
    """
    Demonstrates how to use LocalTimerServer and LocalTimerClient
    to enforce expiration of code-blocks.

    Since torch multiprocessing's ``start_process`` method currently
    does not take the multiprocessing context as parameter argument
    there is no way to create the mp.Queue in the correct
    context BEFORE spawning child processes. Once the ``start_process``
    API is changed in torch, then re-enable ``test_torch_mp_example``
    unittest. As of now this will SIGSEGV.
    """

    @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible")
    def test_torch_mp_example(self):
        # in practice set the max_interval to a larger value (e.g. 60 seconds)
        mp_queue = mp.get_context("spawn").Queue()
        server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
        server.start()

        world_size = 8
        # all processes should complete successfully
        # since start_process does NOT take context as parameter argument yet
        # this method WILL FAIL (hence the test is disabled)
        torch_mp.spawn(
            fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True
        )

        with self.assertRaises(Exception):
            # torch.multiprocessing.spawn kills all sub-procs
            # if one of them gets killed
            torch_mp.spawn(
                fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True
            )

        server.stop()

    @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible")
    def test_example_start_method_spawn(self):
        self._run_example_with(start_method="spawn")

    @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible")
    def test_example_start_method_forkserver(self):
        self._run_example_with(start_method="forkserver")

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_example_start_method_fork(self):
        self._run_example_with(start_method="fork")

    def _run_example_with(self, start_method):
        spawn_ctx = mp.get_context(start_method)
        mp_queue = spawn_ctx.Queue()
        server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
        server.start()

        world_size = 8
        processes = []
        for i in range(0, world_size):
            if i % 2 == 0:
                p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue))
            else:
                p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue))
            p.start()
            processes.append(p)

        for i in range(0, world_size):
            p = processes[i]
            p.join()
            if i % 2 == 0:
                self.assertEqual(-signal.SIGKILL, p.exitcode)
            else:
                self.assertEqual(0, p.exitcode)

        server.stop()
Exemplo n.º 5
0
class LocalElasticAgentTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_happy_function(self):
        spec = self._get_worker_spec(fn=_happy_function)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    def _get_worker_spec(
        self,
        fn,
        args=(),
        max_restarts=1,
        num_agents=1,
        monitor_interval=0.1,
        local_world_size=8,
    ):
        run_id = str(uuid.uuid4().int)
        rdzv_handler = dist.rendezvous(
            f"etcd://{self._etcd_server.get_endpoint()}/{run_id}"
            f"?min_workers={num_agents}"
            f"&max_workers={num_agents}")
        spec = WorkerSpec(
            role="test_trainer",
            local_world_size=local_world_size,
            fn=fn,
            args=args,
            rdzv_handler=rdzv_handler,
            max_restarts=max_restarts,
            monitor_interval=monitor_interval,
        )
        return spec

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_distributed_sum(self):
        spec = self._get_worker_spec(fn=_distributed_sum, args=(0, ))
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    class RoleConfig:
        __slots__ = [
            "role", "workers", "num_agents", "workers_num", "role_size"
        ]

        def __init__(self,
                     role: str,
                     workers=None,
                     num_agents: int = 0,
                     workers_num: int = 0):
            self.role = role
            self.workers = workers
            if workers_num != 0 and num_agents != 0:
                self.workers = [workers_num] * num_agents
            self.role_size = sum(self.workers)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_correct_rank_assignment_heterogeneous(self):
        roles_config = [
            self.RoleConfig("trainer", workers=[1, 2, 3, 4]),
            self.RoleConfig("ps", workers=[5, 2]),
            # split configuration to run the last one on the main process
            self.RoleConfig("master", workers=[8]),
        ]
        self.run_configuration(roles_config, 25)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_correct_rank_assignment_homogeneous(self):
        num_workers = 4
        roles_config = [
            self.RoleConfig("trainer", num_agents=4, workers_num=num_workers),
            self.RoleConfig("ps", num_agents=2, workers_num=num_workers),
            # split configuration to run the last one on the main process
            self.RoleConfig("master", num_agents=1, workers_num=num_workers),
        ]
        self.run_configuration(roles_config, 28)

    def run_configuration(self, roles_config, expected_world_size):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = sum(len(cfg.workers) for cfg in roles_config)
        run_id = str(uuid.uuid4().int)

        procs = []
        manager = multiprocessing.Manager()
        return_dict = manager.dict()
        default_args = (run_id, host, port, nnodes, nnodes,
                        _check_rank_assignment, ())
        for ind in range(len(roles_config) - 1):
            config = roles_config[ind]
            for num_workers in config.workers:
                p = multiprocessing.Process(
                    target=_run_agent,
                    args=(*default_args, num_workers, config.role,
                          return_dict),
                )
                procs.append(p)
                p.start()

        # run one on the main process for debugging
        config = roles_config[len(roles_config) - 1]
        _run_agent(*default_args, config.workers[0], config.role, return_dict)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)
        role_info_dict = {
            role_info.role: role_info
            for role_info in roles_config
        }
        self.verify_rank_consistency(return_dict, role_info_dict,
                                     expected_world_size)

    def verify_rank_consistency(self, return_dict, role_info_dict,
                                expected_world_size):
        role_ranks = {}
        global_ranks = []
        grouped_ranks = {}
        for role, res in return_dict.values():
            for (
                    group_rank,
                    rank,
                    world_size,
                    role_rank,
                    role_world_size,
            ) in res.values():
                role_info_config = role_info_dict[role]
                self.assertEqual(expected_world_size, world_size)
                self.assertEqual(role_info_config.role_size, role_world_size)
                if group_rank not in grouped_ranks:
                    grouped_ranks[group_rank] = []
                grouped_ranks[group_rank].append((rank, role_rank))
                global_ranks.append(rank)
                if role not in role_ranks:
                    role_ranks[role] = []
                role_ranks[role].append(role_rank)
        global_ranks = sorted(global_ranks)
        self.assertEqual(list(range(0, expected_world_size)), global_ranks)
        for role, role_config_info in role_info_dict.items():
            self.assertEqual(list(range(0, role_config_info.role_size)),
                             sorted(role_ranks[role]))
        # Make sure that each agent assignes consecutive ranks to workes
        # The first argument is the global_rank and the second argument
        # is role_rank
        for ranks_lst in grouped_ranks.values():
            self.verify_ranks_sequential(ranks_lst, 0)
            self.verify_ranks_sequential(ranks_lst, 1)

    def verify_ranks_sequential(self, ranks_pairs, rank_idx):
        ranks = sorted(rank_pair[rank_idx] for rank_pair in ranks_pairs)
        start_rank, end_rank = ranks[0], ranks[-1]
        self.assertEqual(list(range(start_rank, end_rank + 1)), ranks)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_distributed_sum_heterogenous(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 4
        run_id = str(uuid.uuid4().int)

        procs = []
        default_args = (run_id, host, port, nnodes, nnodes, _distributed_sum,
                        (0, ))
        for ind in range(nnodes - 1):
            p = multiprocessing.Process(target=_run_agent,
                                        args=(*default_args, ind + 1))
            procs.append(p)
            p.start()

        # run one on the main process for debugging
        _run_agent(*default_args, 8)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_sad_function(self):
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        with self.assertRaises(WorkerGroupFailureException) as cm:
            agent.run()

        excs = cm.exception.get_worker_exceptions()
        for i in range(spec.local_world_size):
            self.assertTrue(isinstance(excs[i], Exception))

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_bipolar_function(self):
        spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        with self.assertRaises(Exception):
            agent.run()
        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_check_env_function(self):
        spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_check_run_id(self):
        def return_run_id():
            return os.environ["TORCHELASTIC_RUN_ID"]

        spec = self._get_worker_spec(fn=return_run_id, max_restarts=0)
        agent = LocalElasticAgent(spec, start_method="fork")
        ret = agent.run()

        for i in range(spec.local_world_size):
            self.assertEqual(spec.rdzv_handler.get_run_id(), ret[i])

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_get_worker_return_values(self):
        spec = self._get_worker_spec(fn=_return_rank_times, args=(2, ))
        agent = LocalElasticAgent(spec, start_method="fork")
        ret_vals = agent.run()

        self.assertEqual(spec.local_world_size, len(ret_vals))
        for i in range(spec.local_world_size):
            self.assertEqual(i * 2, ret_vals[i])

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_happy(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(run_id, host, port, nnodes, nnodes, _distributed_sum,
                      (0, )),
            )
            procs.append(p)
            p.start()

        # run one on the main process for debugging
        _run_agent(run_id, host, port, nnodes, nnodes, _distributed_sum, (0, ))

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_fault_tolerance(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(nnodes):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(run_id, host, port, nnodes, nnodes, _distributed_sum,
                      (0, )),
            )
            procs.append(p)
            p.start()

        # restart odd agents
        for i in range(nnodes):
            if i % 2 != 0:
                procs[i].kill()
                p = multiprocessing.Process(
                    target=_run_agent,
                    args=(run_id, host, port, nnodes, nnodes, _distributed_sum,
                          (0, )),
                )
                procs[i] = p
                p.start()

        for i in range(nnodes):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_elastic(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        min_size = 1
        max_size = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(max_size):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(run_id, host, port, min_size, max_size, _distributed_sum,
                      (0, )),
            )
            procs.append(p)
            p.start()

        # kill odd agents
        for i in range(max_size):
            if i % 2 != 0:
                procs[i].kill()

        for i in range(max_size):
            if i % 2 == 0:
                p = procs[i]
                p.join()
                self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_torch_rpc(self):
        """
        Simple torch rpc example with torchelastic.
        Creates two agents (to simulate two node job),
        each agent runs a single worker. worker0 calls an rpc_sync on
        worker1.
        """

        # TODO upstream this to torch.distributed.rpc so that users do not have
        # to redundantly set rank as part of name (e.g. worker0) AND also pass
        # it explicitly as an argument to rpc.init_rpc
        def init_rpc(name_prefix, backend):
            rank = int(os.environ["RANK"])
            world_size = int(os.environ["WORLD_SIZE"])
            rpc.init_rpc(
                name=f"{name_prefix}{rank}",
                backend=backend,
                rank=rank,
                world_size=world_size,
            )

        def worker_0(queue, msg):
            init_rpc("worker", BackendType.PROCESS_GROUP)
            ret = rpc.rpc_sync(to="worker1", func=echo, args=(msg, ))
            queue.put(ret)
            rpc.shutdown()

        def worker_1():
            init_rpc("worker", BackendType.PROCESS_GROUP)
            rpc.shutdown()

        def run_agent(run_id,
                      etcd_host,
                      etcd_port,
                      start_method,
                      worker_fn,
                      worker_args=()):
            rdzv_handler = dist.rendezvous(
                f"etcd://{etcd_host}:{etcd_port}/{run_id}"
                f"?min_workers=2"
                f"&max_workers=2")
            spec = WorkerSpec(
                role="test_trainer",
                local_world_size=1,
                fn=worker_fn,
                args=worker_args,
                rdzv_handler=rdzv_handler,
                max_restarts=3,
                monitor_interval=1,
            )

            agent = LocalElasticAgent(spec, start_method)
            agent.run()

        run_id = str(uuid.uuid4().int)
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        start_method = "fork"
        msg = "hello world"
        mp_queue = multiprocessing.get_context(start_method).Queue()

        agent0 = multiprocessing.Process(
            target=run_agent,
            args=(run_id, host, port, start_method, worker_0, (mp_queue, msg)),
        )
        agent1 = multiprocessing.Process(target=run_agent,
                                         args=(run_id, host, port,
                                               start_method, worker_1, ()))

        agent0.start()
        agent1.start()

        agent0.join()
        agent1.join()

        self.assertEqual(msg, mp_queue.get())

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_workers_drift_success(self):

        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        default_args = (run_id, host, port, nnodes, nnodes, _simulate_work)
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(*default_args, (10, ), 2, "test_trainer", {}, 30),
            )
            procs.append(p)
            p.start()

        _run_agent(*default_args, (1, ), 2, "test_trainer", {}, 30)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_workers_drift_fail(self):

        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        default_args = (run_id, host, port, nnodes, nnodes, _simulate_work)
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(*default_args, (60, ), 2, "test_trainer", {}, 10),
            )
            procs.append(p)
            p.start()

        # TODO(aivanou): standardize error between different rendezvous stores
        with self.assertRaises(LookupError):
            _run_agent(*default_args, (1, ), 2, "test_trainer", {}, 10)
Exemplo n.º 6
0
class LaunchTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()
        cls._etcd_endpoint = cls._etcd_server.get_endpoint()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_python(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_bash(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            f"--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        with self.assertRaises(ValueError):
            # --no_python cannot be used with --module
            launch.main(args + ["--module"] + script_args)

        launch.main(args + script_args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    def _test_nproc_launch_configuration(self, nproc_type, expected_number):
        run_id = str(uuid.uuid4().int)
        nnodes = 1

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_type}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            f"--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        launch.main(args + script_args)

        world_size = nnodes * expected_number
        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    def test_nproc_launch_auto_configurations(self):
        self._test_nproc_launch_configuration("auto", os.cpu_count())

    def test_nproc_launch_number_configurations(self):
        self._test_nproc_launch_configuration("4", 4)

    def test_nproc_launch_unknown_configurations(self):
        with self.assertRaises(ValueError):
            self._test_nproc_launch_configuration("unknown", 4)

    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.device_count", return_value=3)
    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
        self._test_nproc_launch_configuration("auto", 3)
        self._test_nproc_launch_configuration("gpu", 3)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        # we are only launching 1 node (even though max = 2)
        world_size = nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_with_etcd(self):
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--with_etcd",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))
Exemplo n.º 7
0
class TestRpc(unittest.TestCase):
    def test_init_app(self):
        app.init_app(role="trainer",
                     backend=BackendType.PROCESS_GROUP,
                     backend_options=None)

    @patch("torch.distributed.autograd._init")
    @patch("torch.distributed.rpc.api._init_rpc_backend")
    def test_init_rpc(self, rpc_backend_mock, autograd_mock):
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        store = TestStore()
        app.init_rpc(
            name="trainer_worker",
            backend=BackendType.PROCESS_GROUP,
            backend_options=None,
            store=store,
        )
        autograd_mock.assert_called_once()
        rpc_backend_mock.assert_called_once()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_custom_init_rpc(self):
        def init_rpc(rank, world_size, port, name):
            os.environ["RANK"] = str(rank)
            os.environ["WORLD_SIZE"] = str(world_size)
            os.environ["MASTER_ADDR"] = "localhost"
            os.environ["MASTER_PORT"] = str(port)
            rendezvous_iterator = dist.rendezvous("env://",
                                                  rank=rank,
                                                  world_size=world_size)
            store, _, _ = next(rendezvous_iterator)
            app.init_rpc(
                name=name,
                backend=BackendType.PROCESS_GROUP,
                backend_options=None,
                store=store,
            )

        def master(msg, port):
            init_rpc(rank=0, world_size=2, port=port, name="master")
            ret = rpc.rpc_sync(to="worker", func=echo, args=(msg, ))
            rpc.shutdown()
            return ret

        def worker(port):
            init_rpc(rank=1, world_size=2, port=port, name="worker")
            rpc.shutdown()

        sock = find_free_port()
        port = sock.getsockname()[1]
        sock.close()

        worker_proc = multiprocessing.Process(target=worker, args=(port, ))
        worker_proc.start()
        expected_msg = "test_message_on_worker"

        actual_msg = master(expected_msg, port)
        worker_proc.join()
        self.assertEqual(expected_msg, actual_msg)

    def test_get_worker_names(self):
        pass

    def test_get_role_info(self):
        pass

    def test_get_all_roles(self):
        pass

    def test_wait_all(self):
        pass

    def test_rpc_sync_on_role(self):
        pass

    def test_rpc_async_on_role(self):
        pass

    def test_rpc_remote_on_role(self):
        pass

    def test_init_process_group(self):
        pass
Exemplo n.º 8
0
class LaunchTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServerFixture()
        cls._etcd_server.start()
        host = cls._etcd_server.get_host()
        port = cls._etcd_server.get_port()
        cls._etcd_endpoint = f"{host}:{port}"

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_python(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_python_use_env(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--use_env",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script_use_env.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_bash(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            f"--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        with self.assertRaises(ValueError):
            # --no_python also requires --use_env
            launch.main(args + script_args)

        with self.assertRaises(ValueError):
            # --no_python cannot be used with --module
            launch.main(args + ["--module"] + script_args)

        launch.main(args + ["--use_env"] + script_args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        # we are only launching 1 node (even though max = 2)
        world_size = nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            f"--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            f"--monitor_interval=1",
            f"--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))
Exemplo n.º 9
0
class LocalTimerTest(unittest.TestCase):
    def setUp(self):
        self.mp_queue = mp.Queue()
        self.max_interval = 0.01
        self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
        self.server.start()

    def tearDown(self):
        self.server.stop()

    def test_exception_propagation(self):
        with self.assertRaises(Exception, msg="foobar"):
            with timer.expires(after=1):
                raise Exception("foobar")

    def test_no_client(self):
        # no timer client configured; exception expected
        timer.configure(None)
        with self.assertRaises(RuntimeError):
            with timer.expires(after=1):
                pass

    def test_client_interaction(self):
        # no timer client configured but one passed in explicitly
        # no exception expected
        timer_client = timer.LocalTimerClient(self.mp_queue)
        timer_client.acquire = mock.MagicMock(wraps=timer_client.acquire)
        timer_client.release = mock.MagicMock(wraps=timer_client.release)
        with timer.expires(after=1, scope="test", client=timer_client):
            pass

        timer_client.acquire.assert_called_once_with("test", mock.ANY)
        timer_client.release.assert_called_once_with("test")

    def test_happy_path(self):
        timer.configure(timer.LocalTimerClient(self.mp_queue))
        with timer.expires(after=0.5):
            time.sleep(0.1)

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_get_timer_recursive(self):
        """
        If a function acquires a countdown timer with default scope,
        then recursive calls to the function should re-acquire the
        timer rather than creating a new one. That is only the last
        recursive call's timer will take effect.
        """
        self.server.start()
        timer.configure(timer.LocalTimerClient(self.mp_queue))

        # func should not time out
        def func(n):
            if n > 0:
                with timer.expires(after=0.1):
                    func(n - 1)
                    time.sleep(0.05)

        func(4)

        # func2 should time out
        def func2(n):
            if n > 0:
                with timer.expires(after=0.1):
                    func2(n - 1)
                    time.sleep(0.2)

        p = mp.Process(target=func2, args=(2, ))
        p.start()
        p.join()
        self.assertEqual(-signal.SIGKILL, p.exitcode)

    @staticmethod
    def _run(mp_queue, timeout, duration):
        client = timer.LocalTimerClient(mp_queue)
        timer.configure(client)

        with timer.expires(after=timeout):
            time.sleep(duration)

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_timer(self):
        timeout = 0.1
        duration = 1
        p = mp.Process(target=self._run,
                       args=(self.mp_queue, timeout, duration))
        p.start()
        p.join()
        self.assertEqual(-signal.SIGKILL, p.exitcode)
Exemplo n.º 10
0
class LocalTimerServerTest(unittest.TestCase):
    def setUp(self):
        self.mp_queue = mp.Queue()
        self.max_interval = 0.01
        self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)

    def tearDown(self):
        self.server.stop()

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_watchdog_call_count(self):
        """
        checks that the watchdog function ran wait/interval +- 1 times
        """
        self.server._run_watchdog = mock.MagicMock(
            wraps=self.server._run_watchdog)

        wait = 0.1

        self.server.start()
        time.sleep(wait)
        self.server.stop()
        watchdog_call_count = self.server._run_watchdog.call_count
        self.assertGreaterEqual(watchdog_call_count,
                                int(wait / self.max_interval) - 1)
        self.assertLessEqual(watchdog_call_count,
                             int(wait / self.max_interval) + 1)

    def test_watchdog_empty_queue(self):
        """
        checks that the watchdog can run on an empty queue
        """
        self.server._run_watchdog()

    def _expired_timer(self, pid, scope):
        expired = time.time() - 60
        return TimerRequest(worker_id=pid,
                            scope_id=scope,
                            expiration_time=expired)

    def _valid_timer(self, pid, scope):
        valid = time.time() + 60
        return TimerRequest(worker_id=pid,
                            scope_id=scope,
                            expiration_time=valid)

    def _release_timer(self, pid, scope):
        return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1)

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    @mock.patch("os.kill")
    def test_expired_timers(self, mock_os_kill):
        """
        tests that a single expired timer on a process should terminate
        the process and clean up all pending timers that was owned by the process
        """
        test_pid = -3
        self.mp_queue.put(self._expired_timer(pid=test_pid, scope="test1"))
        self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test2"))

        self.server._run_watchdog()

        self.assertEqual(0, len(self.server._timers))
        mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL)

    @mock.patch("os.kill")
    def test_acquire_release(self, mock_os_kill):
        """
        tests that:
          1. a timer can be acquired then released (should not terminate process)
          2. a timer can be vacuously released (e.g. no-op)
        """
        test_pid = -3
        self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test1"))
        self.mp_queue.put(self._release_timer(pid=test_pid, scope="test1"))
        self.mp_queue.put(self._release_timer(pid=test_pid, scope="test2"))

        self.server._run_watchdog()

        self.assertEqual(0, len(self.server._timers))
        mock_os_kill.assert_not_called()

    @mock.patch("os.kill")
    def test_valid_timers(self, mock_os_kill):
        """
        tests that valid timers are processed correctly and the process is left alone
        """
        self.mp_queue.put(self._valid_timer(pid=-3, scope="test1"))
        self.mp_queue.put(self._valid_timer(pid=-3, scope="test2"))
        self.mp_queue.put(self._valid_timer(pid=-2, scope="test1"))
        self.mp_queue.put(self._valid_timer(pid=-2, scope="test2"))

        self.server._run_watchdog()

        self.assertEqual(4, len(self.server._timers))
        self.assertTrue((-3, "test1") in self.server._timers)
        self.assertTrue((-3, "test2") in self.server._timers)
        self.assertTrue((-2, "test1") in self.server._timers)
        self.assertTrue((-2, "test2") in self.server._timers)
        mock_os_kill.assert_not_called()