def test_run_fail_create_estimator( mock_estimator_factory: EstimatorFactory, sm_execution_engine_to_test: SageMakerExecutionEngine, fetcher_event: FetcherBenchmarkEvent, ): mock_estimator_factory.side_effect = DescriptorError("Missing framework") with pytest.raises(ExecutionEngineException): sm_execution_engine_to_test.run(fetcher_event)
def test_run_fail_from_sagemaker( mock_estimator: Framework, sm_execution_engine_to_test: SageMakerExecutionEngine, fetcher_event: FetcherBenchmarkEvent, ): mock_estimator.fit.side_effect = botocore.exceptions.ClientError( MOCK_ERROR_RESPONSE, "start job") with pytest.raises(ExecutionEngineException) as err: sm_execution_engine_to_test.run(fetcher_event) assert str( err.value ) == "Benchmark creation failed. SageMaker returned error: Something is wrong"
def test_volume_size( file_size: int, expected_train_volume_size: int, prop_train_volume_size: PropertyMock, sm_execution_engine_to_test: SageMakerExecutionEngine, fetcher_event: FetcherBenchmarkEvent, ): fetcher_event.payload.datasets[0].size_info.max_size = file_size fetcher_event.payload.datasets[0].size_info.total_size = file_size sm_execution_engine_to_test.run(fetcher_event) prop_train_volume_size.assert_called_with(expected_train_volume_size)
def test_no_data( sagemaker_config: SageMakerExecutorConfig, sm_execution_engine_to_test: SageMakerExecutionEngine, fetcher_event: FetcherBenchmarkEvent, mock_estimator: Framework, ): fetcher_event.payload.datasets = [] sm_execution_engine_to_test.run(fetcher_event) mock_estimator.fit.assert_called_with({"src0": sagemaker_config.s3_nodata}, job_name=ACTION_ID, wait=False, logs=False)
def test_cancel_raises(sm_execution_engine_to_test: SageMakerExecutionEngine): stub = Stubber(sm_execution_engine_to_test.sagemaker_client) stub.add_response(method="describe_training_job", service_response=sm_job_description(status="InProgress")) stub.add_client_error( "stop_training_job", service_error_code="SomeRandomError", service_message="Some random error has occured", expected_params={"TrainingJobName": ACTION_ID}, ) stub.activate() with pytest.raises(ExecutionEngineException): sm_execution_engine_to_test.cancel(CLIENT_ID, ACTION_ID)
def test_cancel_raises_not_found( sm_execution_engine_to_test: SageMakerExecutionEngine): stub = Stubber(sm_execution_engine_to_test.sagemaker_client) stub.add_response(method="describe_training_job", service_response=sm_job_description(status="InProgress")) stub.add_client_error( "stop_training_job", service_error_code="ValidationException", service_message="Requested resource not found.", expected_params={"TrainingJobName": ACTION_ID}, ) stub.activate() with pytest.raises(NoResourcesFoundException): sm_execution_engine_to_test.cancel(CLIENT_ID, ACTION_ID)
def test_cancel_not_called_on_non_in_progress_status( sm_execution_engine_to_test: SageMakerExecutionEngine): stub = Stubber(sm_execution_engine_to_test.sagemaker_client) stub.add_response(method="describe_training_job", service_response=sm_job_description(status="Failed")) stub.add_client_error( "stop_training_job", service_error_code="ValidationException", service_message="The job status is Failed", expected_params={"TrainingJobName": ACTION_ID}, ) stub.activate() try: sm_execution_engine_to_test.cancel(CLIENT_ID, ACTION_ID) except Exception as err: pytest.fail("stop training got called") raise err
def sm_execution_engine_to_test(mock_session_factory, mock_estimator_factory, sagemaker_config: SageMakerExecutorConfig, sagemaker_client): return SageMakerExecutionEngine( session_factory=mock_session_factory, estimator_factory=mock_estimator_factory, config=sagemaker_config, sagemaker_client=sagemaker_client, )
def test_run( descriptor: BenchmarkDescriptor, sm_execution_engine_to_test: SageMakerExecutionEngine, fetcher_event: FetcherBenchmarkEvent, mock_create_source_dir, mock_estimator: Framework, ): job = sm_execution_engine_to_test.run(fetcher_event) mock_estimator.fit.assert_called_with({DATASET_ID: DATASET_S3_URI}, job_name=ACTION_ID, wait=False, logs=False) mock_create_source_dir.assert_called_with(descriptor, TMP_DIR, SCRIPTS) assert job == BenchmarkJob(CREATED_JOB_ID)
def test_merge_metrics( sm_execution_engine_to_test: SageMakerExecutionEngine, customparams_descriptor: BenchmarkDescriptor, ): metric_data = [ { "MetricName": "iter", "Value": 51.900001525878906, "Timestamp": "1970-01-19T03:48:31.114000-08:00" }, { "MetricName": "accuracy", "Value": 51.900001525878906, "Timestamp": "1970-01-19T03:48:31.114000-08:00" }, ] metrics_with_dimensions = [ { "MetricName": "iter", "Value": 51.900001525878906, "Dimensions": [{ "Name": "task_name", "Value": "exampleTask" }, { "Name": "batch_size", "Value": "64" }], }, { "MetricName": "accuracy", "Value": 51.900001525878906, "Dimensions": [{ "Name": "task_name", "Value": "exampleTask" }, { "Name": "batch_size", "Value": "64" }], }, ] assert sm_execution_engine_to_test.tag_dimensions( customparams_descriptor, metric_data) == metrics_with_dimensions
def test_run_invalid_descriptor( sm_execution_engine_to_test: SageMakerExecutionEngine, fetcher_event: FetcherBenchmarkEvent, mock_create_source_dir): fetcher_event.payload.toml.contents["hardware"] = {} with pytest.raises(ExecutionEngineException): sm_execution_engine_to_test.run(fetcher_event)
def create_execution_engines(sagemaker_config: SageMakerExecutorConfig): sm_engine = SageMakerExecutionEngine(session_factory=sagemaker.Session, estimator_factory=create_estimator, config=sagemaker_config) return {SageMakerExecutionEngine.ENGINE_ID: sm_engine}