def setUp(self):
     self.sagemaker = SageMakerTransformOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=config,
         wait_for_completion=False,
         check_interval=5)
示例#2
0
class TestSageMakerTransformOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerTransformOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=config,
            wait_for_completion=False,
            check_interval=5,
        )

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        test_config = self.sagemaker.config['Transform']
        self.assertEqual(
            test_config['TransformResources']['InstanceCount'],
            int(test_config['TransformResources']['InstanceCount']),
        )
        self.assertEqual(test_config['MaxConcurrentTransforms'],
                         int(test_config['MaxConcurrentTransforms']))
        self.assertEqual(test_config['MaxPayloadInMB'],
                         int(test_config['MaxPayloadInMB']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_transform_job')
    def test_execute(self, mock_transform, mock_model, mock_client):
        mock_transform.return_value = {
            'TransformJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            },
        }
        self.sagemaker.execute(None)
        mock_model.assert_called_once_with(create_model_params)
        mock_transform.assert_called_once_with(create_transform_params,
                                               wait_for_completion=False,
                                               check_interval=5,
                                               max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_transform_job')
    def test_execute_with_failure(self, mock_transform, mock_model,
                                  mock_client):
        mock_transform.return_value = {
            'TransformJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            },
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
示例#3
0
    check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(
    task_id='model_tuning',
    dag=dag,
    config=tuner_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
    config=transform_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30,
    trigger_rule=TriggerRule.ONE_SUCCESS)

cleanup_task = DummyOperator(task_id='cleaning_up', dag=dag)

# set the dependencies between tasks

init.set_downstream(preprocess_task)
preprocess_task.set_downstream(prepare_task)
prepare_task.set_downstream(branching)
branching.set_downstream(tune_model_task)
branching.set_downstream(train_model_task)
tune_model_task.set_downstream(batch_transform_task)
train_model_task.set_downstream(batch_transform_task)