示例#1
0
def test_train_for_distributed_scheduler(run_entry_point, verify_hosts,
                                         host_lookup, popen,
                                         distributed_training_env):
    host_lookup.return_value = IP_ADDRESS

    distributed_training_env.current_host = SCHEDULER
    training.train(distributed_training_env)

    verify_hosts.assert_called_with(MULTIPLE_HOST_LIST)

    scheduler_env = BASE_ENV_VARS.copy()
    scheduler_env.update({'DMLC_ROLE': 'scheduler'})

    server_env = BASE_ENV_VARS.copy()
    server_env.update({'DMLC_ROLE': 'server'})

    calls = [
        call(MXNET_COMMAND, shell=True, env=scheduler_env),
        call(MXNET_COMMAND, shell=True, env=server_env)
    ]

    popen.assert_has_calls(calls)

    run_entry_point.assert_called_with(
        uri=MODULE_DIR,
        user_entry_point=MODULE_NAME,
        args=distributed_training_env.to_cmd_args(),
        env_vars=distributed_training_env.to_env_vars(),
        runner_type=runner.ProcessRunnerType)
示例#2
0
def test_train_for_single_machine(run_entry_point,
                                  single_machine_training_env):
    training.train(single_machine_training_env)
    run_entry_point.assert_called_with(
        uri=MODULE_DIR,
        user_entry_point=MODULE_NAME,
        args=single_machine_training_env.to_cmd_args(),
        env_vars=single_machine_training_env.to_env_vars(),
        runner_type=runner.ProcessRunnerType)
def test_train_for_single_machine(run_entry_point, download_and_install,
                                  single_machine_training_env):
    training.train(single_machine_training_env)
    download_and_install.assert_called_with(MODULE_DIR)
    run_entry_point.assert_called_with(
        MODULE_DIR,
        MODULE_NAME,
        single_machine_training_env.to_cmd_args(),
        single_machine_training_env.to_env_vars(),
        runner=framework.runner.ProcessRunnerType)
示例#4
0
def test_train_horovod(run_module, single_machine_training_env):
    single_machine_training_env.additional_framework_parameters = {
        training.LAUNCH_MPI_ENV_NAME: True,
    }

    training.train(single_machine_training_env)
    run_module.assert_called_with(
        uri=MODULE_DIR,
        user_entry_point=MODULE_NAME,
        args=single_machine_training_env.to_cmd_args(),
        env_vars=single_machine_training_env.to_env_vars(),
        runner_type=runner.MPIRunnerType)
示例#5
0
def test_train_for_distributed_worker(run_module, verify_hosts, host_lookup,
                                      popen, distributed_training_env):
    host_lookup.return_value = IP_ADDRESS

    distributed_training_env.current_host = 'host-2'
    training.train(distributed_training_env)

    verify_hosts.assert_called_with(MULTIPLE_HOST_LIST)

    server_env = BASE_ENV_VARS.copy()
    server_env.update({'DMLC_ROLE': 'server'})

    popen.assert_called_once_with(MXNET_COMMAND, shell=True, env=server_env)
def test_train_for_distributed_scheduler(run_module, verify_hosts, host_lookup, popen,
                                         distributed_training_env):
    host_lookup.return_value = IP_ADDRESS

    distributed_training_env.current_host = SCHEDULER
    training.train(distributed_training_env)

    verify_hosts.assert_called_with(MULTIPLE_HOST_LIST)

    scheduler_env = BASE_ENV_VARS.copy()
    scheduler_env.update({'DMLC_ROLE': 'scheduler'})

    server_env = BASE_ENV_VARS.copy()
    server_env.update({'DMLC_ROLE': 'server'})

    calls = [call(MXNET_COMMAND, shell=True, env=scheduler_env),
             call(MXNET_COMMAND, shell=True, env=server_env)]

    popen.assert_has_calls(calls)
def test_train_for_distributed_worker(run_entry_point, download_and_install,
                                      verify_hosts, host_lookup, popen,
                                      distributed_training_env):
    host_lookup.return_value = IP_ADDRESS

    distributed_training_env.current_host = 'host-2'
    training.train(distributed_training_env)

    verify_hosts.assert_called_with(MULTIPLE_HOST_LIST)

    server_env = BASE_ENV_VARS.copy()
    server_env.update({'DMLC_ROLE': 'server'})

    popen.assert_called_once_with(MXNET_COMMAND, shell=True, env=server_env)

    download_and_install.assert_called_with(MODULE_DIR)
    run_entry_point.assert_called_with(
        MODULE_DIR,
        MODULE_NAME,
        distributed_training_env.to_cmd_args(),
        distributed_training_env.to_env_vars(),
        runner=framework.runner.ProcessRunnerType)
示例#8
0
def test_train_for_single_machine(run_module, single_machine_training_env):
    training.train(single_machine_training_env)
    run_module.assert_called_with(MODULE_DIR,
                                  single_machine_training_env.to_cmd_args(),
                                  single_machine_training_env.to_env_vars(),
                                  MODULE_NAME)