def _get_by_runner_type(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None): env = environment.Environment() user_entry_point = user_entry_point or env.user_entry_point args = args or env.to_cmd_args() env_vars = env_vars or env.to_env_vars() mpi_args = extra_opts or {} # Default to single process for CPU default_processes_per_host = int( env.num_gpus) if int(env.num_gpus) > 0 else 1 processes_per_host = _mpi_param_value(mpi_args, env, params.MPI_PROCESSES_PER_HOST, default_processes_per_host) if identifier is RunnerType.SMDataParallel and env.is_master: custom_mpi_options = _mpi_param_value( mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "") return smdataparallel.SMDataParallelRunner( user_entry_point, args, env_vars, processes_per_host, env.master_hostname, env.distribution_hosts, custom_mpi_options, env.network_interface_name, ) elif identifier is RunnerType.SMDataParallel: return mpi.WorkerRunner(user_entry_point, args, env_vars, processes_per_host, env.master_hostname) elif identifier is RunnerType.MPI and env.is_master: num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES) custom_mpi_options = _mpi_param_value(mpi_args, env, params.MPI_CUSTOM_OPTIONS, "") return mpi.MasterRunner( user_entry_point, args, env_vars, processes_per_host, env.master_hostname, env.distribution_hosts, custom_mpi_options, env.network_interface_name, num_processes=num_processes, ) elif identifier is RunnerType.MPI: return mpi.WorkerRunner(user_entry_point, args, env_vars, processes_per_host, env.master_hostname) elif identifier is RunnerType.Process: return process.ProcessRunner(user_entry_point, args, env_vars, processes_per_host) else: raise ValueError("Invalid identifier %s" % identifier)
def test_mpi_worker_run(popen, policy, process_iter, wait_procs, ssh_client, sleep, path_exists, write_env_vars): process = MagicMock(info={"name": "orted"}) process_iter.side_effect = lambda attrs: [process] worker = mpi.WorkerRunner( user_entry_point="train.sh", args=["-v", "--lr", "35"], env_vars={"LD_CONFIG_PATH": "/etc/ld"}, master_hostname="algo-1", ) worker.run() write_env_vars.assert_called_once() ssh_client().load_system_host_keys.assert_called() ssh_client().set_missing_host_key_policy.assert_called_with(policy()) ssh_client().connect.assert_called_with("algo-1", port=22) ssh_client().close.assert_called() wait_procs.assert_called_with([process]) popen.assert_called_with(["/usr/sbin/sshd", "-D"]) path_exists.assert_called_with("/usr/sbin/sshd")
def test_mpi_worker_run_no_wait(popen, ssh_client, path_exists): worker = mpi.WorkerRunner( user_entry_point="train.sh", args=["-v", "--lr", "35"], env_vars={"LD_CONFIG_PATH": "/etc/ld"}, master_hostname="algo-1", ) worker.run(wait=False) ssh_client.assert_not_called() popen.assert_called_with(["/usr/sbin/sshd", "-D"]) path_exists.assert_called_with("/usr/sbin/sshd")