@dataclass
class Conf:
    """
    Holds arguments to launch an agent (e.g. simulates an agent run on a node).

    """

    entrypoint: Callable
    local_world_size: int
    args: Tuple = ()
    role: str = "default"
    redirects: Std = Std.NONE
    tee: Std = Std.NONE


@unittest.skipIf(is_tsan(), "tests incompatible with tsan")
class LocalElasticAgentTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
        self._run_id = str(uuid.uuid4()).split("-")[0]
Пример #2
0
class LaunchTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()
        cls._etcd_endpoint = cls._etcd_server.get_endpoint()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

        # remove any lingering environment variables
        for env in os.environ.keys():
            if env.startswith("PET_"):
                del os.environ[env]

        # set a sentinel env var on the parent proc
        # this should be present on the child and gets
        # asserted in ``bin/test_script.py``
        os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR"

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    def test_launch_user_script_python(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    def test_launch_user_script_bash(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            "--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        with self.assertRaises(ValueError):
            # --no_python cannot be used with --module
            launch.main(args + ["--module"] + script_args)

        launch.main(args + script_args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    def test_launch_with_env_vars(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node

        os.environ["PET_NNODES"] = str(nnodes)
        os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node)
        os.environ["PET_RDZV_BACKEND"] = "etcd"
        os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
        os.environ["PET_RDZV_ID"] = run_id
        os.environ["PET_MONITOR_INTERVAL"] = "1"
        os.environ["PET_START_METHOD"] = "fork"
        os.environ["PET_NO_PYTHON"] = "1"

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        with self.assertRaises(ValueError):
            # --no_python cannot be used with --module
            os.environ["PET_MODULE"] = "1"
            launch.main(script_args)

        os.environ["PET_MODULE"] = "0"
        launch.main(script_args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    def _test_nproc_launch_configuration(self, nproc_type, expected_number):
        run_id = str(uuid.uuid4().int)
        nnodes = 1

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_type}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            "--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        launch.main(args + script_args)

        world_size = nnodes * expected_number
        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_auto_configurations(self):
        self._test_nproc_launch_configuration("auto", os.cpu_count())

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_number_configurations(self):
        self._test_nproc_launch_configuration("4", 4)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_unknown_configurations(self):
        with self.assertRaises(ValueError):
            self._test_nproc_launch_configuration("unknown", 4)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.device_count", return_value=3)
    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
        self._test_nproc_launch_configuration("auto", 3)
        self._test_nproc_launch_configuration("gpu", 3)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        # we are only launching 1 node (even though max = 2)
        world_size = nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic_worker_raise_exception(self):
        """
        Asserts that when the worker program fails and lancher raieses exception
        to indicate that worker process failed

        """
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--max_restarts=0",
            "--start_method=fork",
            path("bin/test_script.py"),
            "--fail",
        ]
        proc = mp.Process(target=launch_in_proc, args=(args, ))
        proc.start()
        proc.join()
        self.assertEqual(1, proc.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    @mock.patch(
        "torchelastic.agent.server.local_elastic_agent.LocalElasticAgent.run")
    def test_launch_elastic_agent_raise_exception(self, mock_agent_run):
        """
        Asserts that when the agent raises an exception
        the launcher re-raises the original exception
        """
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--max_restarts=0",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]

        mock_agent_run.side_effect = MockException
        with self.assertRaises(MockException):
            launch.main(args)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_standalone(self):
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--standalone",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic_multiple_agents(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        nnodes = 2
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        procs = []
        for _ in range(nnodes - 1):
            p = mp.Process(target=launch.main, args=[args])
            procs.append(p)
            p.start()
        launch.main(args)
        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual({str(i)
                             for i in range(world_size)},
                            set(os.listdir(self.test_dir)))

    def test_min_max_nodes_parse(self):
        min_nodes, max_nodes = launch.parse_min_max_nnodes("1")
        self.assertTrue(min_nodes, max_nodes)
        self.assertTrue(1, min_nodes)
        min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20")
        self.assertTrue(2, min_nodes)
        self.assertTrue(20, max_nodes)
        with self.assertRaises(RuntimeError):
            launch.parse_min_max_nnodes("2:20:30")

    @patch("torchelastic.distributed.launch.LocalElasticAgent")
    def test_launch_rdzv_shutdown(self, agent_mock_cls):
        nnodes = 1
        nproc_per_node = 4
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        agent_mock = Mock()
        agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED)
        agent_mock_cls.return_value = agent_mock
        rdzv_handler_mock = Mock()
        with patch("torchelastic.rendezvous.registry.get_rendezvous_handler"
                   ) as param_mock:
            param_mock.return_value = rdzv_handler_mock
            launch.main(args)
            rdzv_handler_mock.shutdown.assert_called_once()
Пример #3
0
class LaunchTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()
        cls._etcd_endpoint = cls._etcd_server.get_endpoint()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_python(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_user_script_bash(self):
        run_id = str(uuid.uuid4().int)
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            "--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        with self.assertRaises(ValueError):
            # --no_python cannot be used with --module
            launch.main(args + ["--module"] + script_args)

        launch.main(args + script_args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    # @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_wrapper_fn_kill_script_process(self):
        """
        tests that the wrapper_fn properly terminates
        the script process (the script process is the sub_sub_process of
        the agent
        """
        nprocs = 2
        sleep = 300

        # wraps wrapper_fn to be torch.multiprocessing compatible
        # which requires rank to be passed as first arugment
        def wrap_wrap(rank, *args):
            launch.wrapper_fn(*args)

        context = start_processes(
            fn=wrap_wrap,
            args=(None, (path("bin/sleep_script.py"), "--sleep", f"{sleep}")),
            nprocs=nprocs,
            join=False,
            start_method="fork",
        )
        # quick check to see that the wrapper_fn started running
        # without this join() call we don't see an exception on typos
        # and other silly mistakes (silently fails)
        context.join(timeout=-1)

        script_pids = []
        for wrapper_fn_pid in context.pids():
            script_pid = get_child_pids(wrapper_fn_pid)
            # there should only be one child of wrapper_fn
            self.assertEqual(1, len(script_pid))
            script_pids.append(script_pid[0])

        for wrapper_fn_proc in context.processes:
            wrapper_fn_proc.terminate()
            wrapper_fn_proc.join()

        for script_pid in script_pids:
            self.assertFalse(pid_exists(script_pid))

    def _test_nproc_launch_configuration(self, nproc_type, expected_number):
        run_id = str(uuid.uuid4().int)
        nnodes = 1

        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_type}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            "--no_python",
        ]

        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]

        launch.main(args + script_args)

        world_size = nnodes * expected_number
        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_auto_configurations(self):
        self._test_nproc_launch_configuration("auto", os.cpu_count())

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_number_configurations(self):
        self._test_nproc_launch_configuration("4", 4)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_nproc_launch_unknown_configurations(self):
        with self.assertRaises(ValueError):
            self._test_nproc_launch_configuration("unknown", 4)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.device_count", return_value=3)
    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
        self._test_nproc_launch_configuration("auto", 3)
        self._test_nproc_launch_configuration("gpu", 3)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        # we are only launching 1 node (even though max = 2)
        world_size = nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_standalone(self):
        nnodes = 1
        nproc_per_node = 4
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--standalone",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        launch.main(args)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_launch_elastic_multiple_agents(self):
        run_id = str(uuid.uuid4().int)
        min_nodes = 1
        max_nodes = 2
        nproc_per_node = 4
        nnodes = 2
        world_size = nnodes * nproc_per_node
        args = [
            f"--nnodes={min_nodes}:{max_nodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--rdzv_backend=etcd",
            f"--rdzv_endpoint={self._etcd_endpoint}",
            f"--rdzv_id={run_id}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        procs = []
        for _ in range(nnodes - 1):
            p = mp.Process(target=launch.main, args=[args])
            procs.append(p)
            p.start()
        launch.main(args)
        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

        # make sure all the workers ran
        # each worker touches a file with its global rank as the name
        self.assertSetEqual(
            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
        )

    def test_min_max_nodes_parse(self):
        min_nodes, max_nodes = launch.parse_min_max_nnodes("1")
        self.assertTrue(min_nodes, max_nodes)
        self.assertTrue(1, min_nodes)
        min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20")
        self.assertTrue(2, min_nodes)
        self.assertTrue(20, max_nodes)
        with self.assertRaises(RuntimeError):
            launch.parse_min_max_nnodes("2:20:30")

    @patch("torchelastic.distributed.launch.LocalElasticAgent")
    def test_launch_rdzv_shutdown(self, _):
        nnodes = 1
        nproc_per_node = 4
        args = [
            f"--nnodes={nnodes}",
            f"--nproc_per_node={nproc_per_node}",
            "--monitor_interval=1",
            "--start_method=fork",
            path("bin/test_script.py"),
            f"--touch_file_dir={self.test_dir}",
        ]
        rdzv_handler_mock = Mock()
        with patch(
            "torchelastic.rendezvous.registry.get_rendezvous_handler"
        ) as param_mock:
            param_mock.return_value = rdzv_handler_mock
            launch.main(args)
            rdzv_handler_mock.shutdown.assert_called_once()
Пример #4
0
class LocalElasticAgentTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # start a standalone, single process etcd server to use for all tests
        cls._etcd_server = EtcdServer()
        cls._etcd_server.start()

    @classmethod
    def tearDownClass(cls):
        # stop the standalone etcd server
        cls._etcd_server.stop()

    def setUp(self):
        # clear env vars
        os.environ.pop("TORCHELASTIC_ERROR_FILE", None)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_happy_function(self):
        spec = self._get_worker_spec(fn=_happy_function)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    def _get_worker_spec(
        self,
        fn=None,
        cmd=None,
        args=(),
        max_restarts=1,
        num_agents=1,
        monitor_interval=0.1,
        local_world_size=8,
    ):
        run_id = str(uuid.uuid4().int)

        rdzv_params = RendezvousParameters(
            backend="etcd",
            endpoint=f"{self._etcd_server.get_endpoint()}",
            run_id=run_id,
            min_nodes=num_agents,
            max_nodes=num_agents,
        )
        rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
        spec = WorkerSpec(
            role="test_trainer",
            local_world_size=local_world_size,
            fn=fn,
            cmd=cmd,
            args=args,
            rdzv_handler=rdzv_handler,
            max_restarts=max_restarts,
            monitor_interval=monitor_interval,
        )
        return spec

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_check_role_name(self):
        spec = self._get_worker_spec(fn=_get_env_var, args=("ROLE_NAME",))
        agent = LocalElasticAgent(spec, start_method="fork")
        group_result = agent.run()
        results = group_result.return_values
        for role_name in results.values():
            self.assertEquals(spec.role, role_name)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_distributed_sum(self):
        spec = self._get_worker_spec(fn=_distributed_sum, args=(0,))
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    class RoleConfig:
        __slots__ = ["role", "workers", "num_agents", "workers_num", "role_size"]

        def __init__(
            self, role: str, workers=None, num_agents: int = 0, workers_num: int = 0
        ):
            self.role = role
            self.workers = workers
            if workers_num != 0 and num_agents != 0:
                self.workers = [workers_num] * num_agents
            self.role_size = sum(self.workers)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_correct_rank_assignment_heterogeneous(self):
        roles_config = [
            self.RoleConfig("trainer", workers=[1, 2, 3, 4]),
            self.RoleConfig("ps", workers=[5, 2]),
            # split configuration to run the last one on the main process
            self.RoleConfig("master", workers=[8]),
        ]
        self.run_configuration(roles_config, 25)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_correct_rank_assignment_homogeneous(self):
        num_workers = 4
        roles_config = [
            self.RoleConfig("trainer", num_agents=4, workers_num=num_workers),
            self.RoleConfig("ps", num_agents=2, workers_num=num_workers),
            # split configuration to run the last one on the main process
            self.RoleConfig("master", num_agents=1, workers_num=num_workers),
        ]
        self.run_configuration(roles_config, 28)

    def run_configuration(self, roles_config, expected_world_size):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = sum(len(cfg.workers) for cfg in roles_config)
        run_id = str(uuid.uuid4().int)

        procs = []
        manager = multiprocessing.Manager()
        return_dict = manager.dict()
        default_args = (run_id, host, port, nnodes, nnodes, _check_rank_assignment, ())
        for ind in range(len(roles_config) - 1):
            config = roles_config[ind]
            for num_workers in config.workers:
                p = multiprocessing.Process(
                    target=_run_agent,
                    args=(*default_args, num_workers, config.role, return_dict),
                )
                procs.append(p)
                p.start()

        # run one on the main process for debugging
        config = roles_config[len(roles_config) - 1]
        _run_agent(*default_args, config.workers[0], config.role, return_dict)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)
        role_info_dict = {role_info.role: role_info for role_info in roles_config}
        self.verify_rank_consistency(return_dict, role_info_dict, expected_world_size)

    def verify_rank_consistency(self, return_dict, role_info_dict, expected_world_size):
        role_ranks = {}
        global_ranks = []
        grouped_ranks = {}
        for role, group_result in return_dict.values():
            res = group_result.return_values
            for (
                group_rank,
                rank,
                world_size,
                role_rank,
                role_world_size,
            ) in res.values():
                role_info_config = role_info_dict[role]
                self.assertEqual(expected_world_size, world_size)
                self.assertEqual(role_info_config.role_size, role_world_size)
                if group_rank not in grouped_ranks:
                    grouped_ranks[group_rank] = []
                grouped_ranks[group_rank].append((rank, role_rank))
                global_ranks.append(rank)
                if role not in role_ranks:
                    role_ranks[role] = []
                role_ranks[role].append(role_rank)
        global_ranks = sorted(global_ranks)
        self.assertEqual(list(range(0, expected_world_size)), global_ranks)
        for role, role_config_info in role_info_dict.items():
            self.assertEqual(
                list(range(0, role_config_info.role_size)), sorted(role_ranks[role])
            )
        # Make sure that each agent assignes consecutive ranks to workes
        # The first argument is the global_rank and the second argument
        # is role_rank
        for ranks_lst in grouped_ranks.values():
            self.verify_ranks_sequential(ranks_lst, 0)
            self.verify_ranks_sequential(ranks_lst, 1)

    def verify_ranks_sequential(self, ranks_pairs, rank_idx):
        ranks = sorted(rank_pair[rank_idx] for rank_pair in ranks_pairs)
        start_rank, end_rank = ranks[0], ranks[-1]
        self.assertEqual(list(range(start_rank, end_rank + 1)), ranks)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_distributed_sum_heterogenous(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 4
        run_id = str(uuid.uuid4().int)

        procs = []
        default_args = (run_id, host, port, nnodes, nnodes, _distributed_sum, (0,))
        for ind in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent, args=(*default_args, ind + 1)
            )
            procs.append(p)
            p.start()

        # run one on the main process for debugging
        _run_agent(*default_args, 8)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_sad_function(self):
        self._test_run_sad_function()

    @record
    def _test_run_sad_function(self):
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=0)
        agent = LocalElasticAgent(spec, start_method="fork")
        group_results = agent.run()
        failed_results = group_results.failures
        self.assertEqual(spec.local_world_size, len(failed_results))
        # all ranks will have the same result
        for result in failed_results.values():
            self.assertTrue(os.path.exists(result.error_file))
            with open(result.error_file, "r") as f:
                data = f.read().replace("\n", "")
                self.assertTrue("RuntimeError: sad because i throw" in data)

        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_bipolar_function(self):
        spec = self._get_worker_spec(fn=_bipolar_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()
        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
        self.assertEqual(0, agent._remaining_restarts)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_check_env_function(self):
        spec = self._get_worker_spec(fn=_check_env_function, max_restarts=2)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_check_run_id(self):
        def return_run_id():
            return os.environ["TORCHELASTIC_RUN_ID"]

        spec = self._get_worker_spec(fn=return_run_id, max_restarts=0)
        agent = LocalElasticAgent(spec, start_method="fork")
        group_result = agent.run()
        results = group_result.return_values

        for i in range(spec.local_world_size):
            self.assertEqual(spec.rdzv_handler.get_run_id(), results[i])

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_get_worker_return_values(self):
        spec = self._get_worker_spec(fn=_return_rank_times, args=(2,))
        agent = LocalElasticAgent(spec, start_method="fork")
        group_result = agent.run()
        results = group_result.return_values

        self.assertEqual(spec.local_world_size, len(results))
        for i in range(spec.local_world_size):
            self.assertEqual(i * 2, results[i])

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_happy(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)),
            )
            procs.append(p)
            p.start()

        # run one on the main process for debugging
        _run_agent(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,))

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_fault_tolerance(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(nnodes):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)),
            )
            procs.append(p)
            p.start()

        # restart odd agents
        for i in range(nnodes):
            if i % 2 != 0:
                procs[i].kill()
                p = multiprocessing.Process(
                    target=_run_agent,
                    args=(run_id, host, port, nnodes, nnodes, _distributed_sum, (0,)),
                )
                procs[i] = p
                p.start()

        for i in range(nnodes):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_double_agent_elastic(self):
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        min_size = 1
        max_size = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        for _ in range(max_size):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(run_id, host, port, min_size, max_size, _distributed_sum, (0,)),
            )
            procs.append(p)
            p.start()

        # kill odd agents
        for i in range(max_size):
            if i % 2 != 0:
                procs[i].kill()

        for i in range(max_size):
            if i % 2 == 0:
                p = procs[i]
                p.join()
                self.assertEqual(0, p.exitcode)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_torch_rpc(self):
        """
        Simple torch rpc example with torchelastic.
        Creates two agents (to simulate two node job),
        each agent runs a single worker. worker0 calls an rpc_sync on
        worker1.
        """

        # TODO upstream this to torch.distributed.rpc so that users do not have
        # to redundantly set rank as part of name (e.g. worker0) AND also pass
        # it explicitly as an argument to rpc.init_rpc
        def init_rpc(name_prefix, backend):
            rank = int(os.environ["RANK"])
            world_size = int(os.environ["WORLD_SIZE"])
            rpc.init_rpc(
                name=f"{name_prefix}{rank}",
                backend=backend,
                rank=rank,
                world_size=world_size,
            )

        def worker_0(queue, msg):
            init_rpc("worker", BackendType.PROCESS_GROUP)
            ret = rpc.rpc_sync(to="worker1", func=echo, args=(msg,))
            queue.put(ret)
            rpc.shutdown()

        def worker_1():
            init_rpc("worker", BackendType.PROCESS_GROUP)
            rpc.shutdown()

        def run_agent(
            run_id, etcd_host, etcd_port, start_method, worker_fn, worker_args=()
        ):
            rdzv_params = RendezvousParameters(
                backend="etcd",
                endpoint=f"{etcd_host}:{etcd_port}",
                run_id=run_id,
                min_nodes=2,
                max_nodes=2,
            )
            rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)

            spec = WorkerSpec(
                role="test_trainer",
                local_world_size=1,
                fn=worker_fn,
                args=worker_args,
                rdzv_handler=rdzv_handler,
                max_restarts=3,
                monitor_interval=1,
            )

            agent = LocalElasticAgent(spec, start_method)
            agent.run()

        run_id = str(uuid.uuid4().int)
        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        start_method = "fork"
        msg = "hello world"
        mp_queue = multiprocessing.get_context(start_method).Queue()

        agent0 = multiprocessing.Process(
            target=run_agent,
            args=(run_id, host, port, start_method, worker_0, (mp_queue, msg)),
        )
        agent1 = multiprocessing.Process(
            target=run_agent, args=(run_id, host, port, start_method, worker_1, ())
        )

        agent0.start()
        agent1.start()

        agent0.join()
        agent1.join()

        self.assertEqual(msg, mp_queue.get())

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_workers_drift_success(self):

        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        default_args = (run_id, host, port, nnodes, nnodes, _simulate_work)
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(*default_args, (10,), 2, "test_trainer", {}, 30),
            )
            procs.append(p)
            p.start()

        _run_agent(*default_args, (1,), 2, "test_trainer", {}, 30)

        for i in range(nnodes - 1):
            p = procs[i]
            p.join()
            self.assertEqual(0, p.exitcode)

    @patch("torchelastic.utils.store.barrier")
    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_workers_drift_fail(self, barrier_mock):

        host = self._etcd_server.get_host()
        port = self._etcd_server.get_port()
        nnodes = 2
        run_id = str(uuid.uuid4().int)

        procs = []
        default_args = (run_id, host, port, nnodes, nnodes, _simulate_work)
        for _ in range(nnodes - 1):
            p = multiprocessing.Process(
                target=_run_agent,
                args=(*default_args, (60,), 2, "test_trainer", {}, 10),
            )
            procs.append(p)
            p.start()

        _run_agent(*default_args, (1,), 2, "test_trainer", {}, 10)
        barrier_mock.assert_called_once()

    @patch("torchelastic.utils.store.barrier")
    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_barrier_failed(self, barrier_mock):
        barrier_mock.side_effect = RuntimeError("test error")
        spec = self._get_worker_spec(fn=_happy_function)
        agent = LocalElasticAgent(spec, start_method="fork")
        agent.run()
        barrier_mock.assert_called_once()

    def test_provide_fn_and_cmd(self):
        with self.assertRaises(AssertionError):
            self._get_worker_spec(
                fn=_bipolar_function, cmd=["test.bin"], max_restarts=2
            )

    def test_provide_none(self):
        with self.assertRaises(AssertionError):
            self._get_worker_spec(max_restarts=2)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_failed_result_with_run_id(self):
        temp_dir = tempfile.mkdtemp()
        os.environ["TORCHELASTIC_ERROR_FILE"] = f"{temp_dir}/error.log"
        self._test_failed_result_with_run_id()
        shutil.rmtree(temp_dir)

    @record
    def _test_failed_result_with_run_id(self):
        max_restarts = 3
        spec = self._get_worker_spec(fn=_sad_function, max_restarts=max_restarts)
        agent = LocalElasticAgent(spec, start_method="fork")
        run_result = agent.run()
        for failure in run_result.failures.values():
            error_file = failure.error_file
            self.assertTrue(error_file.endswith(f"_{max_restarts}"))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_transient_bug(self):
        temp_dir = tempfile.mkdtemp()
        os.environ["TORCHELASTIC_ERROR_FILE"] = f"{temp_dir}/error.log"
        self._test_transient_bug(temp_dir)
        shutil.rmtree(temp_dir)

    @record
    def _test_transient_bug(self, error_dir: str):
        max_restarts = 3
        spec = self._get_worker_spec(fn=_transient_bug, max_restarts=max_restarts)
        agent = LocalElasticAgent(spec, start_method="fork")
        run_result = agent.run()
        self.assertEqual(WorkerState.SUCCEEDED, run_result.state)
        for rank in range(len(run_result.return_values)):
            error_file_0 = os.path.join(error_dir, str(rank), "error.log_0")
            self.assertTrue(os.path.exists(error_file_0))
            error_file_1 = os.path.join(error_dir, str(rank), "error.log_1")
            self.assertFalse(os.path.exists(error_file_1))
Пример #5
0
class LocalTimerTest(unittest.TestCase):
    def setUp(self):
        self.mp_queue = mp.Queue()
        self.max_interval = 0.01
        self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
        self.server.start()

    def tearDown(self):
        self.server.stop()

    def test_exception_propagation(self):
        with self.assertRaises(Exception, msg="foobar"):
            with timer.expires(after=1):
                raise Exception("foobar")

    def test_no_client(self):
        # no timer client configured; exception expected
        timer.configure(None)
        with self.assertRaises(RuntimeError):
            with timer.expires(after=1):
                pass

    def test_client_interaction(self):
        # no timer client configured but one passed in explicitly
        # no exception expected
        timer_client = timer.LocalTimerClient(self.mp_queue)
        timer_client.acquire = mock.MagicMock(wraps=timer_client.acquire)
        timer_client.release = mock.MagicMock(wraps=timer_client.release)
        with timer.expires(after=1, scope="test", client=timer_client):
            pass

        timer_client.acquire.assert_called_once_with("test", mock.ANY)
        timer_client.release.assert_called_once_with("test")

    def test_happy_path(self):
        timer.configure(timer.LocalTimerClient(self.mp_queue))
        with timer.expires(after=0.5):
            time.sleep(0.1)

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_get_timer_recursive(self):
        """
        If a function acquires a countdown timer with default scope,
        then recursive calls to the function should re-acquire the
        timer rather than creating a new one. That is only the last
        recursive call's timer will take effect.
        """
        self.server.start()
        timer.configure(timer.LocalTimerClient(self.mp_queue))

        # func should not time out
        def func(n):
            if n > 0:
                with timer.expires(after=0.1):
                    func(n - 1)
                    time.sleep(0.05)

        func(4)

        # func2 should time out
        def func2(n):
            if n > 0:
                with timer.expires(after=0.1):
                    func2(n - 1)
                    time.sleep(0.2)

        p = mp.Process(target=func2, args=(2, ))
        p.start()
        p.join()
        self.assertEqual(-signal.SIGKILL, p.exitcode)

    @staticmethod
    def _run(mp_queue, timeout, duration):
        client = timer.LocalTimerClient(mp_queue)
        timer.configure(client)

        with timer.expires(after=timeout):
            time.sleep(duration)

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_timer(self):
        timeout = 0.1
        duration = 1
        p = mp.Process(target=self._run,
                       args=(self.mp_queue, timeout, duration))
        p.start()
        p.join()
        self.assertEqual(-signal.SIGKILL, p.exitcode)
Пример #6
0
class LocalTimerServerTest(unittest.TestCase):
    def setUp(self):
        self.mp_queue = mp.Queue()
        self.max_interval = 0.01
        self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)

    def tearDown(self):
        self.server.stop()

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_watchdog_call_count(self):
        """
        checks that the watchdog function ran wait/interval +- 1 times
        """
        self.server._run_watchdog = mock.MagicMock(
            wraps=self.server._run_watchdog)

        wait = 0.1

        self.server.start()
        time.sleep(wait)
        self.server.stop()
        watchdog_call_count = self.server._run_watchdog.call_count
        self.assertGreaterEqual(watchdog_call_count,
                                int(wait / self.max_interval) - 1)
        self.assertLessEqual(watchdog_call_count,
                             int(wait / self.max_interval) + 1)

    def test_watchdog_empty_queue(self):
        """
        checks that the watchdog can run on an empty queue
        """
        self.server._run_watchdog()

    def _expired_timer(self, pid, scope):
        expired = time.time() - 60
        return TimerRequest(worker_id=pid,
                            scope_id=scope,
                            expiration_time=expired)

    def _valid_timer(self, pid, scope):
        valid = time.time() + 60
        return TimerRequest(worker_id=pid,
                            scope_id=scope,
                            expiration_time=valid)

    def _release_timer(self, pid, scope):
        return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1)

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    @mock.patch("os.kill")
    def test_expired_timers(self, mock_os_kill):
        """
        tests that a single expired timer on a process should terminate
        the process and clean up all pending timers that was owned by the process
        """
        test_pid = -3
        self.mp_queue.put(self._expired_timer(pid=test_pid, scope="test1"))
        self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test2"))

        self.server._run_watchdog()

        self.assertEqual(0, len(self.server._timers))
        mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL)

    @mock.patch("os.kill")
    def test_acquire_release(self, mock_os_kill):
        """
        tests that:
          1. a timer can be acquired then released (should not terminate process)
          2. a timer can be vacuously released (e.g. no-op)
        """
        test_pid = -3
        self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test1"))
        self.mp_queue.put(self._release_timer(pid=test_pid, scope="test1"))
        self.mp_queue.put(self._release_timer(pid=test_pid, scope="test2"))

        self.server._run_watchdog()

        self.assertEqual(0, len(self.server._timers))
        mock_os_kill.assert_not_called()

    @mock.patch("os.kill")
    def test_valid_timers(self, mock_os_kill):
        """
        tests that valid timers are processed correctly and the process is left alone
        """
        self.mp_queue.put(self._valid_timer(pid=-3, scope="test1"))
        self.mp_queue.put(self._valid_timer(pid=-3, scope="test2"))
        self.mp_queue.put(self._valid_timer(pid=-2, scope="test1"))
        self.mp_queue.put(self._valid_timer(pid=-2, scope="test2"))

        self.server._run_watchdog()

        self.assertEqual(4, len(self.server._timers))
        self.assertTrue((-3, "test1") in self.server._timers)
        self.assertTrue((-3, "test2") in self.server._timers)
        self.assertTrue((-2, "test1") in self.server._timers)
        self.assertTrue((-2, "test2") in self.server._timers)
        mock_os_kill.assert_not_called()
Пример #7
0
class LocalTimerExample(unittest.TestCase):
    """
    Demonstrates how to use LocalTimerServer and LocalTimerClient
    to enforce expiration of code-blocks.

    Since torch multiprocessing's ``start_process`` method currently
    does not take the multiprocessing context as parameter argument
    there is no way to create the mp.Queue in the correct
    context BEFORE spawning child processes. Once the ``start_process``
    API is changed in torch, then re-enable ``test_torch_mp_example``
    unittest. As of now this will SIGSEGV.
    """
    @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible")
    def test_torch_mp_example(self):
        # in practice set the max_interval to a larger value (e.g. 60 seconds)
        mp_queue = mp.get_context("spawn").Queue()
        server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
        server.start()

        world_size = 8

        # all processes should complete successfully
        # since start_process does NOT take context as parameter argument yet
        # this method WILL FAIL (hence the test is disabled)
        torch_mp.spawn(fn=_happy_function,
                       args=(mp_queue, ),
                       nprocs=world_size,
                       join=True)

        with self.assertRaises(Exception):
            # torch.multiprocessing.spawn kills all sub-procs
            # if one of them gets killed
            torch_mp.spawn(fn=_stuck_function,
                           args=(mp_queue, ),
                           nprocs=world_size,
                           join=True)

        server.stop()

    @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible")
    def test_example_start_method_spawn(self):
        self._run_example_with(start_method="spawn")

    @unittest.skipIf(is_asan_or_tsan(), "test is a/tsan incompatible")
    def test_example_start_method_forkserver(self):
        self._run_example_with(start_method="forkserver")

    @unittest.skipIf(is_tsan(), "test is tsan incompatible")
    def test_example_start_method_fork(self):
        self._run_example_with(start_method="fork")

    def _run_example_with(self, start_method):
        spawn_ctx = mp.get_context(start_method)
        mp_queue = spawn_ctx.Queue()
        server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
        server.start()

        world_size = 8
        processes = []
        for i in range(0, world_size):
            if i % 2 == 0:
                p = spawn_ctx.Process(target=_stuck_function,
                                      args=(i, mp_queue))
            else:
                p = spawn_ctx.Process(target=_happy_function,
                                      args=(i, mp_queue))
            p.start()
            processes.append(p)

        for i in range(0, world_size):
            p = processes[i]
            p.join()
            if i % 2 == 0:
                self.assertEqual(-signal.SIGKILL, p.exitcode)
            else:
                self.assertEqual(0, p.exitcode)

        server.stop()
Пример #8
0
class MpProcessContextTest(unittest.TestCase):
    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_success(self):
        nprocs = 4
        mult = 2
        params = [MpParameters(fn=run_compute, args=(mult, ))] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        ret_vals = proc_context.wait()
        while not ret_vals:
            ret_vals = proc_context.wait()
        self.assertEqual(4, len(ret_vals))
        for local_rank, ret_val in ret_vals.items():
            self.assertEqual(mult * local_rank, ret_val)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_success_no_return_func(self):
        nprocs = 4
        params = [MpParameters(fn=run_dummy, args=())] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        ret_vals = proc_context.wait()
        while not ret_vals:
            ret_vals = proc_context.wait()
        self.assertEqual(4, len(ret_vals))
        for ret_val in ret_vals.values():
            self.assertEqual(None, ret_val)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_huge_output(self):
        # python multiprocessing.queue module uses pipes and actually PipedQueues
        # This means that if a single object is greater than a pipe size
        # the writer process will block until reader process will start
        # reading the pipe.
        # This test makes a worker fn to return huge output, around ~10 MB
        nprocs = 4
        size = 200000
        params = [MpParameters(fn=fill_dict, args=(size, ))] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        ret_vals = proc_context.wait()
        while not ret_vals:
            ret_vals = proc_context.wait()
        self.assertEqual(4, len(ret_vals))
        for ret_val in ret_vals.values():
            self.assertEqual(size, len(ret_val))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_failure(self):
        nprocs = 4
        params = [MpParameters(fn=run_failure, args=())] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        with self.assertRaises(ProcessRaisedException):
            ret_vals = proc_context.wait()
            while not ret_vals:
                ret_vals = proc_context.wait()

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_termination(self):
        nprocs = 5
        params = [MpParameters(fn=run_infinite, args=())] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        proc_context.terminate()
        # Processes should terminate with SIGTERM
        with self.assertRaises(Exception):
            proc_context.wait()

    def test_wait_busy_loop(self):
        nprocs = 2
        wait_time = 10  # seconds
        params = [MpParameters(fn=run_with_wait, args=(wait_time, ))] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        self.assertIsNone(proc_context.wait(1))
        while proc_context.wait(1) is None:
            pass

    def test_wrap_fn(self):
        nprocs = 2
        start_method = "spawn"
        out_queues: Dict[int, mp.SimpleQueue] = {
            i: mp.get_context(start_method).SimpleQueue()
            for i in range(0, nprocs)
        }
        params = [MpParameters(fn=run_compute, args=(1, ))] * nprocs
        for idx in range(nprocs):
            _wrap(idx, params, out_queues)
        for idx, out_q in out_queues.items():
            self.assertFalse(out_q.empty(), "out queue should not be empty")
            self.assertEqual(idx, out_q.get())
Пример #9
0
class MpProcessContextTest(unittest.TestCase):
    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_success(self):
        nprocs = 4
        mult = 2
        params = [MpParameters(fn=run_compute, args=(mult,))] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        proc_group_result = self._get_result(proc_context)
        ret_vals = proc_group_result.return_values
        self.assertEqual(4, len(ret_vals))
        for local_rank, ret_val in ret_vals.items():
            self.assertEqual(mult * local_rank, ret_val)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_success_no_return_func(self):
        nprocs = 4
        params = [MpParameters(fn=run_dummy, args=())] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        proc_group_result = self._get_result(proc_context)
        ret_vals = proc_group_result.return_values
        self.assertEqual(4, len(ret_vals))
        for ret_val in ret_vals.values():
            self.assertEqual(None, ret_val)

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_huge_output(self):
        # python multiprocessing.queue module uses pipes and actually PipedQueues
        # This means that if a single object is greater than a pipe size
        # the writer process will block until reader process will start
        # reading the pipe.
        # This test makes a worker fn to return huge output, around ~10 MB
        nprocs = 4
        size = 200000
        params = [MpParameters(fn=fill_dict, args=(size,))] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        proc_group_result = self._get_result(proc_context)
        ret_vals = proc_group_result.return_values
        self.assertEqual(4, len(ret_vals))
        for ret_val in ret_vals.values():
            self.assertEqual(size, len(ret_val))

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_run_failure(self):
        os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log"
        _process_error_handler.configure()
        nprocs = 4
        params = [MpParameters(fn=run_failure, args=())] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        proc_group_result = self._get_result(proc_context)
        failed_result = proc_group_result.failure
        self.assertTrue(os.path.exists(failed_result.error_file))
        with open(failed_result.error_file, "r") as f:
            data = f.read().replace("\n", "")
        self.assertTrue("RuntimeError: Test error" in data)
        _process_error_handler.cleanup()

    def _get_result(self, proc_context) -> ProcessGroupResult:
        proc_group_result = proc_context.wait()
        while not proc_group_result:
            proc_group_result = proc_context.wait()
        return proc_group_result

    @unittest.skipIf(is_tsan(), "test incompatible with tsan")
    def test_failure_signal(self):
        os.environ["TORCHELASTIC_ERROR_FILE"] = f"{self.test_dir}/error.log"
        _process_error_handler.configure()
        nprocs = 5
        params = [MpParameters(fn=run_failure_signal, args=())] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        # Processes should terminate with SIGSEGV
        proc_group_result = proc_context.wait()
        failure = proc_group_result.failure
        self.assertTrue(os.path.exists(failure.error_file))
        self.assertEqual("SIGSEGV", failure.get_signal_name())
        with open(failure.error_file, "r") as f:
            data = f.read().replace("\n", "")
        self.assertTrue("string_at" in data)
        _process_error_handler.cleanup()

    def test_wait_busy_loop(self):
        nprocs = 2
        wait_time = 10  # seconds
        params = [MpParameters(fn=run_with_wait, args=(wait_time,))] * nprocs
        proc_context = start_processes(params, start_method="spawn")
        self.assertIsNone(proc_context.wait(1))
        while proc_context.wait(1) is None:
            pass

    def test_wrap_fn(self):
        nprocs = 2
        start_method = "spawn"
        out_queues: Dict[int, mp.SimpleQueue] = {
            i: mp.get_context(start_method).SimpleQueue() for i in range(0, nprocs)
        }
        error_files = ["error.log" for _ in range(0, nprocs)]
        params = [MpParameters(fn=run_compute, args=(1,))] * nprocs
        for idx in range(nprocs):
            _wrap(idx, error_files, params, out_queues)
        for idx, out_q in out_queues.items():
            self.assertFalse(out_q.empty(), "out queue should not be empty")
            self.assertEqual(idx, out_q.get())