def test_create_endpoint_step(trained_estimator, record_set_fixture,
                              sfn_client, sagemaker_session, sfn_role_arn):
    # Setup: Create model and endpoint config for trained estimator in SageMaker
    model = trained_estimator.create_model()
    model._create_sagemaker_model(instance_type=INSTANCE_TYPE)
    endpoint_config = model.sagemaker_session.create_endpoint_config(
        name=model.name,
        model_name=model.name,
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE)
    # End of Setup

    # Build workflow definition
    endpoint_name = unique_name_from_base("integ-test-endpoint")
    endpoint_step = EndpointStep('create_endpoint_step',
                                 endpoint_name=endpoint_name,
                                 endpoint_config_name=model.name)
    endpoint_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([endpoint_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-create-endpoint-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        endpoint_arn = execution_output.get("EndpointArn")
        assert execution_output.get("EndpointArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        delete_sagemaker_endpoint(endpoint_name, sagemaker_session)
        delete_sagemaker_endpoint_config(model.name, sagemaker_session)
        delete_sagemaker_model(model.name, sagemaker_session)
def test_endpoint_step_creation(pca_model):
    step = EndpointStep('Endpoint',
                        endpoint_name='MyEndPoint',
                        endpoint_config_name='MyEndpointConfig',
                        tags=DEFAULT_TAGS)
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'EndpointConfigName': 'MyEndpointConfig',
            'EndpointName': 'MyEndPoint',
            'Tags': DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createEndpoint',
        'End': True
    }

    step = EndpointStep('Endpoint',
                        endpoint_name='MyEndPoint',
                        endpoint_config_name='MyEndpointConfig',
                        update=True,
                        tags=DEFAULT_TAGS)
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'EndpointConfigName': 'MyEndpointConfig',
            'EndpointName': 'MyEndPoint',
            'Tags': DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:updateEndpoint',
        'End': True
    }