def setUp(self): self.sagemaker = SageMakerTuningOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_conn', config=create_tuning_params, wait_for_completion=False, check_interval=5)
class TestSageMakerTuningOperator(unittest.TestCase): def setUp(self): self.sagemaker = SageMakerTuningOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_conn', config=create_tuning_params, wait_for_completion=False, check_interval=5) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() self.assertEqual( self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] ['InstanceCount'], int(self.sagemaker.config['TrainingJobDefinition'] ['ResourceConfig']['InstanceCount'])) self.assertEqual( self.sagemaker.config['TrainingJobDefinition']['ResourceConfig'] ['VolumeSizeInGB'], int(self.sagemaker.config['TrainingJobDefinition'] ['ResourceConfig']['VolumeSizeInGB'])) self.assertEqual( self.sagemaker.config['HyperParameterTuningJobConfig'] ['ResourceLimits']['MaxNumberOfTrainingJobs'], int(self.sagemaker.config['HyperParameterTuningJobConfig'] ['ResourceLimits']['MaxNumberOfTrainingJobs'])) self.assertEqual( self.sagemaker.config['HyperParameterTuningJobConfig'] ['ResourceLimits']['MaxParallelTrainingJobs'], int(self.sagemaker.config['HyperParameterTuningJobConfig'] ['ResourceLimits']['MaxParallelTrainingJobs'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') def test_execute(self, mock_tuning, mock_client): mock_tuning.return_value = { 'TrainingJobArn': 'testarn', 'ResponseMetadata': { 'HTTPStatusCode': 200 } } self.sagemaker.execute(None) mock_tuning.assert_called_once_with(create_tuning_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_tuning_job') def test_execute_with_failure(self, mock_tuning, mock_client): mock_tuning.return_value = { 'TrainingJobArn': 'testarn', 'ResponseMetadata': { 'HTTPStatusCode': 404 } } self.assertRaises(AirflowException, self.sagemaker.execute, None)
if hpo_enabled else "model_training") # launch sagemaker training job and wait until it completes train_model_task = SageMakerTrainingOperator( task_id='model_training', dag=dag, config=train_config, # aws_conn_id='airflow-sagemaker', wait_for_completion=True, check_interval=30) # launch sagemaker hyperparameter job and wait until it completes tune_model_task = SageMakerTuningOperator( task_id='model_tuning', dag=dag, config=tuner_config, # aws_conn_id='airflow-sagemaker', wait_for_completion=True, check_interval=30) # launch sagemaker batch transform job and wait until it completes batch_transform_task = SageMakerTransformOperator( task_id='predicting', dag=dag, config=transform_config, # aws_conn_id='airflow-sagemaker', wait_for_completion=True, check_interval=30, trigger_rule=TriggerRule.ONE_SUCCESS) cleanup_task = DummyOperator(task_id='cleaning_up', dag=dag)