def test_tune_warm_start(sagemaker_session, warm_start_type, parents): def assert_create_tuning_job_request(**kwrags): assert kwrags["HyperParameterTuningJobConfig"] == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1" assert kwrags["TrainingJobDefinition"] == SAMPLE_TUNING_JOB_REQUEST["TrainingJobDefinition"] assert kwrags["WarmStartConfig"] == { 'WarmStartType': warm_start_type, 'ParentHyperParameterTuningJobs': [{'HyperParameterTuningJobName': parent} for parent in parents] } sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = assert_create_tuning_job_request sagemaker_session.tune(job_name="dummy-tuning-1", strategy="Bayesian", objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, max_parallel_jobs=5, parameter_ranges=SAMPLE_PARAM_RANGES, static_hyperparameters=STATIC_HPs, image="dummy-image-1", input_mode="File", metric_definitions=SAMPLE_METRIC_DEF, role=EXPANDED_ROLE, input_config=SAMPLE_INPUT, output_config=SAMPLE_OUTPUT, resource_config=RESOURCE_CONFIG, stop_condition=SAMPLE_STOPPING_CONDITION, tags=None, warm_start_config=WarmStartConfig(warm_start_type=WarmStartTypes(warm_start_type), parents=parents).to_input_req())
def test_prepare_warm_start_config_cls(warm_start_config_req): warm_start_config = WarmStartConfig.from_job_desc(warm_start_config_req) assert warm_start_config.type == WarmStartTypes( warm_start_config_req["WarmStartType"]), "Warm start type initialization failed." for p in warm_start_config_req["ParentHyperParameterTuningJobs"]: assert p['HyperParameterTuningJobName'] in warm_start_config.parents, \ "Warm start parents config initialization failed."
def test_warm_start_config_init(type, parents): warm_start_config = WarmStartConfig(warm_start_type=type, parents=parents) assert warm_start_config.type == type, "Warm start type initialization failed." assert warm_start_config.parents == set(parents), "Warm start parents config initialization failed." warm_start_config_req = warm_start_config.to_input_req() assert warm_start_config.type == WarmStartTypes(warm_start_config_req["WarmStartType"]) for parent in warm_start_config_req["ParentHyperParameterTuningJobs"]: assert parent['HyperParameterTuningJobName'] in parents