def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client,
                       sfn_role_arn):
    # Build workflow definition
    job_name = generate_job_name()
    training_step = TrainingStep('create_training_job_step',
                                 estimator=pca_estimator_fixture,
                                 job_name=job_name,
                                 data=record_set_fixture,
                                 mini_batch_size=200)
    training_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([training_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-training-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TrainingJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_training_step_with_placeholders(pca_estimator_fixture,
                                         record_set_fixture, sfn_client,
                                         sfn_role_arn):
    execution_input = ExecutionInput(
        schema={
            'JobName': str,
            'HyperParameters': str,
            'InstanceCount': int,
            'InstanceType': str,
            'MaxRun': int
        })

    parameters = {
        'HyperParameters': execution_input['HyperParameters'],
        'ResourceConfig': {
            'InstanceCount': execution_input['InstanceCount'],
            'InstanceType': execution_input['InstanceType']
        },
        'StoppingCondition': {
            'MaxRuntimeInSeconds': execution_input['MaxRun']
        }
    }

    training_step = TrainingStep('create_training_job_step',
                                 estimator=pca_estimator_fixture,
                                 job_name=execution_input['JobName'],
                                 data=record_set_fixture,
                                 mini_batch_size=200,
                                 parameters=parameters)
    training_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([training_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-training-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        inputs = {
            'JobName': generate_job_name(),
            'HyperParameters': {
                "num_components": "48",
                "feature_dim": "784",
                "mini_batch_size": "250"
            },
            'InstanceCount': INSTANCE_COUNT,
            'InstanceType': INSTANCE_TYPE,
            'MaxRun': 100000
        }

        # Execute workflow
        execution = workflow.execute(inputs=inputs)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TrainingJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)