def _get_by_runner_type(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None): env = environment.Environment() user_entry_point = user_entry_point or env.user_entry_point args = args or env.to_cmd_args() env_vars = env_vars or env.to_env_vars() mpi_args = extra_opts or {} # Default to single process for CPU default_processes_per_host = int( env.num_gpus) if int(env.num_gpus) > 0 else 1 processes_per_host = _mpi_param_value(mpi_args, env, params.MPI_PROCESSES_PER_HOST, default_processes_per_host) if identifier is RunnerType.SMDataParallel and env.is_master: custom_mpi_options = _mpi_param_value( mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "") return smdataparallel.SMDataParallelRunner( user_entry_point, args, env_vars, processes_per_host, env.master_hostname, env.distribution_hosts, custom_mpi_options, env.network_interface_name, ) elif identifier is RunnerType.SMDataParallel: return mpi.WorkerRunner(user_entry_point, args, env_vars, processes_per_host, env.master_hostname) elif identifier is RunnerType.MPI and env.is_master: num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES) custom_mpi_options = _mpi_param_value(mpi_args, env, params.MPI_CUSTOM_OPTIONS, "") return mpi.MasterRunner( user_entry_point, args, env_vars, processes_per_host, env.master_hostname, env.distribution_hosts, custom_mpi_options, env.network_interface_name, num_processes=num_processes, ) elif identifier is RunnerType.MPI: return mpi.WorkerRunner(user_entry_point, args, env_vars, processes_per_host, env.master_hostname) elif identifier is RunnerType.Process: return process.ProcessRunner(user_entry_point, args, env_vars, processes_per_host) else: raise ValueError("Invalid identifier %s" % identifier)
def test_smdataparallel_run_multi_node_python( training_env, popen, policy, ssh_client, python_executable, path_exists ): with patch.dict(os.environ, clear=True): hosts = ["algo-1", "algo-2"] master_hostname = hosts[0] num_hosts = len(hosts) num_processes_per_host = 8 num_processes = num_processes_per_host * num_hosts host_list = ["{}:{}".format(host, num_processes_per_host) for host in hosts] network_interface_name = "ethw3" smdataparallel_server_addr = master_hostname smdataparallel_server_port = 7592 smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1" smdataparallel_runner = smdataparallel.SMDataParallelRunner( user_entry_point="train.py", args=["-v", "--lr", "35"], env_vars={ "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p3.16xlarge"}}' }, master_hostname=master_hostname, hosts=hosts, network_interface_name=network_interface_name, ) process = smdataparallel_runner.run(wait=False) ssh_client().load_system_host_keys.assert_called() ssh_client().set_missing_host_key_policy.assert_called_with(policy()) ssh_client().connect.assert_called_with("algo-2", port=22) ssh_client().close.assert_called() popen.assert_called_with( [ "mpirun", "--host", ",".join(host_list), "-np", str(num_processes), "--allow-run-as-root", "--tag-output", "--oversubscribe", "-mca", "btl_tcp_if_include", network_interface_name, "-mca", "oob_tcp_if_include", network_interface_name, "-mca", "plm_rsh_no_tree_spawn", "1", "-mca", "pml", "ob1", "-mca", "btl", "^openib", "-mca", "orte_abort_on_non_zero_status", "1", "-mca", "plm_rsh_num_concurrent", str(num_hosts), "-x", "NCCL_SOCKET_IFNAME=%s" % network_interface_name, "-x", "LD_LIBRARY_PATH", "-x", "PATH", "-x", smdataparallel_flag, "-x", "FI_PROVIDER=sockets", "-x", "RDMAV_FORK_SAFE=1", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), "-x", "SMDATAPARALLEL_SERVER_ADDR=%s" % smdataparallel_server_addr, "-x", "SMDATAPARALLEL_SERVER_PORT=%s" % str(smdataparallel_server_port), "-x", "SAGEMAKER_INSTANCE_TYPE=ml.p3.16xlarge", "smddprun", "usr/bin/python3", "-m", "mpi4py", "train.py", "-v", "--lr", "35", ], cwd=environment.code_dir, env=ANY, stderr=None, ) assert process == popen() path_exists.assert_called_with("/usr/sbin/sshd")
def test_smdataparallel_run_single_node_python( training_env, popen, policy, ssh_client, python_executable, path_exists ): with patch.dict(os.environ, clear=True): hosts = ["algo-1"] master_hostname = hosts[0] num_hosts = len(hosts) num_processes_per_host = 8 num_processes = num_processes_per_host * num_hosts host_list = hosts network_interface_name = "ethw3" smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1" smdataparallel_runner = smdataparallel.SMDataParallelRunner( user_entry_point="train.py", args=["-v", "--lr", "35"], env_vars={}, master_hostname=master_hostname, hosts=hosts, network_interface_name=network_interface_name, ) process = smdataparallel_runner.run(wait=False) popen.assert_called_with( [ "mpirun", "--host", ",".join(host_list), "-np", str(num_processes), "--allow-run-as-root", "--tag-output", "--oversubscribe", "-mca", "btl_tcp_if_include", network_interface_name, "-mca", "oob_tcp_if_include", network_interface_name, "-mca", "plm_rsh_no_tree_spawn", "1", "-mca", "pml", "ob1", "-mca", "btl", "^openib", "-mca", "orte_abort_on_non_zero_status", "1", "-mca", "plm_rsh_num_concurrent", str(num_hosts), "-x", "NCCL_SOCKET_IFNAME=%s" % network_interface_name, "-x", "LD_LIBRARY_PATH", "-x", "PATH", "-x", smdataparallel_flag, "-x", "FI_PROVIDER=sockets", "-x", "RDMAV_FORK_SAFE=1", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), "smddprun", "usr/bin/python3", "-m", "mpi4py", "train.py", "-v", "--lr", "35", ], cwd=environment.code_dir, env=ANY, stderr=None, ) assert process == popen() path_exists.assert_called_with("/usr/sbin/sshd")
def test_hc_smdataparallel_run_single_node_python( training_env, async_shell, policy, ssh_client, python_executable, path_exists, async_gather, event_loop, ): with patch.dict(os.environ, clear=True): hosts = ["algo-1"] master_hostname = hosts[0] num_hosts = len(hosts) num_processes_per_host = 8 num_processes = num_processes_per_host * num_hosts host_list = hosts network_interface_name = "ethw3" smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1" smdataparallel_runner = smdataparallel.SMDataParallelRunner( user_entry_point="train.py", args=["-v", "--lr", "35"], env_vars={ "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_distributed_dataparallel_enabled":"true"},\ "current_instance_type": "ml.p4d.24xlarge"}' }, processes_per_host=num_processes_per_host, master_hostname=master_hostname, hosts=hosts, custom_mpi_options="--verbose", network_interface_name=network_interface_name, ) _, _, process = smdataparallel_runner.run(wait=False) cmd = [ "mpirun", "--host", ",".join(host_list), "-np", str(num_processes), "--allow-run-as-root", "--tag-output", "--oversubscribe", "-mca", "btl_tcp_if_include", network_interface_name, "-mca", "oob_tcp_if_include", network_interface_name, "-mca", "plm_rsh_no_tree_spawn", "1", "-mca", "pml", "ob1", "-mca", "btl", "^openib", "-mca", "orte_abort_on_non_zero_status", "1", "-mca", "btl_vader_single_copy_mechanism", "none", "-mca", "plm_rsh_num_concurrent", str(num_hosts), "-x", "NCCL_SOCKET_IFNAME=%s" % network_interface_name, "-x", "NCCL_DEBUG=INFO", "-x", "LD_LIBRARY_PATH", "-x", "PATH", "-x", smdataparallel_flag, "-x", "FI_PROVIDER=efa", "-x", "RDMAV_FORK_SAFE=1", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), "--verbose", "-x", "FI_EFA_USE_DEVICE_RDMA=1", "smddprun", "usr/bin/python3", "-m", "mpi4py", "train.py", "-v", "--lr", "35", ] async_shell.assert_called_with( " ".join(cmd), cwd=environment.code_dir, env=ANY, stdout=asyncio.subprocess.PIPE, stderr=None, ) async_shell.assert_called_once() async_gather.assert_called_once() assert process == async_shell.return_value path_exists.assert_called_with("/usr/sbin/sshd")