def test_transform_step_creation(pca_transformer):
    step = TransformStep('Inference',
                         transformer=pca_transformer,
                         data='s3://sagemaker/inference',
                         job_name='transform-job',
                         model_name='pca-model')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ModelName': 'pca-model',
            'TransformInput': {
                'DataSource': {
                    'S3DataSource': {
                        'S3DataType': 'S3Prefix',
                        'S3Uri': 's3://sagemaker/inference'
                    }
                }
            },
            'TransformOutput': {
                'S3OutputPath': 's3://sagemaker/transform-output'
            },
            'TransformJobName': 'transform-job',
            'TransformResources': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge'
            }
        },
        'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync',
        'End': True
    }
def test_transform_step_creation(pca_transformer):
    step = TransformStep('Inference',
        transformer=pca_transformer,
        data='s3://sagemaker/inference',
        job_name='transform-job',
        model_name='pca-model',
        experiment_config={
            'ExperimentName': 'pca_experiment',
            'TrialName': 'pca_trial',
            'TrialComponentDisplayName': 'Transform'
        },
        tags=DEFAULT_TAGS,
        join_source='Input',
        output_filter='$[2:]',
        input_filter='$[1:]'
    )
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ModelName': 'pca-model',
            'TransformInput': {
                'DataSource': {
                    'S3DataSource': {
                        'S3DataType': 'S3Prefix',
                        'S3Uri': 's3://sagemaker/inference'
                    }
                }
            },
            'TransformOutput': {
                'S3OutputPath': 's3://sagemaker/transform-output'
            },
            'TransformJobName': 'transform-job',
            'TransformResources': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge'
            },
            'ExperimentConfig': {
                'ExperimentName': 'pca_experiment',
                'TrialName': 'pca_trial',
                'TrialComponentDisplayName': 'Transform'                
            },
            'DataProcessing': {
                'InputFilter': '$[1:]',
                'OutputFilter': '$[2:]',
                'JoinSource': 'Input',
            },
            'Tags': DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync',
        'End': True
    }
def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
    # Create transformer from previously created estimator
    job_name = generate_job_name()
    pca_transformer = trained_estimator.transformer(
        instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE)

    # Create a model step to save the model
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=job_name)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)

    # Upload data for transformation to S3
    data_path = os.path.join(DATA_DIR, "one_p_mnist")
    transform_input_path = os.path.join(data_path, "transform_input.csv")
    transform_input_key_prefix = "integ-test-data/one_p_mnist/transform"
    transform_input = pca_transformer.sagemaker_session.upload_data(
        path=transform_input_path, key_prefix=transform_input_key_prefix)

    # Build workflow definition
    transform_step = TransformStep('create_transform_job_step',
                                   pca_transformer,
                                   job_name=job_name,
                                   model_name=job_name,
                                   data=transform_input,
                                   content_type="text/csv")
    transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step, transform_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-transform-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TransformJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_transform_step_with_placeholder(trained_estimator, sfn_client,
                                         sfn_role_arn):
    # Create transformer from supplied estimator
    job_name = generate_job_name()
    pca_transformer = trained_estimator.transformer(
        instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE)

    # Create a model step to save the model
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=job_name)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)

    # Upload data for transformation to S3
    data_path = os.path.join(DATA_DIR, "one_p_mnist")
    transform_input_path = os.path.join(data_path, "transform_input.csv")
    transform_input_key_prefix = "integ-test-data/one_p_mnist/transform"
    transform_input = pca_transformer.sagemaker_session.upload_data(
        path=transform_input_path, key_prefix=transform_input_key_prefix)

    execution_input = ExecutionInput(
        schema={
            'data': str,
            'content_type': str,
            'split_type': str,
            'job_name': str,
            'model_name': str,
            'instance_count': int,
            'instance_type': str,
            'strategy': str,
            'max_concurrent_transforms': int,
            'max_payload': int,
        })

    parameters = {
        'BatchStrategy': execution_input['strategy'],
        'TransformInput': {
            'SplitType': execution_input['split_type'],
        },
        'TransformResources': {
            'InstanceCount': execution_input['instance_count'],
            'InstanceType': execution_input['instance_type'],
        },
        'MaxConcurrentTransforms':
        execution_input['max_concurrent_transforms'],
        'MaxPayloadInMB': execution_input['max_payload']
    }

    # Build workflow definition
    transform_step = TransformStep(
        'create_transform_job_step',
        pca_transformer,
        job_name=execution_input['job_name'],
        model_name=execution_input['model_name'],
        data=execution_input['data'],
        content_type=execution_input['content_type'],
        parameters=parameters)
    transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step, transform_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-transform-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        execution_input = {
            'job_name': job_name,
            'model_name': job_name,
            'data': transform_input,
            'content_type': "text/csv",
            'instance_count': INSTANCE_COUNT,
            'instance_type': INSTANCE_TYPE,
            'split_type': 'Line',
            'strategy': 'SingleRecord',
            'max_concurrent_transforms': 2,
            'max_payload': 5
        }

        # Execute workflow
        execution = workflow.execute(inputs=execution_input)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TransformJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)