def setUp(self): self.sagemaker = SageMakerTransformOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=config, wait_for_completion=False, check_interval=5)
class TestSageMakerTransformOperator(unittest.TestCase): def setUp(self): self.sagemaker = SageMakerTransformOperator( task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=config, wait_for_completion=False, check_interval=5, ) def test_parse_config_integers(self): self.sagemaker.parse_config_integers() test_config = self.sagemaker.config['Transform'] self.assertEqual( test_config['TransformResources']['InstanceCount'], int(test_config['TransformResources']['InstanceCount']), ) self.assertEqual(test_config['MaxConcurrentTransforms'], int(test_config['MaxConcurrentTransforms'])) self.assertEqual(test_config['MaxPayloadInMB'], int(test_config['MaxPayloadInMB'])) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_transform_job') def test_execute(self, mock_transform, mock_model, mock_client): mock_transform.return_value = { 'TransformJobArn': 'testarn', 'ResponseMetadata': { 'HTTPStatusCode': 200 }, } self.sagemaker.execute(None) mock_model.assert_called_once_with(create_model_params) mock_transform.assert_called_once_with(create_transform_params, wait_for_completion=False, check_interval=5, max_ingestion_time=None) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_model') @mock.patch.object(SageMakerHook, 'create_transform_job') def test_execute_with_failure(self, mock_transform, mock_model, mock_client): mock_transform.return_value = { 'TransformJobArn': 'testarn', 'ResponseMetadata': { 'HTTPStatusCode': 404 }, } self.assertRaises(AirflowException, self.sagemaker.execute, None)
check_interval=30) # launch sagemaker hyperparameter job and wait until it completes tune_model_task = SageMakerTuningOperator( task_id='model_tuning', dag=dag, config=tuner_config, # aws_conn_id='airflow-sagemaker', wait_for_completion=True, check_interval=30) # launch sagemaker batch transform job and wait until it completes batch_transform_task = SageMakerTransformOperator( task_id='predicting', dag=dag, config=transform_config, # aws_conn_id='airflow-sagemaker', wait_for_completion=True, check_interval=30, trigger_rule=TriggerRule.ONE_SUCCESS) cleanup_task = DummyOperator(task_id='cleaning_up', dag=dag) # set the dependencies between tasks init.set_downstream(preprocess_task) preprocess_task.set_downstream(prepare_task) prepare_task.set_downstream(branching) branching.set_downstream(tune_model_task) branching.set_downstream(train_model_task) tune_model_task.set_downstream(batch_transform_task) train_model_task.set_downstream(batch_transform_task)