コード例 #1
0
def _get_by_runner_type(identifier,
                        user_entry_point=None,
                        args=None,
                        env_vars=None,
                        extra_opts=None):
    env = environment.Environment()
    user_entry_point = user_entry_point or env.user_entry_point
    args = args or env.to_cmd_args()
    env_vars = env_vars or env.to_env_vars()
    mpi_args = extra_opts or {}

    # Default to single process for CPU
    default_processes_per_host = int(
        env.num_gpus) if int(env.num_gpus) > 0 else 1
    processes_per_host = _mpi_param_value(mpi_args, env,
                                          params.MPI_PROCESSES_PER_HOST,
                                          default_processes_per_host)

    if identifier is RunnerType.SMDataParallel and env.is_master:
        custom_mpi_options = _mpi_param_value(
            mpi_args, env, params.SMDATAPARALLEL_CUSTOM_MPI_OPTIONS, "")
        return smdataparallel.SMDataParallelRunner(
            user_entry_point,
            args,
            env_vars,
            processes_per_host,
            env.master_hostname,
            env.distribution_hosts,
            custom_mpi_options,
            env.network_interface_name,
        )
    elif identifier is RunnerType.SMDataParallel:
        return mpi.WorkerRunner(user_entry_point, args, env_vars,
                                processes_per_host, env.master_hostname)
    elif identifier is RunnerType.MPI and env.is_master:
        num_processes = _mpi_param_value(mpi_args, env,
                                         params.MPI_NUM_PROCESSES)
        custom_mpi_options = _mpi_param_value(mpi_args, env,
                                              params.MPI_CUSTOM_OPTIONS, "")
        return mpi.MasterRunner(
            user_entry_point,
            args,
            env_vars,
            processes_per_host,
            env.master_hostname,
            env.distribution_hosts,
            custom_mpi_options,
            env.network_interface_name,
            num_processes=num_processes,
        )
    elif identifier is RunnerType.MPI:
        return mpi.WorkerRunner(user_entry_point, args, env_vars,
                                processes_per_host, env.master_hostname)
    elif identifier is RunnerType.Process:
        return process.ProcessRunner(user_entry_point, args, env_vars,
                                     processes_per_host)
    else:
        raise ValueError("Invalid identifier %s" % identifier)
コード例 #2
0
def test_mpi_worker_run(popen, policy, process_iter, wait_procs, ssh_client,
                        sleep, path_exists, write_env_vars):

    process = MagicMock(info={"name": "orted"})
    process_iter.side_effect = lambda attrs: [process]

    worker = mpi.WorkerRunner(
        user_entry_point="train.sh",
        args=["-v", "--lr", "35"],
        env_vars={"LD_CONFIG_PATH": "/etc/ld"},
        master_hostname="algo-1",
    )

    worker.run()

    write_env_vars.assert_called_once()

    ssh_client().load_system_host_keys.assert_called()
    ssh_client().set_missing_host_key_policy.assert_called_with(policy())
    ssh_client().connect.assert_called_with("algo-1", port=22)
    ssh_client().close.assert_called()
    wait_procs.assert_called_with([process])

    popen.assert_called_with(["/usr/sbin/sshd", "-D"])
    path_exists.assert_called_with("/usr/sbin/sshd")
コード例 #3
0
def test_mpi_worker_run_no_wait(popen, ssh_client, path_exists):
    worker = mpi.WorkerRunner(
        user_entry_point="train.sh",
        args=["-v", "--lr", "35"],
        env_vars={"LD_CONFIG_PATH": "/etc/ld"},
        master_hostname="algo-1",
    )

    worker.run(wait=False)

    ssh_client.assert_not_called()

    popen.assert_called_with(["/usr/sbin/sshd", "-D"])
    path_exists.assert_called_with("/usr/sbin/sshd")