def setUp(self):
     self.sagemaker = SageMakerTuningOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_conn',
         config=create_tuning_params,
         wait_for_completion=False,
         check_interval=5)
class TestSageMakerTuningOperator(unittest.TestCase):
    def setUp(self):
        configuration.load_test_config()
        self.sagemaker = SageMakerTuningOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_conn',
            config=create_tuning_params,
            wait_for_completion=False,
            check_interval=5)

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(
            self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
            ['InstanceCount'],
            int(self.sagemaker.config['TrainingJobDefinition']
                ['ResourceConfig']['InstanceCount']))
        self.assertEqual(
            self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
            ['VolumeSizeInGB'],
            int(self.sagemaker.config['TrainingJobDefinition']
                ['ResourceConfig']['VolumeSizeInGB']))
        self.assertEqual(
            self.sagemaker.config['HyperParameterTuningJobConfig']
            ['ResourceLimits']['MaxNumberOfTrainingJobs'],
            int(self.sagemaker.config['HyperParameterTuningJobConfig']
                ['ResourceLimits']['MaxNumberOfTrainingJobs']))
        self.assertEqual(
            self.sagemaker.config['HyperParameterTuningJobConfig']
            ['ResourceLimits']['MaxParallelTrainingJobs'],
            int(self.sagemaker.config['HyperParameterTuningJobConfig']
                ['ResourceLimits']['MaxParallelTrainingJobs']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_tuning_job')
    def test_execute(self, mock_tuning, mock_client):
        mock_tuning.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_tuning.assert_called_once_with(create_tuning_params,
                                            wait_for_completion=False,
                                            check_interval=5,
                                            max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_tuning_job')
    def test_execute_with_failure(self, mock_tuning, mock_client):
        mock_tuning.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
class TestSageMakerTuningOperator(unittest.TestCase):

    def setUp(self):
        configuration.load_test_config()
        self.sagemaker = SageMakerTuningOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_conn',
            config=create_tuning_params,
            wait_for_completion=False,
            check_interval=5
        )

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
                         ['InstanceCount'],
                         int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
                             ['InstanceCount']))
        self.assertEqual(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
                         ['VolumeSizeInGB'],
                         int(self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
                             ['VolumeSizeInGB']))
        self.assertEqual(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits']
                         ['MaxNumberOfTrainingJobs'],
                         int(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits']
                             ['MaxNumberOfTrainingJobs']))
        self.assertEqual(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits']
                         ['MaxParallelTrainingJobs'],
                         int(self.sagemaker.config['HyperParameterTuningJobConfig']['ResourceLimits']
                             ['MaxParallelTrainingJobs']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_tuning_job')
    def test_execute(self, mock_tuning, mock_client):
        mock_tuning.return_value = {'TrainingJobArn': 'testarn',
                                    'ResponseMetadata':
                                    {'HTTPStatusCode': 200}}
        self.sagemaker.execute(None)
        mock_tuning.assert_called_once_with(create_tuning_params,
                                            wait_for_completion=False,
                                            check_interval=5,
                                            max_ingestion_time=None
                                            )

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_tuning_job')
    def test_execute_with_failure(self, mock_tuning, mock_client):
        mock_tuning.return_value = {'TrainingJobArn': 'testarn',
                                    'ResponseMetadata':
                                    {'HTTPStatusCode': 404}}
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
 def setUp(self):
     configuration.load_test_config()
     self.sagemaker = SageMakerTuningOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_conn',
         config=create_tuning_params,
         wait_for_completion=False,
         check_interval=5
     )
                                 dag=dag,
                                 python_callable=lambda: "model_tuning"
                                 if hpo_enabled else "model_training")

# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(task_id='model_training',
                                             dag=dag,
                                             config=train_config,
                                             aws_conn_id='airflow-sagemaker',
                                             wait_for_completion=True,
                                             check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(task_id='model_tuning',
                                          dag=dag,
                                          config=tuner_config,
                                          aws_conn_id='airflow-sagemaker',
                                          wait_for_completion=True,
                                          check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
    config=transform_config,
    aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30,
    trigger_rule=TriggerRule.ONE_SUCCESS)

basher_task = BashOperator(task_id='sleep', bash_command='sleep 5', dag=dag)
                                 dag=dag,
                                 python_callable=lambda: "model_tuning"
                                 if hpo_enabled else "model_training")

# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(task_id='model_training',
                                             dag=dag,
                                             config=train_config,
                                             aws_conn_id='airflow-sagemaker',
                                             wait_for_completion=True,
                                             check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(task_id='model_tuning',
                                          dag=dag,
                                          config=tuner_config,
                                          aws_conn_id='airflow-sagemaker',
                                          wait_for_completion=True,
                                          check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
    config=transform_config,
    aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30,
    trigger_rule=TriggerRule.ONE_SUCCESS)

cleanup_task = DummyOperator(task_id='cleaning_up', dag=dag)