def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client, sagemaker_session, sfn_role_arn): # Setup: Create model and endpoint config for trained estimator in SageMaker model = trained_estimator.create_model() model._create_sagemaker_model(instance_type=INSTANCE_TYPE) endpoint_config = model.sagemaker_session.create_endpoint_config( name=model.name, model_name=model.name, initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE) # End of Setup # Build workflow definition endpoint_name = unique_name_from_base("integ-test-endpoint") endpoint_step = EndpointStep('create_endpoint_step', endpoint_name=endpoint_name, endpoint_config_name=model.name) endpoint_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([endpoint_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): # Create workflow and check definition workflow = create_workflow_and_check_definition( workflow_graph=workflow_graph, workflow_name=unique_name_from_base( "integ-test-create-endpoint-step-workflow"), sfn_client=sfn_client, sfn_role_arn=sfn_role_arn) # Execute workflow execution = workflow.execute() execution_output = execution.get_output(wait=True) # Check workflow output endpoint_arn = execution_output.get("EndpointArn") assert execution_output.get("EndpointArn") is not None assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200 # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn) delete_sagemaker_endpoint(endpoint_name, sagemaker_session) delete_sagemaker_endpoint_config(model.name, sagemaker_session) delete_sagemaker_model(model.name, sagemaker_session)
def test_endpoint_step_creation(pca_model): step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'EndpointConfigName': 'MyEndpointConfig', 'EndpointName': 'MyEndPoint', 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createEndpoint', 'End': True } step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True, tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'EndpointConfigName': 'MyEndpointConfig', 'EndpointName': 'MyEndPoint', 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:updateEndpoint', 'End': True }