def test_user_script_error_raised(run_entry_point, training_env): output = 'Not \'gloo::EnforceNotMet\' exception.' run_entry_point.side_effect = errors.ExecuteUserScriptError( cmd='Command "/usr/bin/python -m userscript"', output=output.encode('latin1') if six.PY3 else output) with pytest.raises(errors.ExecuteUserScriptError): train(training_env)
def test_gloo_exception_intercepted(run_entry_point, training_env): output = 'terminate called after throwing an instance of \'gloo::EnforceNotMet\'' run_entry_point.side_effect = errors.ExecuteUserScriptError( cmd='Command "/usr/bin/python -m userscript"', output=output.encode('latin1') if six.PY3 else output) train(training_env) run_entry_point.assert_called()
def test_train(run_module, training_env): train(training_env) run_module.assert_called_with(training_env.module_dir, training_env.to_cmd_args(), training_env.to_env_vars(), training_env.module_name)
def test_train_with_missing_parameters(training_env, user_module): def user_module_train(missing_param): return nn.Module() user_module.train = user_module_train with pytest.raises(TypeError): train(user_module, training_env)
def test_train(run_entry_point, training_env): train(training_env) run_entry_point.assert_called_with( uri=training_env.module_dir, user_entry_point=training_env.user_entry_point, args=training_env.to_cmd_args(), env_vars=training_env.to_env_vars(), capture_error=True, runner_type=runner.ProcessRunnerType)
def test_environment(training_env): train(training_env) # distributed training specific environment assert MASTER_PORT == os.environ['MASTER_PORT'] assert training_env.hosts[0] == os.environ['MASTER_ADDR'] # nccl specific environment assert training_env.network_interface_name == os.environ['NCCL_SOCKET_IFNAME'] assert '1' == os.environ['NCCL_IB_DISABLE'] assert 'WARN' == os.environ['NCCL_DEBUG']
def test_train(run_entry_point, download_and_install, training_env): train(training_env) download_and_install.assert_called_with(training_env.module_dir) run_entry_point.assert_called_with( training_env.module_dir, training_env.user_entry_point, training_env.to_cmd_args(), training_env.to_env_vars(), capture_error=True, runner=framework.runner.ProcessRunnerType)