def test_read_value_serialized_and_non_value_serialized_hyperparameters(): hyperparameters = { k: json.dumps(v) for k, v in SAGEMAKER_HYPERPARAMETERS.items() } hyperparameters.update(USER_HYPERPARAMETERS) test.write_json(hyperparameters, environment.hyperparameters_file_dir) assert environment.read_hyperparameters() == ALL_HYPERPARAMETERS
def main(): """Training entry point """ hyperparameters = environment.read_hyperparameters() env = environment.Environment(hyperparameters=hyperparameters) user_hyperparameters = env.hyperparameters # If the training job is part of the multiple training jobs for tuning, we need to append the training job name to # model_dir in case they read from/write to the same object if "_tuning_objective_metric" in hyperparameters: model_dir = _model_dir_with_training_job(hyperparameters.get("model_dir"), env.job_name) logger.info("Appending the training job name to model_dir: {}".format(model_dir)) user_hyperparameters["model_dir"] = model_dir s3_utils.configure(user_hyperparameters.get("model_dir"), os.environ.get("SAGEMAKER_REGION")) train(env, mapping.to_cmd_args(user_hyperparameters)) _log_model_missing_warning(MODEL_DIR)
def test_read_exception(loads): loads.side_effect = ValueError("Unable to read.") assert environment.read_hyperparameters() == {"a": 1}
def test_read_value_serialized_hyperparameters(): serialized_hps = {k: json.dumps(v) for k, v in ALL_HYPERPARAMETERS.items()} test.write_json(serialized_hps, environment.hyperparameters_file_dir) assert environment.read_hyperparameters() == ALL_HYPERPARAMETERS
def test_read_hyperparameters(): test.write_json(ALL_HYPERPARAMETERS, environment.hyperparameters_file_dir) assert environment.read_hyperparameters() == ALL_HYPERPARAMETERS