def test_train_for_distributed_scheduler(run_entry_point, verify_hosts, host_lookup, popen, distributed_training_env): host_lookup.return_value = IP_ADDRESS distributed_training_env.current_host = SCHEDULER training.train(distributed_training_env) verify_hosts.assert_called_with(MULTIPLE_HOST_LIST) scheduler_env = BASE_ENV_VARS.copy() scheduler_env.update({'DMLC_ROLE': 'scheduler'}) server_env = BASE_ENV_VARS.copy() server_env.update({'DMLC_ROLE': 'server'}) calls = [ call(MXNET_COMMAND, shell=True, env=scheduler_env), call(MXNET_COMMAND, shell=True, env=server_env) ] popen.assert_has_calls(calls) run_entry_point.assert_called_with( uri=MODULE_DIR, user_entry_point=MODULE_NAME, args=distributed_training_env.to_cmd_args(), env_vars=distributed_training_env.to_env_vars(), runner_type=runner.ProcessRunnerType)
def test_train_for_single_machine(run_entry_point, single_machine_training_env): training.train(single_machine_training_env) run_entry_point.assert_called_with( uri=MODULE_DIR, user_entry_point=MODULE_NAME, args=single_machine_training_env.to_cmd_args(), env_vars=single_machine_training_env.to_env_vars(), runner_type=runner.ProcessRunnerType)
def test_train_for_single_machine(run_entry_point, download_and_install, single_machine_training_env): training.train(single_machine_training_env) download_and_install.assert_called_with(MODULE_DIR) run_entry_point.assert_called_with( MODULE_DIR, MODULE_NAME, single_machine_training_env.to_cmd_args(), single_machine_training_env.to_env_vars(), runner=framework.runner.ProcessRunnerType)
def test_train_horovod(run_module, single_machine_training_env): single_machine_training_env.additional_framework_parameters = { training.LAUNCH_MPI_ENV_NAME: True, } training.train(single_machine_training_env) run_module.assert_called_with( uri=MODULE_DIR, user_entry_point=MODULE_NAME, args=single_machine_training_env.to_cmd_args(), env_vars=single_machine_training_env.to_env_vars(), runner_type=runner.MPIRunnerType)
def test_train_for_distributed_worker(run_module, verify_hosts, host_lookup, popen, distributed_training_env): host_lookup.return_value = IP_ADDRESS distributed_training_env.current_host = 'host-2' training.train(distributed_training_env) verify_hosts.assert_called_with(MULTIPLE_HOST_LIST) server_env = BASE_ENV_VARS.copy() server_env.update({'DMLC_ROLE': 'server'}) popen.assert_called_once_with(MXNET_COMMAND, shell=True, env=server_env)
def test_train_for_distributed_scheduler(run_module, verify_hosts, host_lookup, popen, distributed_training_env): host_lookup.return_value = IP_ADDRESS distributed_training_env.current_host = SCHEDULER training.train(distributed_training_env) verify_hosts.assert_called_with(MULTIPLE_HOST_LIST) scheduler_env = BASE_ENV_VARS.copy() scheduler_env.update({'DMLC_ROLE': 'scheduler'}) server_env = BASE_ENV_VARS.copy() server_env.update({'DMLC_ROLE': 'server'}) calls = [call(MXNET_COMMAND, shell=True, env=scheduler_env), call(MXNET_COMMAND, shell=True, env=server_env)] popen.assert_has_calls(calls)
def test_train_for_distributed_worker(run_entry_point, download_and_install, verify_hosts, host_lookup, popen, distributed_training_env): host_lookup.return_value = IP_ADDRESS distributed_training_env.current_host = 'host-2' training.train(distributed_training_env) verify_hosts.assert_called_with(MULTIPLE_HOST_LIST) server_env = BASE_ENV_VARS.copy() server_env.update({'DMLC_ROLE': 'server'}) popen.assert_called_once_with(MXNET_COMMAND, shell=True, env=server_env) download_and_install.assert_called_with(MODULE_DIR) run_entry_point.assert_called_with( MODULE_DIR, MODULE_NAME, distributed_training_env.to_cmd_args(), distributed_training_env.to_env_vars(), runner=framework.runner.ProcessRunnerType)
def test_train_for_single_machine(run_module, single_machine_training_env): training.train(single_machine_training_env) run_module.assert_called_with(MODULE_DIR, single_machine_training_env.to_cmd_args(), single_machine_training_env.to_env_vars(), MODULE_NAME)