def test_transform_step_creation(pca_transformer): step = TransformStep('Inference', transformer=pca_transformer, data='s3://sagemaker/inference', job_name='transform-job', model_name='pca-model') assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'ModelName': 'pca-model', 'TransformInput': { 'DataSource': { 'S3DataSource': { 'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker/inference' } } }, 'TransformOutput': { 'S3OutputPath': 's3://sagemaker/transform-output' }, 'TransformJobName': 'transform-job', 'TransformResources': { 'InstanceCount': 1, 'InstanceType': 'ml.c4.xlarge' } }, 'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync', 'End': True }
def test_transform_step_creation(pca_transformer): step = TransformStep('Inference', transformer=pca_transformer, data='s3://sagemaker/inference', job_name='transform-job', model_name='pca-model', experiment_config={ 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Transform' }, tags=DEFAULT_TAGS, join_source='Input', output_filter='$[2:]', input_filter='$[1:]' ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'ModelName': 'pca-model', 'TransformInput': { 'DataSource': { 'S3DataSource': { 'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker/inference' } } }, 'TransformOutput': { 'S3OutputPath': 's3://sagemaker/transform-output' }, 'TransformJobName': 'transform-job', 'TransformResources': { 'InstanceCount': 1, 'InstanceType': 'ml.c4.xlarge' }, 'ExperimentConfig': { 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Transform' }, 'DataProcessing': { 'InputFilter': '$[1:]', 'OutputFilter': '$[2:]', 'JoinSource': 'Input', }, 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync', 'End': True }
def test_transform_step(trained_estimator, sfn_client, sfn_role_arn): # Create transformer from previously created estimator job_name = generate_job_name() pca_transformer = trained_estimator.transformer( instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE) # Create a model step to save the model model_step = ModelStep('create_model_step', model=trained_estimator.create_model(), model_name=job_name) model_step.add_retry(SAGEMAKER_RETRY_STRATEGY) # Upload data for transformation to S3 data_path = os.path.join(DATA_DIR, "one_p_mnist") transform_input_path = os.path.join(data_path, "transform_input.csv") transform_input_key_prefix = "integ-test-data/one_p_mnist/transform" transform_input = pca_transformer.sagemaker_session.upload_data( path=transform_input_path, key_prefix=transform_input_key_prefix) # Build workflow definition transform_step = TransformStep('create_transform_job_step', pca_transformer, job_name=job_name, model_name=job_name, data=transform_input, content_type="text/csv") transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([model_step, transform_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): # Create workflow and check definition workflow = create_workflow_and_check_definition( workflow_graph=workflow_graph, workflow_name=unique_name_from_base( "integ-test-transform-step-workflow"), sfn_client=sfn_client, sfn_role_arn=sfn_role_arn) # Execute workflow execution = workflow.execute() execution_output = execution.get_output(wait=True) # Check workflow output assert execution_output.get("TransformJobStatus") == "Completed" # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_transform_step_with_placeholder(trained_estimator, sfn_client, sfn_role_arn): # Create transformer from supplied estimator job_name = generate_job_name() pca_transformer = trained_estimator.transformer( instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE) # Create a model step to save the model model_step = ModelStep('create_model_step', model=trained_estimator.create_model(), model_name=job_name) model_step.add_retry(SAGEMAKER_RETRY_STRATEGY) # Upload data for transformation to S3 data_path = os.path.join(DATA_DIR, "one_p_mnist") transform_input_path = os.path.join(data_path, "transform_input.csv") transform_input_key_prefix = "integ-test-data/one_p_mnist/transform" transform_input = pca_transformer.sagemaker_session.upload_data( path=transform_input_path, key_prefix=transform_input_key_prefix) execution_input = ExecutionInput( schema={ 'data': str, 'content_type': str, 'split_type': str, 'job_name': str, 'model_name': str, 'instance_count': int, 'instance_type': str, 'strategy': str, 'max_concurrent_transforms': int, 'max_payload': int, }) parameters = { 'BatchStrategy': execution_input['strategy'], 'TransformInput': { 'SplitType': execution_input['split_type'], }, 'TransformResources': { 'InstanceCount': execution_input['instance_count'], 'InstanceType': execution_input['instance_type'], }, 'MaxConcurrentTransforms': execution_input['max_concurrent_transforms'], 'MaxPayloadInMB': execution_input['max_payload'] } # Build workflow definition transform_step = TransformStep( 'create_transform_job_step', pca_transformer, job_name=execution_input['job_name'], model_name=execution_input['model_name'], data=execution_input['data'], content_type=execution_input['content_type'], parameters=parameters) transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([model_step, transform_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): # Create workflow and check definition workflow = create_workflow_and_check_definition( workflow_graph=workflow_graph, workflow_name=unique_name_from_base( "integ-test-transform-step-workflow"), sfn_client=sfn_client, sfn_role_arn=sfn_role_arn) execution_input = { 'job_name': job_name, 'model_name': job_name, 'data': transform_input, 'content_type': "text/csv", 'instance_count': INSTANCE_COUNT, 'instance_type': INSTANCE_TYPE, 'split_type': 'Line', 'strategy': 'SingleRecord', 'max_concurrent_transforms': 2, 'max_payload': 5 } # Execute workflow execution = workflow.execute(inputs=execution_input) execution_output = execution.get_output(wait=True) # Check workflow output assert execution_output.get("TransformJobStatus") == "Completed" # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn)