예제 #1
0
def test_create_transfer_learning_tuner(sagemaker_session,
                                        kmeans_train_set,
                                        kmeans_estimator,
                                        hyperparameter_ranges):
    """Tests Transfer learning use case with two parents and child job launched with
        create_transfer_learning_tuner() """
    parent_tuning_job_name_1 = name_from_base("km-tran2-parent1", max_length=32, short=True)
    parent_tuning_job_name_2 = name_from_base("km-tran2-parent2", max_length=32, short=True)
    child_tuning_job_name = name_from_base("km-tran2-child", max_length=32, short=True)

    parent_tuner_1 = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name_1,
                           hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, max_jobs=1)

    parent_tuner_2 = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name_2,
                           hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, max_jobs=1)

    child_tuner = create_transfer_learning_tuner(parent=parent_tuner_1.latest_tuning_job.name,
                                                 sagemaker_session=sagemaker_session,
                                                 estimator=kmeans_estimator,
                                                 additional_parents={parent_tuner_2.latest_tuning_job.name})
    _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner)

    child_warm_start_config_response = WarmStartConfig.from_job_desc(
        sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job(
            HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"])

    assert child_warm_start_config_response.type == child_tuner.warm_start_config.type
    assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents
def test_create_transfer_learning_tuner(sagemaker_session, estimator, additional_parents):
    job_details = copy.deepcopy(TUNING_JOB_DETAILS)
    sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
                                                                                  return_value=job_details)

    tuner = create_transfer_learning_tuner(parent=JOB_NAME,
                                           additional_parents=additional_parents,
                                           sagemaker_session=sagemaker_session,
                                           estimator=estimator)

    assert tuner.warm_start_config.type == WarmStartTypes.TRANSFER_LEARNING
    assert tuner.estimator == estimator
    if additional_parents:
        additional_parents.add(JOB_NAME)
        assert tuner.warm_start_config.parents == additional_parents
    else:
        assert tuner.warm_start_config.parents == {JOB_NAME}