예제 #1
0
def training_env():
    """Create a TrainingEnv.

    Returns:
        TrainingEnv: an instance of TrainingEnv
    """
    from sagemaker_containers import _env

    return _env.TrainingEnv(resource_config=_env.read_resource_config(),
                            input_data_config=_env.read_input_data_config(),
                            hyperparameters=_env.read_hyperparameters())
예제 #2
0
def training_env(resource_config=None,
                 input_data_config=None,
                 hyperparameters=None):

    resource_config = resource_config or env.read_resource_config()
    input_data_config = input_data_config or env.read_input_data_config()
    hyperparameters = hyperparameters or env.read_hyperparameters()

    return env.TrainingEnv(resource_config=resource_config,
                           input_data_config=input_data_config,
                           hyperparameters=hyperparameters)
예제 #3
0
def test_read_value_serialized_and_non_value_serialized_hyperparameters():
    hyperparameters = {
        k: json.dumps(v)
        for k, v in SAGEMAKER_HYPERPARAMETERS.items()
    }

    hyperparameters.update(USER_HYPERPARAMETERS)

    test.write_json(hyperparameters, _env.hyperparameters_file_dir)

    assert _env.read_hyperparameters() == ALL_HYPERPARAMETERS
예제 #4
0
def training_env(resource_config=None, input_data_config=None, hyperparameters=None):
    """Placeholder docstring"""

    resource_config = resource_config or env.read_resource_config()
    input_data_config = input_data_config or env.read_input_data_config()
    hyperparameters = hyperparameters or env.read_hyperparameters()

    return env.TrainingEnv(
        resource_config=resource_config,
        input_data_config=input_data_config,
        hyperparameters=hyperparameters,
    )
예제 #5
0
def test_read_exception(loads):
    loads.side_effect = ValueError('Unable to read.')

    assert _env.read_hyperparameters() == {'a': 1}
예제 #6
0
def test_read_value_serialized_hyperparameters():
    serialized_hps = {k: json.dumps(v) for k, v in ALL_HYPERPARAMETERS.items()}
    test.write_json(serialized_hps, _env.hyperparameters_file_dir)

    assert _env.read_hyperparameters() == ALL_HYPERPARAMETERS
예제 #7
0
def test_read_hyperparameters():
    test.write_json(ALL_HYPERPARAMETERS, _env.hyperparameters_file_dir)

    assert _env.read_hyperparameters() == ALL_HYPERPARAMETERS