Exemple #1
0
 def setUp(self):
     self.sagemaker = SageMakerEndpointOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=config,
         wait_for_completion=False,
         check_interval=5,
         operation='create')
Exemple #2
0
class TestSageMakerEndpointOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerEndpointOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=config,
            wait_for_completion=False,
            check_interval=5,
            operation='create')

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        for variant in self.sagemaker.config['EndpointConfig'][
                'ProductionVariants']:
            self.assertEqual(variant['InitialInstanceCount'],
                             int(variant['InitialInstanceCount']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute(self, mock_endpoint, mock_endpoint_config, mock_model,
                     mock_client):
        mock_endpoint.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_model.assert_called_once_with(create_model_params)
        mock_endpoint_config.assert_called_once_with(
            create_endpoint_config_params)
        mock_endpoint.assert_called_once_with(create_endpoint_params,
                                              wait_for_completion=False,
                                              check_interval=5,
                                              max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config,
                                  mock_model, mock_client):
        mock_endpoint.return_value = {
            'EndpointArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
class TestSageMakerEndpointOperator(unittest.TestCase):

    def setUp(self):
        configuration.load_test_config()
        self.sagemaker = SageMakerEndpointOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=config,
            wait_for_completion=False,
            check_interval=5,
            operation='create'
        )

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        for variant in self.sagemaker.config['EndpointConfig']['ProductionVariants']:
            self.assertEqual(variant['InitialInstanceCount'],
                             int(variant['InitialInstanceCount']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute(self, mock_endpoint, mock_endpoint_config,
                     mock_model, mock_client):
        mock_endpoint.return_value = {'EndpointArn': 'testarn',
                                      'ResponseMetadata':
                                      {'HTTPStatusCode': 200}}
        self.sagemaker.execute(None)
        mock_model.assert_called_once_with(create_model_params)
        mock_endpoint_config.assert_called_once_with(create_endpoint_config_params)
        mock_endpoint.assert_called_once_with(create_endpoint_params,
                                              wait_for_completion=False,
                                              check_interval=5,
                                              max_ingestion_time=None
                                              )

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_model')
    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
    @mock.patch.object(SageMakerHook, 'create_endpoint')
    def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config,
                                  mock_model, mock_client):
        mock_endpoint.return_value = {'EndpointArn': 'testarn',
                                      'ResponseMetadata':
                                      {'HTTPStatusCode': 404}}
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
 def setUp(self):
     configuration.load_test_config()
     self.sagemaker = SageMakerEndpointOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=config,
         wait_for_completion=False,
         check_interval=5,
         operation='create'
     )
Exemple #5
0
sagemaker_deploy_model = SageMakerEndpointOperator(
    task_id="sagemaker_deploy_model",
    operation="update",
    wait_for_completion=True,
    config={
        "Model": {
            "ModelName":
            "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
            "PrimaryContainer": {
                "Image":
                "438346466558.dkr.ecr.eu-west-1.amazonaws.com/kmeans:1",
                "ModelDataUrl":
                ("s3://your-bucket/mnistclassifier-output/mnistclassifier"
                 "-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/"
                 "output/model.tar.gz"
                 ),  # this will link the model and the training job
            },
            "ExecutionRoleArn":
            ("arn:aws:iam::297623009465:role/service-role/"
             "AmazonSageMaker-ExecutionRole-20180905T153196"),
        },
        "EndpointConfig": {
            "EndpointConfigName":
            "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
            "ProductionVariants": [{
                "InitialInstanceCount": 1,
                "InstanceType": "ml.t2.medium",
                "ModelName": "mnistclassifier",
                "VariantName": "AllTraffic",
            }],
        },
        "Endpoint": {
            "EndpointConfigName":
            "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
            "EndpointName": "mnistclassifier",
        },
    },
    dag=dag,
)
# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
    config=transform_config,
    aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30,
    trigger_rule=TriggerRule.ONE_SUCCESS
)

# Deploy Endpoint
deploy_model_task = SageMakerEndpointOperator(
    task_id='deployment',
    dag=dag,
    config=deploy_endpoint_config,
    aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30
)

cleanup_task = DummyOperator(
    task_id='cleaning_up',
    dag=dag)

# set the dependencies between tasks

init.set_downstream(preprocess_task)
preprocess_task.set_downstream(prepare_task)
prepare_task.set_downstream(branching)
branching.set_downstream(tune_model_task)
branching.set_downstream(train_model_task)
) as dag:
    process_task = PythonOperator(
        task_id="process",
        dag=dag,
        #provide_context=False,
        python_callable=preprocess_glue,
    )

    train_task = SageMakerTrainingOperator(
        task_id="train",
        config=training_config,
        aws_conn_id="airflow-sagemaker",
        wait_for_completion=True,
        check_interval=60,  #check status of the job every minute
        max_ingestion_time=
        None,  #allow training job to run as long as it needs, change for early stop
    )

    endpoint_deploy_task = SageMakerEndpointOperator(
        task_id="endpoint-deploy",
        config=endpoint_config,
        aws_conn_id="sagemaker-airflow",
        wait_for_completion=True,
        check_interval=60,  #check status of endpoint deployment every minute
        max_ingestion_time=None,
        operation=
        'create',  #change to update if you are updating rather than creating an endpoint
    )

    # set the dependencies between tasks
    process_task >> train_task >> endpoint_deploy_task