def test_run_fail_create_estimator(
    mock_estimator_factory: EstimatorFactory,
    sm_execution_engine_to_test: SageMakerExecutionEngine,
    fetcher_event: FetcherBenchmarkEvent,
):
    mock_estimator_factory.side_effect = DescriptorError("Missing framework")
    with pytest.raises(ExecutionEngineException):
        sm_execution_engine_to_test.run(fetcher_event)
def test_run_fail_from_sagemaker(
    mock_estimator: Framework,
    sm_execution_engine_to_test: SageMakerExecutionEngine,
    fetcher_event: FetcherBenchmarkEvent,
):
    mock_estimator.fit.side_effect = botocore.exceptions.ClientError(
        MOCK_ERROR_RESPONSE, "start job")
    with pytest.raises(ExecutionEngineException) as err:
        sm_execution_engine_to_test.run(fetcher_event)
    assert str(
        err.value
    ) == "Benchmark creation failed. SageMaker returned error: Something is wrong"
def test_volume_size(
    file_size: int,
    expected_train_volume_size: int,
    prop_train_volume_size: PropertyMock,
    sm_execution_engine_to_test: SageMakerExecutionEngine,
    fetcher_event: FetcherBenchmarkEvent,
):
    fetcher_event.payload.datasets[0].size_info.max_size = file_size
    fetcher_event.payload.datasets[0].size_info.total_size = file_size

    sm_execution_engine_to_test.run(fetcher_event)
    prop_train_volume_size.assert_called_with(expected_train_volume_size)
def test_no_data(
    sagemaker_config: SageMakerExecutorConfig,
    sm_execution_engine_to_test: SageMakerExecutionEngine,
    fetcher_event: FetcherBenchmarkEvent,
    mock_estimator: Framework,
):
    fetcher_event.payload.datasets = []
    sm_execution_engine_to_test.run(fetcher_event)

    mock_estimator.fit.assert_called_with({"src0": sagemaker_config.s3_nodata},
                                          job_name=ACTION_ID,
                                          wait=False,
                                          logs=False)
def test_cancel_raises(sm_execution_engine_to_test: SageMakerExecutionEngine):
    stub = Stubber(sm_execution_engine_to_test.sagemaker_client)
    stub.add_response(method="describe_training_job",
                      service_response=sm_job_description(status="InProgress"))
    stub.add_client_error(
        "stop_training_job",
        service_error_code="SomeRandomError",
        service_message="Some random error has occured",
        expected_params={"TrainingJobName": ACTION_ID},
    )
    stub.activate()

    with pytest.raises(ExecutionEngineException):
        sm_execution_engine_to_test.cancel(CLIENT_ID, ACTION_ID)
def test_cancel_raises_not_found(
        sm_execution_engine_to_test: SageMakerExecutionEngine):
    stub = Stubber(sm_execution_engine_to_test.sagemaker_client)
    stub.add_response(method="describe_training_job",
                      service_response=sm_job_description(status="InProgress"))
    stub.add_client_error(
        "stop_training_job",
        service_error_code="ValidationException",
        service_message="Requested resource not found.",
        expected_params={"TrainingJobName": ACTION_ID},
    )
    stub.activate()

    with pytest.raises(NoResourcesFoundException):
        sm_execution_engine_to_test.cancel(CLIENT_ID, ACTION_ID)
def test_cancel_not_called_on_non_in_progress_status(
        sm_execution_engine_to_test: SageMakerExecutionEngine):
    stub = Stubber(sm_execution_engine_to_test.sagemaker_client)
    stub.add_response(method="describe_training_job",
                      service_response=sm_job_description(status="Failed"))
    stub.add_client_error(
        "stop_training_job",
        service_error_code="ValidationException",
        service_message="The job status is Failed",
        expected_params={"TrainingJobName": ACTION_ID},
    )
    stub.activate()

    try:
        sm_execution_engine_to_test.cancel(CLIENT_ID, ACTION_ID)
    except Exception as err:
        pytest.fail("stop training got called")
        raise err
def sm_execution_engine_to_test(mock_session_factory, mock_estimator_factory,
                                sagemaker_config: SageMakerExecutorConfig,
                                sagemaker_client):
    return SageMakerExecutionEngine(
        session_factory=mock_session_factory,
        estimator_factory=mock_estimator_factory,
        config=sagemaker_config,
        sagemaker_client=sagemaker_client,
    )
def test_run(
    descriptor: BenchmarkDescriptor,
    sm_execution_engine_to_test: SageMakerExecutionEngine,
    fetcher_event: FetcherBenchmarkEvent,
    mock_create_source_dir,
    mock_estimator: Framework,
):
    job = sm_execution_engine_to_test.run(fetcher_event)

    mock_estimator.fit.assert_called_with({DATASET_ID: DATASET_S3_URI},
                                          job_name=ACTION_ID,
                                          wait=False,
                                          logs=False)
    mock_create_source_dir.assert_called_with(descriptor, TMP_DIR, SCRIPTS)

    assert job == BenchmarkJob(CREATED_JOB_ID)
def test_merge_metrics(
    sm_execution_engine_to_test: SageMakerExecutionEngine,
    customparams_descriptor: BenchmarkDescriptor,
):
    metric_data = [
        {
            "MetricName": "iter",
            "Value": 51.900001525878906,
            "Timestamp": "1970-01-19T03:48:31.114000-08:00"
        },
        {
            "MetricName": "accuracy",
            "Value": 51.900001525878906,
            "Timestamp": "1970-01-19T03:48:31.114000-08:00"
        },
    ]
    metrics_with_dimensions = [
        {
            "MetricName":
            "iter",
            "Value":
            51.900001525878906,
            "Dimensions": [{
                "Name": "task_name",
                "Value": "exampleTask"
            }, {
                "Name": "batch_size",
                "Value": "64"
            }],
        },
        {
            "MetricName":
            "accuracy",
            "Value":
            51.900001525878906,
            "Dimensions": [{
                "Name": "task_name",
                "Value": "exampleTask"
            }, {
                "Name": "batch_size",
                "Value": "64"
            }],
        },
    ]
    assert sm_execution_engine_to_test.tag_dimensions(
        customparams_descriptor, metric_data) == metrics_with_dimensions
def test_run_invalid_descriptor(
        sm_execution_engine_to_test: SageMakerExecutionEngine,
        fetcher_event: FetcherBenchmarkEvent, mock_create_source_dir):
    fetcher_event.payload.toml.contents["hardware"] = {}
    with pytest.raises(ExecutionEngineException):
        sm_execution_engine_to_test.run(fetcher_event)
예제 #12
0
def create_execution_engines(sagemaker_config: SageMakerExecutorConfig):
    sm_engine = SageMakerExecutionEngine(session_factory=sagemaker.Session,
                                         estimator_factory=create_estimator,
                                         config=sagemaker_config)
    return {SageMakerExecutionEngine.ENGINE_ID: sm_engine}