def test_endpoint_config_step_creation(pca_model):
    data_capture_config = DataCaptureConfig(
        enable_capture=True,
        sampling_percentage=100,
        destination_s3_uri='s3://sagemaker/datacapture')
    step = EndpointConfigStep(
        'Endpoint Config',
        endpoint_config_name='MyEndpointConfig',
        model_name='pca-model',
        initial_instance_count=1,
        instance_type='ml.p2.xlarge',
        data_capture_config=data_capture_config,
        tags=DEFAULT_TAGS,
    )
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'EndpointConfigName':
            'MyEndpointConfig',
            'ProductionVariants': [{
                'InitialInstanceCount': 1,
                'InstanceType': 'ml.p2.xlarge',
                'ModelName': 'pca-model',
                'VariantName': 'AllTraffic'
            }],
            'DataCaptureConfig': {
                'EnableCapture':
                True,
                'InitialSamplingPercentage':
                100,
                'DestinationS3Uri':
                's3://sagemaker/datacapture',
                'CaptureOptions': [{
                    'CaptureMode': 'Input'
                }, {
                    'CaptureMode': 'Output'
                }],
                'CaptureContentTypeHeader': {
                    'CsvContentTypes': ['text/csv'],
                    'JsonContentTypes': ['application/json']
                }
            },
            'Tags':
            DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createEndpointConfig',
        'End': True
    }
def test_endpoint_config_step_creation(pca_model):
    step = EndpointConfigStep('Endpoint Config',
                              endpoint_config_name='MyEndpointConfig',
                              model_name='pca-model',
                              initial_instance_count=1,
                              instance_type='ml.p2.xlarge')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'EndpointConfigName':
            'MyEndpointConfig',
            'ProductionVariants': [{
                'InitialInstanceCount': 1,
                'InstanceType': 'ml.p2.xlarge',
                'ModelName': 'pca-model',
                'VariantName': 'AllTraffic'
            }]
        },
        'Resource': 'arn:aws:states:::sagemaker:createEndpointConfig',
        'End': True
    }
def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
                              sfn_role_arn):
    # Setup: Create model for trained estimator in SageMaker
    model = trained_estimator.create_model()
    model._create_sagemaker_model(instance_type=INSTANCE_TYPE)
    # End of Setup

    # Build workflow definition
    endpoint_config_name = unique_name_from_base("integ-test-endpoint-config")
    endpoint_config_step = EndpointConfigStep(
        'create_endpoint_config_step',
        endpoint_config_name=endpoint_config_name,
        model_name=model.name,
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE)
    endpoint_config_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([endpoint_config_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-create-endpoint-config-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("EndpointConfigArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        delete_sagemaker_endpoint_config(endpoint_config_name,
                                         sagemaker_session)
        delete_sagemaker_model(model.name, sagemaker_session)