def test_quality_check_step_invalid_config(
    check_job_config,
    model_package_group_name,
    supplied_baseline_statistics_uri,
    supplied_baseline_constraints_uri,
):
    quality_check_config = QualityCheckConfig(
        baseline_dataset="baseline_dataset_s3_url",
        dataset_format=DatasetFormat.csv(header=True),
        post_analytics_processor_script=
        "s3://my_bucket/data_quality/postprocessor.py",
        output_s3_uri="",
    )
    with pytest.raises(Exception) as error:
        QualityCheckStep(
            name="QualityCheckStep",
            register_new_baseline=False,
            skip_check=False,
            quality_check_config=quality_check_config,
            check_job_config=check_job_config,
            model_package_group_name=model_package_group_name,
            supplied_baseline_statistics=supplied_baseline_statistics_uri,
            supplied_baseline_constraints=supplied_baseline_constraints_uri,
        )

    assert (
        str(error.value) ==
        "The quality_check_config can only be object of DataQualityCheckConfig or ModelQualityCheckConfig"
    )
def test_model_quality_check_step(
    sagemaker_session,
    check_job_config,
    model_package_group_name,
    supplied_baseline_statistics_uri,
    supplied_baseline_constraints_uri,
):
    model_quality_check_config = ModelQualityCheckConfig(
        baseline_dataset="baseline_dataset_s3_url",
        dataset_format=DatasetFormat.csv(header=True),
        problem_type="BinaryClassification",
        probability_attribute=0,  # the integer should be converted to str by SDK
        ground_truth_attribute=None,
        probability_threshold_attribute=
        0.5,  # the float should be converted to str by SDK
        post_analytics_processor_script=
        "s3://my_bucket/data_quality/postprocessor.py",
        output_s3_uri="",
    )
    model_quality_check_step = QualityCheckStep(
        name="ModelQualityCheckStep",
        register_new_baseline=False,
        skip_check=False,
        quality_check_config=model_quality_check_config,
        check_job_config=check_job_config,
        model_package_group_name=model_package_group_name,
        supplied_baseline_statistics=supplied_baseline_statistics_uri,
        supplied_baseline_constraints=supplied_baseline_constraints_uri,
    )
    pipeline = Pipeline(
        name="MyPipeline",
        parameters=[
            supplied_baseline_statistics_uri,
            supplied_baseline_constraints_uri,
            model_package_group_name,
        ],
        steps=[model_quality_check_step],
        sagemaker_session=sagemaker_session,
    )

    step_definition = _get_step_definition_for_test(pipeline)

    assert step_definition == _expected_model_quality_dsl
def test_quality_check_step_properties(
    check_job_config,
    model_package_group_name,
    supplied_baseline_statistics_uri,
    supplied_baseline_constraints_uri,
):
    model_quality_check_config = ModelQualityCheckConfig(
        baseline_dataset="baseline_dataset_s3_url",
        dataset_format=DatasetFormat.csv(header=True),
        problem_type="BinaryClassification",
        probability_attribute="0",
        probability_threshold_attribute="0.5",
        post_analytics_processor_script=
        "s3://my_bucket/data_quality/postprocessor.py",
        output_s3_uri="",
    )
    model_quality_check_step = QualityCheckStep(
        name="ModelQualityCheckStep",
        register_new_baseline=False,
        skip_check=False,
        quality_check_config=model_quality_check_config,
        check_job_config=check_job_config,
        model_package_group_name=model_package_group_name,
        supplied_baseline_statistics=supplied_baseline_statistics_uri,
        supplied_baseline_constraints=supplied_baseline_constraints_uri,
    )

    assert model_quality_check_step.properties.CalculatedBaselineConstraints.expr == {
        "Get": "Steps.ModelQualityCheckStep.CalculatedBaselineConstraints"
    }
    assert model_quality_check_step.properties.CalculatedBaselineStatistics.expr == {
        "Get": "Steps.ModelQualityCheckStep.CalculatedBaselineStatistics"
    }
    assert model_quality_check_step.properties.BaselineUsedForDriftCheckStatistics.expr == {
        "Get":
        "Steps.ModelQualityCheckStep.BaselineUsedForDriftCheckStatistics"
    }
    assert model_quality_check_step.properties.BaselineUsedForDriftCheckConstraints.expr == {
        "Get":
        "Steps.ModelQualityCheckStep.BaselineUsedForDriftCheckConstraints"
    }
def test_data_quality_check_step(
    sagemaker_session,
    check_job_config,
    model_package_group_name,
    supplied_baseline_statistics_uri,
    supplied_baseline_constraints_uri,
):
    data_quality_check_config = DataQualityCheckConfig(
        baseline_dataset=ParameterString(name="BaselineDataset"),
        dataset_format=DatasetFormat.csv(header=True),
        output_s3_uri="s3://...",
        record_preprocessor_script=
        "s3://my_bucket/data_quality/preprocessor.py",
        post_analytics_processor_script=
        "s3://my_bucket/data_quality/postprocessor.py",
    )
    data_quality_check_step = QualityCheckStep(
        name="DataQualityCheckStep",
        skip_check=False,
        register_new_baseline=False,
        quality_check_config=data_quality_check_config,
        check_job_config=check_job_config,
        model_package_group_name=model_package_group_name,
        supplied_baseline_statistics=supplied_baseline_statistics_uri,
        supplied_baseline_constraints=supplied_baseline_constraints_uri,
        cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
    )
    pipeline = Pipeline(
        name="MyPipeline",
        parameters=[
            supplied_baseline_statistics_uri,
            supplied_baseline_constraints_uri,
            model_package_group_name,
        ],
        steps=[data_quality_check_step],
        sagemaker_session=sagemaker_session,
    )
    step_definition = _get_step_definition_for_test(
        pipeline, ["baseline_dataset_input", "quality_check_output"])

    assert step_definition == _expected_data_quality_dsl
def test_one_step_model_quality_pipeline_constraint_violation(
    sagemaker_session,
    role,
    pipeline_name,
    check_job_config,
    supplied_baseline_statistics_uri_param,
    supplied_baseline_constraints_uri_param,
    model_quality_check_config,
    model_quality_supplied_baseline_statistics,
):
    model_quality_supplied_baseline_constraints = Constraints.from_file_path(
        constraints_file_path=os.path.join(
            DATA_DIR, "pipeline/quality_check_step/model_quality/bad_cases/constraints.json"
        ),
        sagemaker_session=sagemaker_session,
    ).file_s3_uri
    model_quality_check_step = QualityCheckStep(
        name="ModelQualityCheckStep",
        register_new_baseline=False,
        skip_check=False,
        quality_check_config=model_quality_check_config,
        check_job_config=check_job_config,
        supplied_baseline_statistics=supplied_baseline_statistics_uri_param,
        supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        steps=[model_quality_check_step],
        parameters=[
            supplied_baseline_statistics_uri_param,
            supplied_baseline_constraints_uri_param,
        ],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]

        for _ in retries(
            max_retry_count=5,
            exception_message_prefix="Waiting for a successful execution of pipeline",
            seconds_to_sleep=10,
        ):
            execution = pipeline.start(
                parameters={
                    "SuppliedBaselineStatisticsUri": model_quality_supplied_baseline_statistics,
                    "SuppliedBaselineConstraintsUri": model_quality_supplied_baseline_constraints,
                }
            )
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if _CHECK_FAIL_ERROR_MSG not in failure_reason:
                logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
                continue
            assert execution_steps[0]["StepName"] == "ModelQualityCheckStep"
            assert execution_steps[0]["StepStatus"] == "Failed"
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_one_step_data_quality_pipeline_happycase(
    sagemaker_session,
    role,
    pipeline_name,
    check_job_config,
    supplied_baseline_statistics_uri_param,
    supplied_baseline_constraints_uri_param,
    data_quality_check_config,
    data_quality_supplied_baseline_statistics,
):
    data_quality_supplied_baseline_constraints = Constraints.from_file_path(
        constraints_file_path=os.path.join(
            DATA_DIR, "pipeline/quality_check_step/data_quality/good_cases/constraints.json"
        ),
        sagemaker_session=sagemaker_session,
    ).file_s3_uri
    data_quality_check_step = QualityCheckStep(
        name="DataQualityCheckStep",
        skip_check=False,
        register_new_baseline=False,
        quality_check_config=data_quality_check_config,
        check_job_config=check_job_config,
        supplied_baseline_statistics=supplied_baseline_statistics_uri_param,
        supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        steps=[data_quality_check_step],
        parameters=[
            supplied_baseline_statistics_uri_param,
            supplied_baseline_constraints_uri_param,
        ],
        sagemaker_session=sagemaker_session,
    )
    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]

        for _ in retries(
            max_retry_count=5,
            exception_message_prefix="Waiting for a successful execution of pipeline",
            seconds_to_sleep=10,
        ):
            execution = pipeline.start(
                parameters={
                    "SuppliedBaselineStatisticsUri": data_quality_supplied_baseline_statistics,
                    "SuppliedBaselineConstraintsUri": data_quality_supplied_baseline_constraints,
                }
            )
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if failure_reason != "":
                logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
                continue
            assert execution_steps[0]["StepName"] == "DataQualityCheckStep"
            assert execution_steps[0]["StepStatus"] == "Succeeded"
            data_qual_metadata = execution_steps[0]["Metadata"]["QualityCheck"]
            assert not data_qual_metadata["SkipCheck"]
            assert not data_qual_metadata["RegisterNewBaseline"]
            assert not data_qual_metadata.get("ViolationReport", "")
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckConstraints"]
                == data_quality_supplied_baseline_constraints
            )
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckStatistics"]
                == data_quality_supplied_baseline_statistics
            )
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckConstraints"]
                != data_qual_metadata["CalculatedBaselineConstraints"]
            )
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckStatistics"]
                != data_qual_metadata["CalculatedBaselineStatistics"]
            )
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass