def test_tune_warm_start(sagemaker_session, warm_start_type, parents):

    def assert_create_tuning_job_request(**kwrags):
        assert kwrags["HyperParameterTuningJobConfig"] == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"]
        assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1"
        assert kwrags["TrainingJobDefinition"] == SAMPLE_TUNING_JOB_REQUEST["TrainingJobDefinition"]
        assert kwrags["WarmStartConfig"] == {
            'WarmStartType': warm_start_type,
            'ParentHyperParameterTuningJobs': [{'HyperParameterTuningJobName': parent} for parent in parents]
        }

    sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = assert_create_tuning_job_request
    sagemaker_session.tune(job_name="dummy-tuning-1",
                           strategy="Bayesian",
                           objective_type="Maximize",
                           objective_metric_name="val-score",
                           max_jobs=100,
                           max_parallel_jobs=5,
                           parameter_ranges=SAMPLE_PARAM_RANGES,
                           static_hyperparameters=STATIC_HPs,
                           image="dummy-image-1",
                           input_mode="File",
                           metric_definitions=SAMPLE_METRIC_DEF,
                           role=EXPANDED_ROLE,
                           input_config=SAMPLE_INPUT,
                           output_config=SAMPLE_OUTPUT,
                           resource_config=RESOURCE_CONFIG,
                           stop_condition=SAMPLE_STOPPING_CONDITION,
                           tags=None,
                           warm_start_config=WarmStartConfig(warm_start_type=WarmStartTypes(warm_start_type),
                                                             parents=parents).to_input_req())
def test_prepare_warm_start_config_cls(warm_start_config_req):
    warm_start_config = WarmStartConfig.from_job_desc(warm_start_config_req)

    assert warm_start_config.type == WarmStartTypes(
        warm_start_config_req["WarmStartType"]), "Warm start type initialization failed."

    for p in warm_start_config_req["ParentHyperParameterTuningJobs"]:
        assert p['HyperParameterTuningJobName'] in warm_start_config.parents, \
            "Warm start parents config initialization failed."
def test_warm_start_config_init(type, parents):
    warm_start_config = WarmStartConfig(warm_start_type=type, parents=parents)

    assert warm_start_config.type == type, "Warm start type initialization failed."
    assert warm_start_config.parents == set(parents), "Warm start parents config initialization failed."

    warm_start_config_req = warm_start_config.to_input_req()
    assert warm_start_config.type == WarmStartTypes(warm_start_config_req["WarmStartType"])
    for parent in warm_start_config_req["ParentHyperParameterTuningJobs"]:
        assert parent['HyperParameterTuningJobName'] in parents