def test_quality_check_step_invalid_config( check_job_config, model_package_group_name, supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, ): quality_check_config = QualityCheckConfig( baseline_dataset="baseline_dataset_s3_url", dataset_format=DatasetFormat.csv(header=True), post_analytics_processor_script= "s3://my_bucket/data_quality/postprocessor.py", output_s3_uri="", ) with pytest.raises(Exception) as error: QualityCheckStep( name="QualityCheckStep", register_new_baseline=False, skip_check=False, quality_check_config=quality_check_config, check_job_config=check_job_config, model_package_group_name=model_package_group_name, supplied_baseline_statistics=supplied_baseline_statistics_uri, supplied_baseline_constraints=supplied_baseline_constraints_uri, ) assert ( str(error.value) == "The quality_check_config can only be object of DataQualityCheckConfig or ModelQualityCheckConfig" )
def test_model_quality_check_step( sagemaker_session, check_job_config, model_package_group_name, supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, ): model_quality_check_config = ModelQualityCheckConfig( baseline_dataset="baseline_dataset_s3_url", dataset_format=DatasetFormat.csv(header=True), problem_type="BinaryClassification", probability_attribute=0, # the integer should be converted to str by SDK ground_truth_attribute=None, probability_threshold_attribute= 0.5, # the float should be converted to str by SDK post_analytics_processor_script= "s3://my_bucket/data_quality/postprocessor.py", output_s3_uri="", ) model_quality_check_step = QualityCheckStep( name="ModelQualityCheckStep", register_new_baseline=False, skip_check=False, quality_check_config=model_quality_check_config, check_job_config=check_job_config, model_package_group_name=model_package_group_name, supplied_baseline_statistics=supplied_baseline_statistics_uri, supplied_baseline_constraints=supplied_baseline_constraints_uri, ) pipeline = Pipeline( name="MyPipeline", parameters=[ supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, model_package_group_name, ], steps=[model_quality_check_step], sagemaker_session=sagemaker_session, ) step_definition = _get_step_definition_for_test(pipeline) assert step_definition == _expected_model_quality_dsl
def test_quality_check_step_properties( check_job_config, model_package_group_name, supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, ): model_quality_check_config = ModelQualityCheckConfig( baseline_dataset="baseline_dataset_s3_url", dataset_format=DatasetFormat.csv(header=True), problem_type="BinaryClassification", probability_attribute="0", probability_threshold_attribute="0.5", post_analytics_processor_script= "s3://my_bucket/data_quality/postprocessor.py", output_s3_uri="", ) model_quality_check_step = QualityCheckStep( name="ModelQualityCheckStep", register_new_baseline=False, skip_check=False, quality_check_config=model_quality_check_config, check_job_config=check_job_config, model_package_group_name=model_package_group_name, supplied_baseline_statistics=supplied_baseline_statistics_uri, supplied_baseline_constraints=supplied_baseline_constraints_uri, ) assert model_quality_check_step.properties.CalculatedBaselineConstraints.expr == { "Get": "Steps.ModelQualityCheckStep.CalculatedBaselineConstraints" } assert model_quality_check_step.properties.CalculatedBaselineStatistics.expr == { "Get": "Steps.ModelQualityCheckStep.CalculatedBaselineStatistics" } assert model_quality_check_step.properties.BaselineUsedForDriftCheckStatistics.expr == { "Get": "Steps.ModelQualityCheckStep.BaselineUsedForDriftCheckStatistics" } assert model_quality_check_step.properties.BaselineUsedForDriftCheckConstraints.expr == { "Get": "Steps.ModelQualityCheckStep.BaselineUsedForDriftCheckConstraints" }
def test_data_quality_check_step( sagemaker_session, check_job_config, model_package_group_name, supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, ): data_quality_check_config = DataQualityCheckConfig( baseline_dataset=ParameterString(name="BaselineDataset"), dataset_format=DatasetFormat.csv(header=True), output_s3_uri="s3://...", record_preprocessor_script= "s3://my_bucket/data_quality/preprocessor.py", post_analytics_processor_script= "s3://my_bucket/data_quality/postprocessor.py", ) data_quality_check_step = QualityCheckStep( name="DataQualityCheckStep", skip_check=False, register_new_baseline=False, quality_check_config=data_quality_check_config, check_job_config=check_job_config, model_package_group_name=model_package_group_name, supplied_baseline_statistics=supplied_baseline_statistics_uri, supplied_baseline_constraints=supplied_baseline_constraints_uri, cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"), ) pipeline = Pipeline( name="MyPipeline", parameters=[ supplied_baseline_statistics_uri, supplied_baseline_constraints_uri, model_package_group_name, ], steps=[data_quality_check_step], sagemaker_session=sagemaker_session, ) step_definition = _get_step_definition_for_test( pipeline, ["baseline_dataset_input", "quality_check_output"]) assert step_definition == _expected_data_quality_dsl
def test_one_step_model_quality_pipeline_constraint_violation( sagemaker_session, role, pipeline_name, check_job_config, supplied_baseline_statistics_uri_param, supplied_baseline_constraints_uri_param, model_quality_check_config, model_quality_supplied_baseline_statistics, ): model_quality_supplied_baseline_constraints = Constraints.from_file_path( constraints_file_path=os.path.join( DATA_DIR, "pipeline/quality_check_step/model_quality/bad_cases/constraints.json" ), sagemaker_session=sagemaker_session, ).file_s3_uri model_quality_check_step = QualityCheckStep( name="ModelQualityCheckStep", register_new_baseline=False, skip_check=False, quality_check_config=model_quality_check_config, check_job_config=check_job_config, supplied_baseline_statistics=supplied_baseline_statistics_uri_param, supplied_baseline_constraints=supplied_baseline_constraints_uri_param, ) pipeline = Pipeline( name=pipeline_name, steps=[model_quality_check_step], parameters=[ supplied_baseline_statistics_uri_param, supplied_baseline_constraints_uri_param, ], sagemaker_session=sagemaker_session, ) try: response = pipeline.create(role) create_arn = response["PipelineArn"] for _ in retries( max_retry_count=5, exception_message_prefix="Waiting for a successful execution of pipeline", seconds_to_sleep=10, ): execution = pipeline.start( parameters={ "SuppliedBaselineStatisticsUri": model_quality_supplied_baseline_statistics, "SuppliedBaselineConstraintsUri": model_quality_supplied_baseline_constraints, } ) response = execution.describe() assert response["PipelineArn"] == create_arn try: execution.wait(delay=30, max_attempts=60) except WaiterError: pass execution_steps = execution.list_steps() assert len(execution_steps) == 1 failure_reason = execution_steps[0].get("FailureReason", "") if _CHECK_FAIL_ERROR_MSG not in failure_reason: logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..") continue assert execution_steps[0]["StepName"] == "ModelQualityCheckStep" assert execution_steps[0]["StepStatus"] == "Failed" break finally: try: pipeline.delete() except Exception: pass
def test_one_step_data_quality_pipeline_happycase( sagemaker_session, role, pipeline_name, check_job_config, supplied_baseline_statistics_uri_param, supplied_baseline_constraints_uri_param, data_quality_check_config, data_quality_supplied_baseline_statistics, ): data_quality_supplied_baseline_constraints = Constraints.from_file_path( constraints_file_path=os.path.join( DATA_DIR, "pipeline/quality_check_step/data_quality/good_cases/constraints.json" ), sagemaker_session=sagemaker_session, ).file_s3_uri data_quality_check_step = QualityCheckStep( name="DataQualityCheckStep", skip_check=False, register_new_baseline=False, quality_check_config=data_quality_check_config, check_job_config=check_job_config, supplied_baseline_statistics=supplied_baseline_statistics_uri_param, supplied_baseline_constraints=supplied_baseline_constraints_uri_param, ) pipeline = Pipeline( name=pipeline_name, steps=[data_quality_check_step], parameters=[ supplied_baseline_statistics_uri_param, supplied_baseline_constraints_uri_param, ], sagemaker_session=sagemaker_session, ) try: response = pipeline.create(role) create_arn = response["PipelineArn"] for _ in retries( max_retry_count=5, exception_message_prefix="Waiting for a successful execution of pipeline", seconds_to_sleep=10, ): execution = pipeline.start( parameters={ "SuppliedBaselineStatisticsUri": data_quality_supplied_baseline_statistics, "SuppliedBaselineConstraintsUri": data_quality_supplied_baseline_constraints, } ) response = execution.describe() assert response["PipelineArn"] == create_arn try: execution.wait(delay=30, max_attempts=60) except WaiterError: pass execution_steps = execution.list_steps() assert len(execution_steps) == 1 failure_reason = execution_steps[0].get("FailureReason", "") if failure_reason != "": logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..") continue assert execution_steps[0]["StepName"] == "DataQualityCheckStep" assert execution_steps[0]["StepStatus"] == "Succeeded" data_qual_metadata = execution_steps[0]["Metadata"]["QualityCheck"] assert not data_qual_metadata["SkipCheck"] assert not data_qual_metadata["RegisterNewBaseline"] assert not data_qual_metadata.get("ViolationReport", "") assert ( data_qual_metadata["BaselineUsedForDriftCheckConstraints"] == data_quality_supplied_baseline_constraints ) assert ( data_qual_metadata["BaselineUsedForDriftCheckStatistics"] == data_quality_supplied_baseline_statistics ) assert ( data_qual_metadata["BaselineUsedForDriftCheckConstraints"] != data_qual_metadata["CalculatedBaselineConstraints"] ) assert ( data_qual_metadata["BaselineUsedForDriftCheckStatistics"] != data_qual_metadata["CalculatedBaselineStatistics"] ) break finally: try: pipeline.delete() except Exception: pass