def setUp(self):
     self.sagemaker = SageMakerEndpointOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=config,
         wait_for_completion=False,
         check_interval=5,
         operation='create')
class TestSageMakerEndpointOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerEndpointOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=config,
            wait_for_completion=False,
            check_interval=5,
            operation='create')

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        for variant in self.sagemaker.config['EndpointConfig'][
                'ProductionVariants']:
            self.assertEqual(variant['InitialInstanceCount'],
                             int(variant['InitialInstanceCount']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model,
                     mock_client):
        mock_endpoint.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_model.assert_called_once_with(create_model_params)
        mock_endpoint_config.assert_called_once_with(
            create_endpoint_config_params)
        mock_endpoint.assert_called_once_with(create_endpoint_params,
                                              wait_for_completion=False,
                                              check_interval=5,
                                              max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config,
                                  mock_model, mock_client):
        mock_endpoint.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
sagemaker_deploy_model = SageMakerEndpointOperator(
    task_id="sagemaker_deploy_model",
    operation="update",
    wait_for_completion=True,
    config={
        "Model": {
            "ModelName":
            "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
            "PrimaryContainer": {
                "Image":
                "438346466558.dkr.ecr.eu-west-1.amazonaws.com/kmeans:1",
                "ModelDataUrl":
                ("s3://your-bucket/mnistclassifier-output/mnistclassifier"
                 "-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/"
                 "output/model.tar.gz"
                 ),  # this will link the model and the training job
            },
            "ExecutionRoleArn":
            ("arn:aws:iam::297623009465:role/service-role/"
             "AmazonSageMaker-ExecutionRole-20180905T153196"),
        },
        "EndpointConfig": {
            "EndpointConfigName":
            "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
            "ProductionVariants": [{
                "InitialInstanceCount": 1,
                "InstanceType": "ml.t2.medium",
                "ModelName": "mnistclassifier",
                "VariantName": "AllTraffic",
            }],
        },
        "Endpoint": {
            "EndpointConfigName":
            "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
            "EndpointName": "mnistclassifier",
        },
    },
    dag=dag,
)
Esempio n. 4
0
) as dag:
    process_task = PythonOperator(
        task_id="process",
        dag=dag,
        #provide_context=False,
        python_callable=preprocess_glue,
    )

    train_task = SageMakerTrainingOperator(
        task_id="train",
        config=training_config,
        aws_conn_id="airflow-sagemaker",
        wait_for_completion=True,
        check_interval=60,  #check status of the job every minute
        max_ingestion_time=
        None,  #allow training job to run as long as it needs, change for early stop
    )

    endpoint_deploy_task = SageMakerEndpointOperator(
        task_id="endpoint-deploy",
        config=endpoint_config,
        aws_conn_id="sagemaker-airflow",
        wait_for_completion=True,
        check_interval=60,  #check status of endpoint deployment every minute
        max_ingestion_time=None,
        operation=
        'create',  #change to update if you are updating rather than creating an endpoint
    )

    # set the dependencies between tasks
    process_task >> train_task >> endpoint_deploy_task
class TestSageMakerEndpointOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerEndpointOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=config,
            wait_for_completion=False,
            check_interval=5,
            operation='create',
        )

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        for variant in self.sagemaker.config['EndpointConfig'][
                'ProductionVariants']:
            assert variant['InitialInstanceCount'] == int(
                variant['InitialInstanceCount'])

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model,
                     mock_client):
        mock_endpoint.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_model.assert_called_once_with(create_model_params)
        mock_endpoint_config.assert_called_once_with(
            create_endpoint_config_params)
        mock_endpoint.assert_called_once_with(create_endpoint_params,
                                              wait_for_completion=False,
                                              check_interval=5,
                                              max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config,
                                  mock_model, mock_client):
        mock_endpoint.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        with pytest.raises(AirflowException):
            self.sagemaker.execute(None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    @mock.patch.object(SageMakerHook, 'update_endpoint')
    def test_execute_with_duplicate_endpoint_creation(self,
                                                      mock_endpoint_update,
                                                      mock_endpoint,
                                                      mock_endpoint_config,
                                                      mock_model, mock_client):
        response = {
            "Error": {
                "Code": "ValidationException",
                "Message": "Cannot create already existing endpoint."
            }
        }
        mock_endpoint.side_effect = ClientError(
            error_response=response, operation_name="CreateEndpoint")
        mock_endpoint_update.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            },
        }
        self.sagemaker.execute(None)