def test_create_transfer_learning_tuner(sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges): """Tests Transfer learning use case with two parents and child job launched with create_transfer_learning_tuner() """ parent_tuning_job_name_1 = name_from_base("km-tran2-parent1", max_length=32, short=True) parent_tuning_job_name_2 = name_from_base("km-tran2-parent2", max_length=32, short=True) child_tuning_job_name = name_from_base("km-tran2-child", max_length=32, short=True) parent_tuner_1 = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name_1, hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, max_jobs=1) parent_tuner_2 = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name_2, hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, max_jobs=1) child_tuner = create_transfer_learning_tuner(parent=parent_tuner_1.latest_tuning_job.name, sagemaker_session=sagemaker_session, estimator=kmeans_estimator, additional_parents={parent_tuner_2.latest_tuning_job.name}) _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner) child_warm_start_config_response = WarmStartConfig.from_job_desc( sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"]) assert child_warm_start_config_response.type == child_tuner.warm_start_config.type assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents
def test_create_transfer_learning_tuner(sagemaker_session, estimator, additional_parents): job_details = copy.deepcopy(TUNING_JOB_DETAILS) sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', return_value=job_details) tuner = create_transfer_learning_tuner(parent=JOB_NAME, additional_parents=additional_parents, sagemaker_session=sagemaker_session, estimator=estimator) assert tuner.warm_start_config.type == WarmStartTypes.TRANSFER_LEARNING assert tuner.estimator == estimator if additional_parents: additional_parents.add(JOB_NAME) assert tuner.warm_start_config.parents == additional_parents else: assert tuner.warm_start_config.parents == {JOB_NAME}