Esempio n. 1
0
    def test_invalid_log_dir(self):
        with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
            cases = {
                "does_not_exist":
                FileNotFoundError,
                not_a_dir.name:
                NotADirectoryError,
                # test_dir is not empty since we touched not_a_dir file
                self.test_dir:
                RuntimeError,
            }

            for (log_dir, expected_error) in cases.items():
                with self.subTest(log_dir=log_dir,
                                  expected_error=expected_error):
                    with self.assertRaises(expected_error):
                        start_processes(
                            name="echo",
                            entrypoint=echo1,
                            args={0: ("hello", )},
                            envs={0: {
                                "RANK": "0"
                            }},
                            log_dir=log_dir,
                        )
Esempio n. 2
0
    def test_args_env_len_mismatch(self):
        cases = [
            # 1 x args; 2 x envs
            {
                "args": {0: ("hello",)},
                "envs": {0: {"RANK": "0"}, 1: {"RANK": "1"}},
            },
            # 2 x args; 1 x envs
            {
                "args": {0: ("hello",), 1: ("world",)},
                "envs": {0: {"RANK": "0"}},
            },
        ]

        for kwds in cases:
            args = kwds["args"]
            envs = kwds["envs"]
            with self.subTest(args=args, envs=envs):
                with self.assertRaises(RuntimeError):
                    start_processes(
                        name="echo",
                        entrypoint=echo1,
                        args=args,
                        envs=envs,
                        log_dir=self.log_dir(),
                    )
Esempio n. 3
0
 def test_binary_incorrect_entrypoint(self):
     with self.assertRaises(FileNotFoundError):
         start_processes(
             name="echo",
             entrypoint="does_not_exist.py",
             args={0: ("foo"), 1: ("bar",)},
             envs={0: {}, 1: {}},
             log_dir=self.log_dir(),
         )
Esempio n. 4
0
    def test_binary_exit(self):
        FAIL = 138
        pc = start_processes(
            name="echo",
            entrypoint=self.bin("echo1.py"),
            args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
            envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
            log_dir=self.log_dir(),
            redirects={0: Std.ALL},
        )

        results = pc.wait(period=0.1)

        self.assertTrue(results.is_failed())
        self.assertEqual(1, len(results.failures))

        failure = results.failures[0]
        self.assertEqual(138, failure.exitcode)
        self.assertEqual("<N/A>", failure.signal_name())
        self.assertEqual("<NONE>", failure.error_file_data["message"])
        self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0])
        self.assert_in_file([], results.stdouts[0])
        self.assertFalse(results.stderrs[1])
        self.assertFalse(results.stdouts[1])
        self.assertTrue(pc._stderr_tail.stopped())
        self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 5
0
    def test_binary(self):
        for redirs in redirects():
            with self.subTest(redirs=redirs):
                pc = start_processes(
                    name="echo",
                    entrypoint=self.bin("echo1.py"),
                    args={0: ("hello",), 1: ("hello",)},
                    envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
                    log_dir=self.log_dir(),
                    redirects=redirs,
                )

                results = pc.wait(period=0.1)

                self.assert_pids_noexist(pc.pids())
                # binaries are not functions hence do not have return values
                self.assertEqual(0, len(results.return_values))
                self.assertFalse(results.is_failed())

                nprocs = pc.nprocs
                for i in range(nprocs):
                    if redirs & Std.OUT != Std.OUT:
                        self.assertFalse(results.stdouts[i])
                    if redirs & Std.ERR != Std.ERR:
                        self.assertFalse(results.stderrs[i])
                    if redirs & Std.OUT == Std.OUT:
                        self.assert_in_file(
                            [f"hello stdout from {i}"], results.stdouts[i]
                        )
                    if redirs & Std.ERR == Std.ERR:
                        self.assert_in_file(
                            [f"hello stderr from {i}"], results.stderrs[i]
                        )
Esempio n. 6
0
    def test_function_signal(self):
        """
        run 2x copies of echo3, induce a segfault on first
        """
        SEGFAULT = True
        for start_method, redirs in product(start_methods(), redirects()):
            with self.subTest(start_method=start_method):
                log_dir = self.log_dir()
                pc = start_processes(
                    name="echo",
                    entrypoint=echo3,
                    args={0: ("hello", SEGFAULT), 1: ("world",)},
                    envs={0: {}, 1: {}},
                    log_dir=log_dir,
                    start_method=start_method,
                    redirects=redirs,
                )

                results = pc.wait(period=0.1)

                self.assert_pids_noexist(pc.pids())
                self.assertEqual(1, len(results.failures))
                self.assertFalse(results.return_values)

                failure = results.failures[0]
                error_file = failure.error_file

                self.assertEqual(-signal.SIGSEGV, failure.exitcode)
                self.assertEqual("SIGSEGV", failure.signal_name())
                self.assertEqual(pc.pids()[0], failure.pid)
                self.assertEqual(os.path.join(log_dir, "0", "error.json"), error_file)
Esempio n. 7
0
    def test_function(self):
        for start_method, redirs in product(start_methods(), redirects()):
            with self.subTest(start_method=start_method, redirs=redirs):
                pc = start_processes(
                    name="echo",
                    entrypoint=echo1,
                    args={0: ("hello",), 1: ("hello",)},
                    envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
                    log_dir=self.log_dir(),
                    start_method=start_method,
                    redirects=redirs,
                )

                results = pc.wait(period=0.1)
                nprocs = pc.nprocs

                self.assert_pids_noexist(pc.pids())
                self.assertEqual(
                    {i: f"hello_{i}" for i in range(nprocs)}, results.return_values
                )

                for i in range(nprocs):
                    if redirs & Std.OUT != Std.OUT:
                        self.assertFalse(results.stdouts[i])
                    if redirs & Std.ERR != Std.ERR:
                        self.assertFalse(results.stderrs[i])
                    if redirs & Std.OUT == Std.OUT:
                        self.assert_in_file(
                            [f"hello stdout from {i}"], results.stdouts[i]
                        )
                    if redirs & Std.ERR == Std.ERR:
                        self.assert_in_file(
                            [f"hello stderr from {i}"], results.stderrs[i]
                        )
Esempio n. 8
0
    def test_binary_signal(self):
        pc = start_processes(
            name="echo",
            entrypoint=self.bin("echo3.py"),
            args={
                0: ("--segfault", "true", "foo"),
                1: ("bar", )
            },
            envs={
                0: {
                    "RANK": "0"
                },
                1: {
                    "RANK": "1"
                }
            },
            log_dir=self.log_dir(),
        )

        results = pc.wait(period=0.1)

        self.assert_pids_noexist(pc.pids())
        self.assertTrue(results.is_failed())
        self.assertEqual(1, len(results.failures))

        failure = results.failures[0]
        self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
        self.assertEqual("SIGSEGV", failure.signal_name())
        self.assertEqual("<NONE>", failure.error_file_data["message"])
Esempio n. 9
0
 def _start_mp(
     self, dist_infos: Dict[int, _DistInfo], spec: WorkerSpec
 ) -> BaseProcessContext:
     proc_params = [
         MpParameters(fn=_wrap, args=(dist_infos, spec.fn, spec.args))
     ] * spec.local_world_size
     return start_processes(proc_params, start_method=self._start_method)
Esempio n. 10
0
    def test_function_large_ret_val(self):
        # python multiprocessing.queue module uses pipes and actually PipedQueues
        # This means that if a single object is greater than a pipe size
        # the writer process will block until reader process will start
        # reading the pipe.
        # This test makes a worker fn to return huge output, around ~10 MB

        size = 200000
        for start_method in start_methods():
            with self.subTest(start_method=start_method):
                pc = start_processes(
                    name="echo",
                    entrypoint=echo_large,
                    args={
                        0: (size, ),
                        1: (size, ),
                        2: (size, ),
                        3: (size, )
                    },
                    envs={
                        0: {},
                        1: {},
                        2: {},
                        3: {}
                    },
                    log_dir=self.log_dir(),
                    start_method=start_method,
                )

                results = pc.wait(period=0.1)
                for i in range(pc.nprocs):
                    self.assertEqual(size, len(results.return_values[i]))
Esempio n. 11
0
    def test_function_exit(self):
        """
        run 2x copies of echo1 fail (exit) the first
        functions that exit from python do not generate an error file
        (even if they are decorated with @record)
        """

        FAIL = 138
        for start_method in start_methods():
            with self.subTest(start_method=start_method):
                log_dir = self.log_dir()
                pc = start_processes(
                    name="echo",
                    entrypoint=echo1,
                    args={
                        0: ("hello", FAIL),
                        1: ("hello", )
                    },
                    envs={
                        0: {
                            "RANK": "0"
                        },
                        1: {
                            "RANK": "1"
                        }
                    },
                    log_dir=log_dir,
                    start_method=start_method,
                    redirects={0: Std.ERR},
                )

                results = pc.wait(period=0.1)

                self.assert_pids_noexist(pc.pids())
                self.assertTrue(results.is_failed())
                self.assertEqual(1, len(results.failures))
                self.assertFalse(results.return_values)

                failure = results.failures[0]
                error_file = failure.error_file

                self.assertEqual(FAIL, failure.exitcode)
                self.assertEqual("<N/A>", failure.signal_name())
                self.assertEqual(pc.pids()[0], failure.pid)
                self.assertEqual("<N/A>", error_file)
                self.assertEqual(f"Process failed with exitcode {FAIL}",
                                 failure.message)
                self.assertLessEqual(failure.timestamp, int(time.time()))

                self.assert_in_file([f"exit {FAIL} from 0"],
                                    results.stderrs[0])
                self.assertFalse(results.stdouts[0])
                self.assertFalse(results.stderrs[1])
                self.assertFalse(results.stdouts[1])
                self.assertTrue(pc._stderr_tail.stopped())
                self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 12
0
    def test_pcontext_wait(self):
        pc = start_processes(
            name="sleep",
            entrypoint=time.sleep,
            args={0: (1, )},
            envs={0: {}},
            log_dir=self.log_dir(),
            start_method="fork",
        )

        self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
        self.assertIsNotNone(pc.wait(period=0.1))
        self.assertTrue(pc._stderr_tail.stopped())
        self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 13
0
    def test_void_function(self):
        for start_method in start_methods():
            with self.subTest(start_method=start_method):
                pc = start_processes(
                    name="echo",
                    entrypoint=echo0,
                    args={0: ("hello",), 1: ("world",)},
                    envs={0: {}, 1: {}},
                    log_dir=self.log_dir(),
                    start_method=start_method,
                )

                results = pc.wait(period=0.1)
                self.assertEqual({0: None, 1: None}, results.return_values)
Esempio n. 14
0
    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
        spec = worker_group.spec
        store = worker_group.store
        master_addr, master_port = super()._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts

        args: Dict[int, Tuple] = {}
        envs: Dict[int, Dict[str, str]] = {}
        for worker in worker_group.workers:
            local_rank = worker.local_rank
            worker_env = {
                "LOCAL_RANK": str(local_rank),
                "RANK": str(worker.global_rank),
                "GROUP_RANK": str(worker_group.group_rank),
                "ROLE_RANK": str(worker.role_rank),
                "ROLE_NAME": spec.role,
                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
                "WORLD_SIZE": str(worker.world_size),
                "ROLE_WORLD_SIZE": str(worker.role_world_size),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": str(master_port),
                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
                "NCCL_ASYNC_ERROR_HANDLING": str(1),
            }
            if "OMP_NUM_THREADS" in os.environ:
                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
            envs[local_rank] = worker_env
            args[local_rank] = spec.args

        # scaling events do not count towards restarts (gets same attempt #)
        # remove existing log dir if this restart is due to a scaling event
        attempt_log_dir = os.path.join(self._log_dir,
                                       f"attempt_{restart_count}")
        shutil.rmtree(attempt_log_dir, ignore_errors=True)
        os.makedirs(attempt_log_dir)

        self._pcontext = start_processes(
            name=spec.role,
            entrypoint=spec.entrypoint,
            args=args,
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

        return self._pcontext.pids()
Esempio n. 15
0
    def test_multiprocess_context_close(self):
        pc = start_processes(
            name="sleep",
            entrypoint=time.sleep,
            args={0: (1, )},
            envs={0: {}},
            log_dir=self.log_dir(),
            start_method="fork",
        )

        pids = pc.pids()
        pc.close()
        self.assert_pids_noexist(pids)
        self.assertTrue(pc._stderr_tail.stopped())
        self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 16
0
    def test_function_raise(self):
        """
        run 2x copies of echo2, raise an exception on the first
        """
        RAISE = True

        for start_method in start_methods():
            with self.subTest(start_method=start_method):
                log_dir = self.log_dir()
                pc = start_processes(
                    name="echo",
                    entrypoint=echo2,
                    args={
                        0: ("hello", RAISE),
                        1: ("world", )
                    },
                    envs={
                        0: {},
                        1: {}
                    },
                    log_dir=log_dir,
                    start_method=start_method,
                )

                results = pc.wait(period=0.1)

                self.assert_pids_noexist(pc.pids())
                self.assertEqual(1, len(results.failures))
                self.assertFalse(results.return_values)

                failure = results.failures[0]
                error_file = failure.error_file
                error_file_data = failure.error_file_data

                self.assertEqual(1, failure.exitcode)
                self.assertEqual("<N/A>", failure.signal_name())
                self.assertEqual(pc.pids()[0], failure.pid)
                self.assertEqual(os.path.join(log_dir, "0", "error.json"),
                                 error_file)
                self.assertEqual(
                    int(error_file_data["message"]["extraInfo"]["timestamp"]),
                    int(failure.timestamp),
                )
                self.assertTrue(pc._stderr_tail.stopped())
                self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 17
0
    def test_binary_redirect_and_tee(self):
        pc = start_processes(
            name="trainer",
            entrypoint=self.bin("echo1.py"),
            args={0: ("hello",), 1: ("world",)},
            envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
            log_dir=self.log_dir(),
            start_method="fork",
            redirects={0: Std.ERR, 1: Std.NONE},
            tee={0: Std.OUT, 1: Std.ERR},
        )

        result = pc.wait()

        self.assertFalse(result.is_failed())
        self.assert_in_file(["hello stdout from 0"], pc.stdouts[0])
        self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
        self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
        self.assertFalse(pc.stdouts[1])
        self.assertTrue(pc._stderr_tail.stopped())
        self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 18
0
    def test_function_redirect_and_tee(self):
        for start_method in start_methods():
            with self.subTest(start_method=start_method):
                log_dir = self.log_dir()
                pc = start_processes(
                    name="trainer",
                    entrypoint=echo1,
                    args={
                        0: ("hello", ),
                        1: ("world", )
                    },
                    envs={
                        0: {
                            "RANK": "0"
                        },
                        1: {
                            "RANK": "1"
                        }
                    },
                    log_dir=log_dir,
                    start_method="fork",
                    redirects={
                        0: Std.ERR,
                        1: Std.NONE
                    },
                    tee={
                        0: Std.OUT,
                        1: Std.ERR
                    },
                )

                result = pc.wait()

                self.assertFalse(result.is_failed())
                self.assert_in_file(["hello stdout from 0"], pc.stdouts[0])
                self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
                self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
                self.assertFalse(pc.stdouts[1])
                self.assertTrue(pc._stderr_tail.stopped())
                self.assertTrue(pc._stdout_tail.stopped())
Esempio n. 19
0
 def test_invoke_mp(self, mp_mock):
     params = [MpParameters(fn=dummy_fn, args=())] * 4
     start_processes(params, start_method="fork")
     mp_mock.assert_called_once_with(params, "fork")
Esempio n. 20
0
 def test_invoke_mp_no_params(self):
     with self.assertRaises(ValueError):
         start_processes([], start_method="fork")