def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn): # Build workflow definition job_name = generate_job_name() training_step = TrainingStep('create_training_job_step', estimator=pca_estimator_fixture, job_name=job_name, data=record_set_fixture, mini_batch_size=200) training_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([training_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): # Create workflow and check definition workflow = create_workflow_and_check_definition( workflow_graph=workflow_graph, workflow_name=unique_name_from_base( "integ-test-training-step-workflow"), sfn_client=sfn_client, sfn_role_arn=sfn_role_arn) # Execute workflow execution = workflow.execute() execution_output = execution.get_output(wait=True) # Check workflow output assert execution_output.get("TrainingJobStatus") == "Completed" # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_training_step_with_placeholders(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn): execution_input = ExecutionInput( schema={ 'JobName': str, 'HyperParameters': str, 'InstanceCount': int, 'InstanceType': str, 'MaxRun': int }) parameters = { 'HyperParameters': execution_input['HyperParameters'], 'ResourceConfig': { 'InstanceCount': execution_input['InstanceCount'], 'InstanceType': execution_input['InstanceType'] }, 'StoppingCondition': { 'MaxRuntimeInSeconds': execution_input['MaxRun'] } } training_step = TrainingStep('create_training_job_step', estimator=pca_estimator_fixture, job_name=execution_input['JobName'], data=record_set_fixture, mini_batch_size=200, parameters=parameters) training_step.add_retry(SAGEMAKER_RETRY_STRATEGY) workflow_graph = Chain([training_step]) with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): # Create workflow and check definition workflow = create_workflow_and_check_definition( workflow_graph=workflow_graph, workflow_name=unique_name_from_base( "integ-test-training-step-workflow"), sfn_client=sfn_client, sfn_role_arn=sfn_role_arn) inputs = { 'JobName': generate_job_name(), 'HyperParameters': { "num_components": "48", "feature_dim": "784", "mini_batch_size": "250" }, 'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'MaxRun': 100000 } # Execute workflow execution = workflow.execute(inputs=inputs) execution_output = execution.get_output(wait=True) # Check workflow output assert execution_output.get("TrainingJobStatus") == "Completed" # Cleanup state_machine_delete_wait(sfn_client, workflow.state_machine_arn)