def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
    training_step = TrainingStep('Training',
                                 estimator=tensorflow_estimator,
                                 data={'train': 's3://sagemaker/train'},
                                 job_name='tensorflow-job',
                                 mini_batch_size=1024)
    expected_model = training_step.get_expected_model()
    expected_model.entry_point = 'tf_train.py'
    model_step = ModelStep('Create model',
                           model=expected_model,
                           model_name='tf-model')
    assert model_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName': 'tf-model',
            'PrimaryContainer': {
                'Environment': {
                    'SAGEMAKER_PROGRAM': 'tf_train.py',
                    'SAGEMAKER_SUBMIT_DIRECTORY':
                    's3://sagemaker/tensorflow-job/source/sourcedir.tar.gz',
                    'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
                    'SAGEMAKER_REGION': 'us-east-1',
                },
                'Image': expected_model.image_uri,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'End': True
    }
def test_training_step_creation_with_framework(tensorflow_estimator):
    step = TrainingStep(
        'Training',
        estimator=tensorflow_estimator,
        data={'train': 's3://sagemaker/train'},
        job_name='tensorflow-job',
        mini_batch_size=1024,
        tags=DEFAULT_TAGS,
    )

    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': TENSORFLOW_IMAGE,
                'TrainingInputMode': 'File'
            },
            'InputDataConfig': [{
                'DataSource': {
                    'S3DataSource': {
                        'S3DataDistributionType': 'FullyReplicated',
                        'S3DataType': 'S3Prefix',
                        'S3Uri': 's3://sagemaker/train'
                    }
                },
                'ChannelName': 'train'
            }],
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'DebugHookConfig': {
                'S3OutputPath': 's3://sagemaker/models/debug'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.p2.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn':
            EXECUTION_ROLE,
            'HyperParameters': {
                'model_dir': '"s3://sagemaker/models/tensorflow-job/model"',
                'sagemaker_container_log_level': '20',
                'sagemaker_enable_cloudwatch_metrics': 'false',
                'sagemaker_job_name': '"tensorflow-job"',
                'sagemaker_program': '"tf_train.py"',
                'sagemaker_region': '"us-east-1"',
                'sagemaker_submit_directory': '"s3://sagemaker/source"'
            },
            'TrainingJobName':
            'tensorflow-job',
            'Tags':
            DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_creation(pca_estimator):
    step = TrainingStep('Training',
                        estimator=pca_estimator,
                        job_name='TrainingJob')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_merges_hyperparameters_from_constructor_and_estimator(
        tensorflow_estimator):
    step = TrainingStep('Training',
                        estimator=tensorflow_estimator,
                        data={'train': 's3://sagemaker/train'},
                        job_name='tensorflow-job',
                        mini_batch_size=1024,
                        hyperparameters={'key': 'value'})

    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': TENSORFLOW_IMAGE,
                'TrainingInputMode': 'File'
            },
            'InputDataConfig': [{
                'DataSource': {
                    'S3DataSource': {
                        'S3DataDistributionType': 'FullyReplicated',
                        'S3DataType': 'S3Prefix',
                        'S3Uri': 's3://sagemaker/train'
                    }
                },
                'ChannelName': 'train'
            }],
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'DebugHookConfig': {
                'S3OutputPath': 's3://sagemaker/models/debug'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.p2.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn':
            EXECUTION_ROLE,
            'HyperParameters': {
                'checkpoint_path':
                '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
                'evaluation_steps': '100',
                'key': 'value',
                'sagemaker_container_log_level': '20',
                'sagemaker_job_name': '"tensorflow-job"',
                'sagemaker_program': '"tf_train.py"',
                'sagemaker_region': '"us-east-1"',
                'sagemaker_submit_directory': '"s3://sagemaker/source"',
                'training_steps': '1000',
            },
            'TrainingJobName':
            'tensorflow-job',
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client,
                       sfn_role_arn):
    # Build workflow definition
    job_name = generate_job_name()
    training_step = TrainingStep('create_training_job_step',
                                 estimator=pca_estimator_fixture,
                                 job_name=job_name,
                                 data=record_set_fixture,
                                 mini_batch_size=200)
    training_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([training_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-training-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TrainingJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):
    step = TrainingStep('Training',
        estimator=pca_estimator_with_debug_hook,
        job_name='TrainingJob')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'DebugHookConfig': {
                'S3OutputPath': 's3://sagemaker/output/debug',
                'HookParameters': {'save_interval': '1'},
                'CollectionConfigurations': [
                    {'CollectionName': 'hyperparameters'},
                    {'CollectionName': 'metrics'}
                ]
            },
            'DebugRuleConfigurations': [
                {
                    'RuleConfigurationName': 'Confusion',
                    'RuleEvaluatorImage': '503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest',
                    'RuleParameters': {
                        'rule_to_invoke': 'Confusion',
                        'category_no': '15',
                        'min_diag': '0.7',
                        'max_off_diag': '0.3',
                        'start_step': '17',
                        'end_step': '19'
                    }
                }
            ],
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_get_expected_model(pca_estimator):
    training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')
    expected_model = training_step.get_expected_model()
    model_step = ModelStep('Create model', model=expected_model, model_name='pca-model')
    assert model_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName': 'pca-model',
            'PrimaryContainer': {
                'Environment': {},
                'Image': expected_model.image,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'End': True
    }
def test_training_step_creation_with_model(pca_estimator):
    training_step = TrainingStep('Training',
                                 estimator=pca_estimator,
                                 job_name='TrainingJob')
    model_step = ModelStep(
        'Training - Save Model',
        training_step.get_expected_model(
            model_name=training_step.output()['TrainingJobName']))
    training_step.next(model_step)
    assert training_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'Next': 'Training - Save Model'
    }

    assert model_step.to_dict() == {
        'Type': 'Task',
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName.$': "$['TrainingJobName']",
            'PrimaryContainer': {
                'Environment': {},
                'Image': PCA_IMAGE,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'End': True
    }
def test_training_step_creation_with_placeholders(pca_estimator):
    execution_input = ExecutionInput(schema={
        'Data': str,
        'OutputPath': str,
    })

    step_input = StepInput(schema={
        'JobName': str,
    })

    step = TrainingStep(
        'Training',
        estimator=pca_estimator,
        job_name=step_input['JobName'],
        data=execution_input['Data'],
        output_data_config_path=execution_input['OutputPath'],
        experiment_config={
            'ExperimentName': 'pca_experiment',
            'TrialName': 'pca_trial',
            'TrialComponentDisplayName': 'Training'
        },
        tags=DEFAULT_TAGS,
    )
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath.$': "$$.Execution.Input['OutputPath']"
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn':
            EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'InputDataConfig': [{
                'ChannelName': 'training',
                'DataSource': {
                    'S3DataSource': {
                        'S3DataDistributionType': 'FullyReplicated',
                        'S3DataType': 'S3Prefix',
                        'S3Uri.$': "$$.Execution.Input['Data']"
                    }
                }
            }],
            'ExperimentConfig': {
                'ExperimentName': 'pca_experiment',
                'TrialName': 'pca_trial',
                'TrialComponentDisplayName': 'Training'
            },
            'TrainingJobName.$':
            "$['JobName']",
            'Tags':
            DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_with_placeholders(pca_estimator_fixture,
                                         record_set_fixture, sfn_client,
                                         sfn_role_arn):
    execution_input = ExecutionInput(
        schema={
            'JobName': str,
            'HyperParameters': str,
            'InstanceCount': int,
            'InstanceType': str,
            'MaxRun': int
        })

    parameters = {
        'HyperParameters': execution_input['HyperParameters'],
        'ResourceConfig': {
            'InstanceCount': execution_input['InstanceCount'],
            'InstanceType': execution_input['InstanceType']
        },
        'StoppingCondition': {
            'MaxRuntimeInSeconds': execution_input['MaxRun']
        }
    }

    training_step = TrainingStep('create_training_job_step',
                                 estimator=pca_estimator_fixture,
                                 job_name=execution_input['JobName'],
                                 data=record_set_fixture,
                                 mini_batch_size=200,
                                 parameters=parameters)
    training_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([training_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-training-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        inputs = {
            'JobName': generate_job_name(),
            'HyperParameters': {
                "num_components": "48",
                "feature_dim": "784",
                "mini_batch_size": "250"
            },
            'InstanceCount': INSTANCE_COUNT,
            'InstanceType': INSTANCE_TYPE,
            'MaxRun': 100000
        }

        # Execute workflow
        execution = workflow.execute(inputs=inputs)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TrainingJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)