コード例 #1
0
def _get_by_runner_type(identifier,
                        user_entry_point=None,
                        args=None,
                        env_vars=None,
                        extra_opts=None):
    env = environment.Environment()
    user_entry_point = user_entry_point or env.user_entry_point
    args = args or env.to_cmd_args()
    env_vars = env_vars or env.to_env_vars()
    mpi_args = extra_opts or {}

    # Default to single process for CPU
    default_processes_per_host = int(
        env.num_gpus) if int(env.num_gpus) > 0 else 1
    processes_per_host = _mpi_param_value(mpi_args, env,
                                          params.MPI_PROCESSES_PER_HOST,
                                          default_processes_per_host)

    if identifier is RunnerType.SMDataParallel and env.is_master:
        custom_mpi_options = _mpi_param_value(
            mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "")
        return smdataparallel.SMDataParallelRunner(
            user_entry_point,
            args,
            env_vars,
            processes_per_host,
            env.master_hostname,
            env.distribution_hosts,
            custom_mpi_options,
            env.network_interface_name,
        )
    elif identifier is RunnerType.SMDataParallel:
        return mpi.WorkerRunner(user_entry_point, args, env_vars,
                                processes_per_host, env.master_hostname)
    elif identifier is RunnerType.MPI and env.is_master:
        num_processes = _mpi_param_value(mpi_args, env,
                                         params.MPI_NUM_PROCESSES)
        custom_mpi_options = _mpi_param_value(mpi_args, env,
                                              params.MPI_CUSTOM_OPTIONS, "")
        return mpi.MasterRunner(
            user_entry_point,
            args,
            env_vars,
            processes_per_host,
            env.master_hostname,
            env.distribution_hosts,
            custom_mpi_options,
            env.network_interface_name,
            num_processes=num_processes,
        )
    elif identifier is RunnerType.MPI:
        return mpi.WorkerRunner(user_entry_point, args, env_vars,
                                processes_per_host, env.master_hostname)
    elif identifier is RunnerType.Process:
        return process.ProcessRunner(user_entry_point, args, env_vars,
                                     processes_per_host)
    else:
        raise ValueError("Invalid identifier %s" % identifier)
コード例 #2
0
def test_smdataparallel_run_multi_node_python(
    training_env, popen, policy, ssh_client, python_executable, path_exists
):
    with patch.dict(os.environ, clear=True):
        hosts = ["algo-1", "algo-2"]
        master_hostname = hosts[0]
        num_hosts = len(hosts)
        num_processes_per_host = 8
        num_processes = num_processes_per_host * num_hosts
        host_list = ["{}:{}".format(host, num_processes_per_host) for host in hosts]
        network_interface_name = "ethw3"
        smdataparallel_server_addr = master_hostname
        smdataparallel_server_port = 7592
        smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1"

        smdataparallel_runner = smdataparallel.SMDataParallelRunner(
            user_entry_point="train.py",
            args=["-v", "--lr", "35"],
            env_vars={
                "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p3.16xlarge"}}'
            },
            master_hostname=master_hostname,
            hosts=hosts,
            network_interface_name=network_interface_name,
        )

        process = smdataparallel_runner.run(wait=False)

        ssh_client().load_system_host_keys.assert_called()
        ssh_client().set_missing_host_key_policy.assert_called_with(policy())
        ssh_client().connect.assert_called_with("algo-2", port=22)
        ssh_client().close.assert_called()

        popen.assert_called_with(
            [
                "mpirun",
                "--host",
                ",".join(host_list),
                "-np",
                str(num_processes),
                "--allow-run-as-root",
                "--tag-output",
                "--oversubscribe",
                "-mca",
                "btl_tcp_if_include",
                network_interface_name,
                "-mca",
                "oob_tcp_if_include",
                network_interface_name,
                "-mca",
                "plm_rsh_no_tree_spawn",
                "1",
                "-mca",
                "pml",
                "ob1",
                "-mca",
                "btl",
                "^openib",
                "-mca",
                "orte_abort_on_non_zero_status",
                "1",
                "-mca",
                "plm_rsh_num_concurrent",
                str(num_hosts),
                "-x",
                "NCCL_SOCKET_IFNAME=%s" % network_interface_name,
                "-x",
                "LD_LIBRARY_PATH",
                "-x",
                "PATH",
                "-x",
                smdataparallel_flag,
                "-x",
                "FI_PROVIDER=sockets",
                "-x",
                "RDMAV_FORK_SAFE=1",
                "-x",
                "LD_PRELOAD=%s" % inspect.getfile(gethostname),
                "-x",
                "SMDATAPARALLEL_SERVER_ADDR=%s" % smdataparallel_server_addr,
                "-x",
                "SMDATAPARALLEL_SERVER_PORT=%s" % str(smdataparallel_server_port),
                "-x",
                "SAGEMAKER_INSTANCE_TYPE=ml.p3.16xlarge",
                "smddprun",
                "usr/bin/python3",
                "-m",
                "mpi4py",
                "train.py",
                "-v",
                "--lr",
                "35",
            ],
            cwd=environment.code_dir,
            env=ANY,
            stderr=None,
        )

        assert process == popen()
        path_exists.assert_called_with("/usr/sbin/sshd")
コード例 #3
0
def test_smdataparallel_run_single_node_python(
    training_env, popen, policy, ssh_client, python_executable, path_exists
):
    with patch.dict(os.environ, clear=True):
        hosts = ["algo-1"]
        master_hostname = hosts[0]
        num_hosts = len(hosts)
        num_processes_per_host = 8
        num_processes = num_processes_per_host * num_hosts
        host_list = hosts
        network_interface_name = "ethw3"
        smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1"

        smdataparallel_runner = smdataparallel.SMDataParallelRunner(
            user_entry_point="train.py",
            args=["-v", "--lr", "35"],
            env_vars={},
            master_hostname=master_hostname,
            hosts=hosts,
            network_interface_name=network_interface_name,
        )

        process = smdataparallel_runner.run(wait=False)

        popen.assert_called_with(
            [
                "mpirun",
                "--host",
                ",".join(host_list),
                "-np",
                str(num_processes),
                "--allow-run-as-root",
                "--tag-output",
                "--oversubscribe",
                "-mca",
                "btl_tcp_if_include",
                network_interface_name,
                "-mca",
                "oob_tcp_if_include",
                network_interface_name,
                "-mca",
                "plm_rsh_no_tree_spawn",
                "1",
                "-mca",
                "pml",
                "ob1",
                "-mca",
                "btl",
                "^openib",
                "-mca",
                "orte_abort_on_non_zero_status",
                "1",
                "-mca",
                "plm_rsh_num_concurrent",
                str(num_hosts),
                "-x",
                "NCCL_SOCKET_IFNAME=%s" % network_interface_name,
                "-x",
                "LD_LIBRARY_PATH",
                "-x",
                "PATH",
                "-x",
                smdataparallel_flag,
                "-x",
                "FI_PROVIDER=sockets",
                "-x",
                "RDMAV_FORK_SAFE=1",
                "-x",
                "LD_PRELOAD=%s" % inspect.getfile(gethostname),
                "smddprun",
                "usr/bin/python3",
                "-m",
                "mpi4py",
                "train.py",
                "-v",
                "--lr",
                "35",
            ],
            cwd=environment.code_dir,
            env=ANY,
            stderr=None,
        )

        assert process == popen()
        path_exists.assert_called_with("/usr/sbin/sshd")
コード例 #4
0
def test_hc_smdataparallel_run_single_node_python(
    training_env,
    async_shell,
    policy,
    ssh_client,
    python_executable,
    path_exists,
    async_gather,
    event_loop,
):
    with patch.dict(os.environ, clear=True):
        hosts = ["algo-1"]
        master_hostname = hosts[0]
        num_hosts = len(hosts)
        num_processes_per_host = 8
        num_processes = num_processes_per_host * num_hosts
        host_list = hosts
        network_interface_name = "ethw3"
        smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1"

        smdataparallel_runner = smdataparallel.SMDataParallelRunner(
            user_entry_point="train.py",
            args=["-v", "--lr", "35"],
            env_vars={
                "SM_TRAINING_ENV":
                '{"additional_framework_parameters":{"sagemaker_distributed_dataparallel_enabled":"true"},\
                "current_instance_type": "ml.p4d.24xlarge"}'
            },
            processes_per_host=num_processes_per_host,
            master_hostname=master_hostname,
            hosts=hosts,
            custom_mpi_options="--verbose",
            network_interface_name=network_interface_name,
        )

        _, _, process = smdataparallel_runner.run(wait=False)
        cmd = [
            "mpirun",
            "--host",
            ",".join(host_list),
            "-np",
            str(num_processes),
            "--allow-run-as-root",
            "--tag-output",
            "--oversubscribe",
            "-mca",
            "btl_tcp_if_include",
            network_interface_name,
            "-mca",
            "oob_tcp_if_include",
            network_interface_name,
            "-mca",
            "plm_rsh_no_tree_spawn",
            "1",
            "-mca",
            "pml",
            "ob1",
            "-mca",
            "btl",
            "^openib",
            "-mca",
            "orte_abort_on_non_zero_status",
            "1",
            "-mca",
            "btl_vader_single_copy_mechanism",
            "none",
            "-mca",
            "plm_rsh_num_concurrent",
            str(num_hosts),
            "-x",
            "NCCL_SOCKET_IFNAME=%s" % network_interface_name,
            "-x",
            "NCCL_DEBUG=INFO",
            "-x",
            "LD_LIBRARY_PATH",
            "-x",
            "PATH",
            "-x",
            smdataparallel_flag,
            "-x",
            "FI_PROVIDER=efa",
            "-x",
            "RDMAV_FORK_SAFE=1",
            "-x",
            "LD_PRELOAD=%s" % inspect.getfile(gethostname),
            "--verbose",
            "-x",
            "FI_EFA_USE_DEVICE_RDMA=1",
            "smddprun",
            "usr/bin/python3",
            "-m",
            "mpi4py",
            "train.py",
            "-v",
            "--lr",
            "35",
        ]
        async_shell.assert_called_with(
            " ".join(cmd),
            cwd=environment.code_dir,
            env=ANY,
            stdout=asyncio.subprocess.PIPE,
            stderr=None,
        )
        async_shell.assert_called_once()
        async_gather.assert_called_once()
        assert process == async_shell.return_value
        path_exists.assert_called_with("/usr/sbin/sshd")