def test_training_step_creation_with_framework(tensorflow_estimator):
    step = TrainingStep(
        'Training',
        estimator=tensorflow_estimator,
        data={'train': 's3://sagemaker/train'},
        job_name='tensorflow-job',
        mini_batch_size=1024,
        tags=DEFAULT_TAGS,
    )

    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': TENSORFLOW_IMAGE,
                'TrainingInputMode': 'File'
            },
            'InputDataConfig': [{
                'DataSource': {
                    'S3DataSource': {
                        'S3DataDistributionType': 'FullyReplicated',
                        'S3DataType': 'S3Prefix',
                        'S3Uri': 's3://sagemaker/train'
                    }
                },
                'ChannelName': 'train'
            }],
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'DebugHookConfig': {
                'S3OutputPath': 's3://sagemaker/models/debug'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.p2.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn':
            EXECUTION_ROLE,
            'HyperParameters': {
                'model_dir': '"s3://sagemaker/models/tensorflow-job/model"',
                'sagemaker_container_log_level': '20',
                'sagemaker_enable_cloudwatch_metrics': 'false',
                'sagemaker_job_name': '"tensorflow-job"',
                'sagemaker_program': '"tf_train.py"',
                'sagemaker_region': '"us-east-1"',
                'sagemaker_submit_directory': '"s3://sagemaker/source"'
            },
            'TrainingJobName':
            'tensorflow-job',
            'Tags':
            DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_merges_hyperparameters_from_constructor_and_estimator(
        tensorflow_estimator):
    step = TrainingStep('Training',
                        estimator=tensorflow_estimator,
                        data={'train': 's3://sagemaker/train'},
                        job_name='tensorflow-job',
                        mini_batch_size=1024,
                        hyperparameters={'key': 'value'})

    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': TENSORFLOW_IMAGE,
                'TrainingInputMode': 'File'
            },
            'InputDataConfig': [{
                'DataSource': {
                    'S3DataSource': {
                        'S3DataDistributionType': 'FullyReplicated',
                        'S3DataType': 'S3Prefix',
                        'S3Uri': 's3://sagemaker/train'
                    }
                },
                'ChannelName': 'train'
            }],
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'DebugHookConfig': {
                'S3OutputPath': 's3://sagemaker/models/debug'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.p2.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn':
            EXECUTION_ROLE,
            'HyperParameters': {
                'checkpoint_path':
                '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
                'evaluation_steps': '100',
                'key': 'value',
                'sagemaker_container_log_level': '20',
                'sagemaker_job_name': '"tensorflow-job"',
                'sagemaker_program': '"tf_train.py"',
                'sagemaker_region': '"us-east-1"',
                'sagemaker_submit_directory': '"s3://sagemaker/source"',
                'training_steps': '1000',
            },
            'TrainingJobName':
            'tensorflow-job',
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_creation(pca_estimator):
    step = TrainingStep('Training',
                        estimator=pca_estimator,
                        job_name='TrainingJob')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):
    step = TrainingStep('Training',
        estimator=pca_estimator_with_debug_hook,
        job_name='TrainingJob')
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'DebugHookConfig': {
                'S3OutputPath': 's3://sagemaker/output/debug',
                'HookParameters': {'save_interval': '1'},
                'CollectionConfigurations': [
                    {'CollectionName': 'hyperparameters'},
                    {'CollectionName': 'metrics'}
                ]
            },
            'DebugRuleConfigurations': [
                {
                    'RuleConfigurationName': 'Confusion',
                    'RuleEvaluatorImage': '503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest',
                    'RuleParameters': {
                        'rule_to_invoke': 'Confusion',
                        'category_no': '15',
                        'min_diag': '0.7',
                        'max_off_diag': '0.3',
                        'start_step': '17',
                        'end_step': '19'
                    }
                }
            ],
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }
def test_training_step_creation_with_model(pca_estimator):
    training_step = TrainingStep('Training',
                                 estimator=pca_estimator,
                                 job_name='TrainingJob')
    model_step = ModelStep(
        'Training - Save Model',
        training_step.get_expected_model(
            model_name=training_step.output()['TrainingJobName']))
    training_step.next(model_step)
    assert training_step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath': 's3://sagemaker/models'
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn': EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'TrainingJobName': 'TrainingJob'
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'Next': 'Training - Save Model'
    }

    assert model_step.to_dict() == {
        'Type': 'Task',
        'Resource': 'arn:aws:states:::sagemaker:createModel',
        'Parameters': {
            'ExecutionRoleArn': EXECUTION_ROLE,
            'ModelName.$': "$['TrainingJobName']",
            'PrimaryContainer': {
                'Environment': {},
                'Image': PCA_IMAGE,
                'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
            }
        },
        'End': True
    }
def test_training_step_creation_with_placeholders(pca_estimator):
    execution_input = ExecutionInput(schema={
        'Data': str,
        'OutputPath': str,
    })

    step_input = StepInput(schema={
        'JobName': str,
    })

    step = TrainingStep(
        'Training',
        estimator=pca_estimator,
        job_name=step_input['JobName'],
        data=execution_input['Data'],
        output_data_config_path=execution_input['OutputPath'],
        experiment_config={
            'ExperimentName': 'pca_experiment',
            'TrialName': 'pca_trial',
            'TrialComponentDisplayName': 'Training'
        },
        tags=DEFAULT_TAGS,
    )
    assert step.to_dict() == {
        'Type': 'Task',
        'Parameters': {
            'AlgorithmSpecification': {
                'TrainingImage': PCA_IMAGE,
                'TrainingInputMode': 'File'
            },
            'OutputDataConfig': {
                'S3OutputPath.$': "$$.Execution.Input['OutputPath']"
            },
            'StoppingCondition': {
                'MaxRuntimeInSeconds': 86400
            },
            'ResourceConfig': {
                'InstanceCount': 1,
                'InstanceType': 'ml.c4.xlarge',
                'VolumeSizeInGB': 30
            },
            'RoleArn':
            EXECUTION_ROLE,
            'HyperParameters': {
                'feature_dim': '50000',
                'num_components': '10',
                'subtract_mean': 'True',
                'algorithm_mode': 'randomized',
                'mini_batch_size': '200'
            },
            'InputDataConfig': [{
                'ChannelName': 'training',
                'DataSource': {
                    'S3DataSource': {
                        'S3DataDistributionType': 'FullyReplicated',
                        'S3DataType': 'S3Prefix',
                        'S3Uri.$': "$$.Execution.Input['Data']"
                    }
                }
            }],
            'ExperimentConfig': {
                'ExperimentName': 'pca_experiment',
                'TrialName': 'pca_trial',
                'TrialComponentDisplayName': 'Training'
            },
            'TrainingJobName.$':
            "$['JobName']",
            'Tags':
            DEFAULT_TAGS_LIST
        },
        'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
        'End': True
    }