Exemplo n.º 1
0
def _get_by_runner_type(identifier,
                        user_entry_point=None,
                        args=None,
                        env_vars=None,
                        extra_opts=None):
    env = environment.Environment()
    user_entry_point = user_entry_point or env.user_entry_point
    args = args or env.to_cmd_args()
    env_vars = env_vars or env.to_env_vars()
    mpi_args = extra_opts or {}

    # Default to single process for CPU
    default_processes_per_host = int(
        env.num_gpus) if int(env.num_gpus) > 0 else 1
    processes_per_host = _mpi_param_value(mpi_args, env,
                                          params.MPI_PROCESSES_PER_HOST,
                                          default_processes_per_host)

    if identifier is RunnerType.SMDataParallel and env.is_master:
        custom_mpi_options = _mpi_param_value(
            mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "")
        return smdataparallel.SMDataParallelRunner(
            user_entry_point,
            args,
            env_vars,
            processes_per_host,
            env.master_hostname,
            env.distribution_hosts,
            custom_mpi_options,
            env.network_interface_name,
        )
    elif identifier is RunnerType.SMDataParallel:
        return mpi.WorkerRunner(user_entry_point, args, env_vars,
                                processes_per_host, env.master_hostname)
    elif identifier is RunnerType.MPI and env.is_master:
        num_processes = _mpi_param_value(mpi_args, env,
                                         params.MPI_NUM_PROCESSES)
        custom_mpi_options = _mpi_param_value(mpi_args, env,
                                              params.MPI_CUSTOM_OPTIONS, "")
        return mpi.MasterRunner(
            user_entry_point,
            args,
            env_vars,
            processes_per_host,
            env.master_hostname,
            env.distribution_hosts,
            custom_mpi_options,
            env.network_interface_name,
            num_processes=num_processes,
        )
    elif identifier is RunnerType.MPI:
        return mpi.WorkerRunner(user_entry_point, args, env_vars,
                                processes_per_host, env.master_hostname)
    elif identifier is RunnerType.Process:
        return process.ProcessRunner(user_entry_point, args, env_vars,
                                     processes_per_host)
    else:
        raise ValueError("Invalid identifier %s" % identifier)
Exemplo n.º 2
0
def test_mpi_master_run_python(training_env, popen, policy, ssh_client,
                               python_executable, path_exists):
    with patch.dict(os.environ, clear=True):

        master = mpi.MasterRunner(
            user_entry_point="train.py",
            args=["-v", "--lr", "35"],
            env_vars={"LD_CONFIG_PATH": "/etc/ld"},
            master_hostname="algo-1",
            hosts=["algo-1", "algo-2"],
            process_per_host=2,
            custom_mpi_options="-v --lr 35",
            network_interface_name="ethw3",
        )

        process = master.run(wait=False)

        ssh_client().load_system_host_keys.assert_called()
        ssh_client().set_missing_host_key_policy.assert_called_with(policy())
        ssh_client().connect.assert_called_with("algo-2", port=22)
        ssh_client().close.assert_called()

        popen.assert_called_with(
            [
                "mpirun",
                "--host",
                "algo-1:2,algo-2:2",
                "-np",
                "4",
                "--allow-run-as-root",
                "--display-map",
                "--tag-output",
                "-mca",
                "btl_tcp_if_include",
                "ethw3",
                "-mca",
                "oob_tcp_if_include",
                "ethw3",
                "-mca",
                "plm_rsh_no_tree_spawn",
                "1",
                "-bind-to",
                "none",
                "-map-by",
                "slot",
                "-mca",
                "pml",
                "ob1",
                "-mca",
                "btl",
                "^openib",
                "-mca",
                "orte_abort_on_non_zero_status",
                "1",
                "-x",
                "NCCL_MIN_NRINGS=4",
                "-x",
                "NCCL_SOCKET_IFNAME=ethw3",
                "-x",
                "NCCL_DEBUG=INFO",
                "-x",
                "LD_LIBRARY_PATH",
                "-x",
                "PATH",
                "-x",
                "LD_PRELOAD=%s" % inspect.getfile(gethostname),
                "-v",
                "--lr",
                "35",
                "-x",
                "LD_CONFIG_PATH",
                "usr/bin/python3",
                "-m",
                "mpi4py",
                "train.py",
                "-v",
                "--lr",
                "35",
            ],
            cwd=environment.code_dir,
            env=ANY,
            stderr=None,
        )

        assert process == popen()
        path_exists.assert_called_with("/usr/sbin/sshd")
Exemplo n.º 3
0
def test_mpi_master_run(training_env, async_shell, policy, ssh_client,
                        path_exists, async_gather, event_loop):

    with patch.dict(os.environ, clear=True):
        os.environ["AWS_ACCESS_KEY_ID"] = "ABCD"
        master = mpi.MasterRunner(
            user_entry_point="train.sh",
            args=["-v", "--lr", "35"],
            env_vars={"LD_CONFIG_PATH": "/etc/ld"},
            processes_per_host=2,
            master_hostname="algo-1",
            hosts=["algo-1", "algo-2"],
            custom_mpi_options="-v --lr 35",
            network_interface_name="ethw3",
        )
        process = master.run(wait=False)

        ssh_client().load_system_host_keys.assert_called()
        ssh_client().set_missing_host_key_policy.assert_called_with(policy())
        ssh_client().connect.assert_called_with("algo-2", port=22)
        ssh_client().close.assert_called()
        cmd = [
            "mpirun",
            "--host",
            "algo-1:2,algo-2:2",
            "-np",
            "4",
            "--allow-run-as-root",
            "--display-map",
            "--tag-output",
            "-mca",
            "btl_tcp_if_include",
            "ethw3",
            "-mca",
            "oob_tcp_if_include",
            "ethw3",
            "-mca",
            "plm_rsh_no_tree_spawn",
            "1",
            "-bind-to",
            "none",
            "-map-by",
            "slot",
            "-mca",
            "pml",
            "ob1",
            "-mca",
            "btl",
            "^openib",
            "-mca",
            "orte_abort_on_non_zero_status",
            "1",
            "-mca",
            "btl_vader_single_copy_mechanism",
            "none",
            "-x",
            "NCCL_MIN_NRINGS=4",
            "-x",
            "NCCL_SOCKET_IFNAME=ethw3",
            "-x",
            "NCCL_DEBUG=INFO",
            "-x",
            "LD_LIBRARY_PATH",
            "-x",
            "PATH",
            "-x",
            "LD_PRELOAD=%s" % inspect.getfile(gethostname),
            "-v",
            "--lr",
            "35",
            "-x",
            "AWS_ACCESS_KEY_ID",
            "-x",
            "LD_CONFIG_PATH",
            "/bin/sh",
            "-c",
            '"./train.sh -v --lr 35"',
        ]
        extended_cmd = " ".join(cmd)
        async_shell.assert_called_with(
            extended_cmd,
            env=ANY,
            cwd=environment.code_dir,
            stdout=asyncio.subprocess.PIPE,
            stderr=None,
        )
        async_shell.assert_called_once()
        async_gather.assert_called_once()
        assert process == async_shell.return_value
        path_exists.assert_called_with("/usr/sbin/sshd")