def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
    training_step = TrainingStep('Training',
                                 estimator=tensorflow_estimator,
                                 data={'train': 's3://sagemaker/train'},
                                 job_name='tensorflow-job',
                                 mini_batch_size=1024)
    expected_model = training_step.get_expected_model()
    expected_model.entry_point = 'tf_train.py'
    model_step = ModelStep('Create model',
                           model=expected_model,
                           model_name='tf-model')
    assert model_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName': 'tf-model',
            'PrimaryContainer': {
                'Environment': {
                    'SAGEMAKER_PROGRAM': 'tf_train.py',
                    'SAGEMAKER_SUBMIT_DIRECTORY':
                    's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz',
                    'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
                    'SAGEMAKER_REGION': 'us-east-1',
                },
                'Image': expected_model.image_uri,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'End': True
    }
def test_model_step(trained_estimator, sfn_client, sagemaker_session,
                    sfn_role_arn):
    # Build workflow definition
    model_name = generate_job_name()
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=model_name)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-model-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("ModelArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        model_name = get_resource_name_from_arn(
            execution_output.get("ModelArn")).split("/")[1]
        delete_sagemaker_model(model_name, sagemaker_session)
def test_training_step_creation_with_model(pca_estimator):
    training_step = TrainingStep('Training',
                                 estimator=pca_estimator,
                                 job_name='TrainingJob')
    model_step = ModelStep(
        'Training - Save Model',
        training_step.get_expected_model(
            model_name=training_step.output()['TrainingJobName']))
    training_step.next(model_step)
    assert training_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'Next': 'Training - Save Model'
    }

    assert model_step.to_dict() == {
        'Type': 'Task',
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName.$': "$['TrainingJobName']",
            'PrimaryContainer': {
                'Environment': {},
                'Image': PCA_IMAGE,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'End': True
    }
def test_model_step_with_placeholders(trained_estimator, sfn_client,
                                      sagemaker_session, sfn_role_arn):
    # Build workflow definition
    execution_input = ExecutionInput(schema={
        'ModelName': str,
        'Mode': str,
        'Tags': list
    })

    parameters = {
        'PrimaryContainer': {
            'Mode': execution_input['Mode']
        },
        'Tags': execution_input['Tags']
    }

    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=execution_input['ModelName'],
                           parameters=parameters)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-model-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        inputs = {
            'ModelName': generate_job_name(),
            'Mode': 'SingleModel',
            'Tags': [{
                'Key': 'Environment',
                'Value': 'test'
            }]
        }

        # Execute workflow
        execution = workflow.execute(inputs=inputs)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("ModelArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        model_name = get_resource_name_from_arn(
            execution_output.get("ModelArn")).split("/")[1]
        delete_sagemaker_model(model_name, sagemaker_session)
def test_model_step_creation(pca_model):
    step = ModelStep('Create model', model=pca_model, model_name='pca-model')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName': 'pca-model',
            'PrimaryContainer': {
                'Environment': {},
                'Image': pca_model.image,
                'ModelDataUrl': pca_model.model_data
            }
        },
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'End': True
    }
def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
    # Create transformer from previously created estimator
    job_name = generate_job_name()
    pca_transformer = trained_estimator.transformer(
        instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE)

    # Create a model step to save the model
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=job_name)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)

    # Upload data for transformation to S3
    data_path = os.path.join(DATA_DIR, "one_p_mnist")
    transform_input_path = os.path.join(data_path, "transform_input.csv")
    transform_input_key_prefix = "integ-test-data/one_p_mnist/transform"
    transform_input = pca_transformer.sagemaker_session.upload_data(
        path=transform_input_path, key_prefix=transform_input_key_prefix)

    # Build workflow definition
    transform_step = TransformStep('create_transform_job_step',
                                   pca_transformer,
                                   job_name=job_name,
                                   model_name=job_name,
                                   data=transform_input,
                                   content_type="text/csv")
    transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step, transform_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-transform-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TransformJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_get_expected_model(pca_estimator):
    training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')
    expected_model = training_step.get_expected_model()
    model_step = ModelStep('Create model', model=expected_model, model_name='pca-model')
    assert model_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName': 'pca-model',
            'PrimaryContainer': {
                'Environment': {},
                'Image': expected_model.image,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'End': True
    }
def test_transform_step_with_placeholder(trained_estimator, sfn_client,
                                         sfn_role_arn):
    # Create transformer from supplied estimator
    job_name = generate_job_name()
    pca_transformer = trained_estimator.transformer(
        instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE)

    # Create a model step to save the model
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=job_name)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)

    # Upload data for transformation to S3
    data_path = os.path.join(DATA_DIR, "one_p_mnist")
    transform_input_path = os.path.join(data_path, "transform_input.csv")
    transform_input_key_prefix = "integ-test-data/one_p_mnist/transform"
    transform_input = pca_transformer.sagemaker_session.upload_data(
        path=transform_input_path, key_prefix=transform_input_key_prefix)

    execution_input = ExecutionInput(
        schema={
            'data': str,
            'content_type': str,
            'split_type': str,
            'job_name': str,
            'model_name': str,
            'instance_count': int,
            'instance_type': str,
            'strategy': str,
            'max_concurrent_transforms': int,
            'max_payload': int,
        })

    parameters = {
        'BatchStrategy': execution_input['strategy'],
        'TransformInput': {
            'SplitType': execution_input['split_type'],
        },
        'TransformResources': {
            'InstanceCount': execution_input['instance_count'],
            'InstanceType': execution_input['instance_type'],
        },
        'MaxConcurrentTransforms':
        execution_input['max_concurrent_transforms'],
        'MaxPayloadInMB': execution_input['max_payload']
    }

    # Build workflow definition
    transform_step = TransformStep(
        'create_transform_job_step',
        pca_transformer,
        job_name=execution_input['job_name'],
        model_name=execution_input['model_name'],
        data=execution_input['data'],
        content_type=execution_input['content_type'],
        parameters=parameters)
    transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step, transform_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-transform-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        execution_input = {
            'job_name': job_name,
            'model_name': job_name,
            'data': transform_input,
            'content_type': "text/csv",
            'instance_count': INSTANCE_COUNT,
            'instance_type': INSTANCE_TYPE,
            'split_type': 'Line',
            'strategy': 'SingleRecord',
            'max_concurrent_transforms': 2,
            'max_payload': 5
        }

        # Execute workflow
        execution = workflow.execute(inputs=execution_input)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TransformJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)