def _get_by_runner_type(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None): env = environment.Environment() user_entry_point = user_entry_point or env.user_entry_point args = args or env.to_cmd_args() env_vars = env_vars or env.to_env_vars() mpi_args = extra_opts or {} # Default to single process for CPU default_processes_per_host = int( env.num_gpus) if int(env.num_gpus) > 0 else 1 processes_per_host = _mpi_param_value(mpi_args, env, params.MPI_PROCESSES_PER_HOST, default_processes_per_host) if identifier is RunnerType.SMDataParallel and env.is_master: custom_mpi_options = _mpi_param_value( mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "") return smdataparallel.SMDataParallelRunner( user_entry_point, args, env_vars, processes_per_host, env.master_hostname, env.distribution_hosts, custom_mpi_options, env.network_interface_name, ) elif identifier is RunnerType.SMDataParallel: return mpi.WorkerRunner(user_entry_point, args, env_vars, processes_per_host, env.master_hostname) elif identifier is RunnerType.MPI and env.is_master: num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES) custom_mpi_options = _mpi_param_value(mpi_args, env, params.MPI_CUSTOM_OPTIONS, "") return mpi.MasterRunner( user_entry_point, args, env_vars, processes_per_host, env.master_hostname, env.distribution_hosts, custom_mpi_options, env.network_interface_name, num_processes=num_processes, ) elif identifier is RunnerType.MPI: return mpi.WorkerRunner(user_entry_point, args, env_vars, processes_per_host, env.master_hostname) elif identifier is RunnerType.Process: return process.ProcessRunner(user_entry_point, args, env_vars, processes_per_host) else: raise ValueError("Invalid identifier %s" % identifier)
def test_mpi_master_run_python(training_env, popen, policy, ssh_client, python_executable, path_exists): with patch.dict(os.environ, clear=True): master = mpi.MasterRunner( user_entry_point="train.py", args=["-v", "--lr", "35"], env_vars={"LD_CONFIG_PATH": "/etc/ld"}, master_hostname="algo-1", hosts=["algo-1", "algo-2"], process_per_host=2, custom_mpi_options="-v --lr 35", network_interface_name="ethw3", ) process = master.run(wait=False) ssh_client().load_system_host_keys.assert_called() ssh_client().set_missing_host_key_policy.assert_called_with(policy()) ssh_client().connect.assert_called_with("algo-2", port=22) ssh_client().close.assert_called() popen.assert_called_with( [ "mpirun", "--host", "algo-1:2,algo-2:2", "-np", "4", "--allow-run-as-root", "--display-map", "--tag-output", "-mca", "btl_tcp_if_include", "ethw3", "-mca", "oob_tcp_if_include", "ethw3", "-mca", "plm_rsh_no_tree_spawn", "1", "-bind-to", "none", "-map-by", "slot", "-mca", "pml", "ob1", "-mca", "btl", "^openib", "-mca", "orte_abort_on_non_zero_status", "1", "-x", "NCCL_MIN_NRINGS=4", "-x", "NCCL_SOCKET_IFNAME=ethw3", "-x", "NCCL_DEBUG=INFO", "-x", "LD_LIBRARY_PATH", "-x", "PATH", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), "-v", "--lr", "35", "-x", "LD_CONFIG_PATH", "usr/bin/python3", "-m", "mpi4py", "train.py", "-v", "--lr", "35", ], cwd=environment.code_dir, env=ANY, stderr=None, ) assert process == popen() path_exists.assert_called_with("/usr/sbin/sshd")
def test_mpi_master_run(training_env, async_shell, policy, ssh_client, path_exists, async_gather, event_loop): with patch.dict(os.environ, clear=True): os.environ["AWS_ACCESS_KEY_ID"] = "ABCD" master = mpi.MasterRunner( user_entry_point="train.sh", args=["-v", "--lr", "35"], env_vars={"LD_CONFIG_PATH": "/etc/ld"}, processes_per_host=2, master_hostname="algo-1", hosts=["algo-1", "algo-2"], custom_mpi_options="-v --lr 35", network_interface_name="ethw3", ) process = master.run(wait=False) ssh_client().load_system_host_keys.assert_called() ssh_client().set_missing_host_key_policy.assert_called_with(policy()) ssh_client().connect.assert_called_with("algo-2", port=22) ssh_client().close.assert_called() cmd = [ "mpirun", "--host", "algo-1:2,algo-2:2", "-np", "4", "--allow-run-as-root", "--display-map", "--tag-output", "-mca", "btl_tcp_if_include", "ethw3", "-mca", "oob_tcp_if_include", "ethw3", "-mca", "plm_rsh_no_tree_spawn", "1", "-bind-to", "none", "-map-by", "slot", "-mca", "pml", "ob1", "-mca", "btl", "^openib", "-mca", "orte_abort_on_non_zero_status", "1", "-mca", "btl_vader_single_copy_mechanism", "none", "-x", "NCCL_MIN_NRINGS=4", "-x", "NCCL_SOCKET_IFNAME=ethw3", "-x", "NCCL_DEBUG=INFO", "-x", "LD_LIBRARY_PATH", "-x", "PATH", "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), "-v", "--lr", "35", "-x", "AWS_ACCESS_KEY_ID", "-x", "LD_CONFIG_PATH", "/bin/sh", "-c", '"./train.sh -v --lr 35"', ] extended_cmd = " ".join(cmd) async_shell.assert_called_with( extended_cmd, env=ANY, cwd=environment.code_dir, stdout=asyncio.subprocess.PIPE, stderr=None, ) async_shell.assert_called_once() async_gather.assert_called_once() assert process == async_shell.return_value path_exists.assert_called_with("/usr/sbin/sshd")