コード例 #1
0
def test_user_script_error_raised(run_entry_point, training_env):
    output = 'Not \'gloo::EnforceNotMet\' exception.'
    run_entry_point.side_effect = errors.ExecuteUserScriptError(
        cmd='Command "/usr/bin/python -m userscript"',
        output=output.encode('latin1') if six.PY3 else output)
    with pytest.raises(errors.ExecuteUserScriptError):
        train(training_env)
コード例 #2
0
def test_gloo_exception_intercepted(run_entry_point, training_env):
    output = 'terminate called after throwing an instance of \'gloo::EnforceNotMet\''
    run_entry_point.side_effect = errors.ExecuteUserScriptError(
        cmd='Command "/usr/bin/python -m userscript"',
        output=output.encode('latin1') if six.PY3 else output)
    train(training_env)
    run_entry_point.assert_called()
コード例 #3
0
def test_train(run_module, training_env):
    train(training_env)

    run_module.assert_called_with(training_env.module_dir,
                                  training_env.to_cmd_args(),
                                  training_env.to_env_vars(),
                                  training_env.module_name)
コード例 #4
0
def test_train_with_missing_parameters(training_env, user_module):
    def user_module_train(missing_param):
        return nn.Module()

    user_module.train = user_module_train

    with pytest.raises(TypeError):
        train(user_module, training_env)
コード例 #5
0
def test_train(run_entry_point, training_env):
    train(training_env)

    run_entry_point.assert_called_with(
        uri=training_env.module_dir,
        user_entry_point=training_env.user_entry_point,
        args=training_env.to_cmd_args(),
        env_vars=training_env.to_env_vars(),
        capture_error=True,
        runner_type=runner.ProcessRunnerType)
コード例 #6
0
def test_environment(training_env):
    train(training_env)

    # distributed training specific environment
    assert MASTER_PORT == os.environ['MASTER_PORT']
    assert training_env.hosts[0] == os.environ['MASTER_ADDR']

    # nccl specific environment
    assert training_env.network_interface_name == os.environ['NCCL_SOCKET_IFNAME']
    assert '1' == os.environ['NCCL_IB_DISABLE']
    assert 'WARN' == os.environ['NCCL_DEBUG']
コード例 #7
0
def test_train(run_entry_point, download_and_install, training_env):
    train(training_env)

    download_and_install.assert_called_with(training_env.module_dir)
    run_entry_point.assert_called_with(
        training_env.module_dir,
        training_env.user_entry_point,
        training_env.to_cmd_args(),
        training_env.to_env_vars(),
        capture_error=True,
        runner=framework.runner.ProcessRunnerType)