def test_clarify_check_step_properties(
    check_job_config,
    model_package_group_name,
    data_config,
    model_config,
    shap_config,
):
    model_explainability_check_config = ModelExplainabilityCheckConfig(
        data_config=data_config,
        model_config=model_config,
        explainability_config=shap_config,
    )
    model_explainability_check_step = ClarifyCheckStep(
        name="ModelExplainabilityCheckStep",
        clarify_check_config=model_explainability_check_config,
        check_job_config=check_job_config,
        skip_check=False,
        register_new_baseline=False,
        model_package_group_name=model_package_group_name,
        supplied_baseline_constraints="supplied_baseline_constraints",
    )

    assert model_explainability_check_step.properties.CalculatedBaselineConstraints.expr == {
        "Get":
        "Steps.ModelExplainabilityCheckStep.CalculatedBaselineConstraints"
    }
    assert model_explainability_check_step.properties.BaselineUsedForDriftCheckConstraints.expr == {
        "Get":
        "Steps.ModelExplainabilityCheckStep.BaselineUsedForDriftCheckConstraints"
    }
def test_clarify_check_step_invalid_config(
    check_job_config,
    model_package_group_name,
    data_config,
):
    clarify_check_config = ClarifyCheckConfig(data_config=data_config)
    with pytest.raises(Exception) as error:
        ClarifyCheckStep(
            name="ClarifyCheckStep",
            clarify_check_config=clarify_check_config,
            check_job_config=check_job_config,
            skip_check=False,
            register_new_baseline=False,
            model_package_group_name=model_package_group_name,
            supplied_baseline_constraints="supplied_baseline_constraints",
        )

    assert (
        str(error.value) ==
        "The clarify_check_config can only be object of DataBiasCheckConfig, ModelBiasCheckConfig"
        " or ModelExplainabilityCheckConfig")
def test_model_bias_check_step(
    sagemaker_session,
    check_job_config,
    model_package_group_name,
    data_config,
    bias_config,
    model_config,
    predictions_config,
):
    model_bias_check_config = ModelBiasCheckConfig(
        data_config=data_config,
        data_bias_config=bias_config,
        model_config=model_config,
        model_predicted_label_config=predictions_config,
        methods="all",
    )
    model_bias_check_step = ClarifyCheckStep(
        name="ModelBiasCheckStep",
        clarify_check_config=model_bias_check_config,
        check_job_config=check_job_config,
        skip_check=False,
        register_new_baseline=False,
        model_package_group_name=model_package_group_name,
        supplied_baseline_constraints="supplied_baseline_constraints",
    )
    pipeline = Pipeline(
        name="MyPipeline",
        parameters=[model_package_group_name],
        steps=[model_bias_check_step],
        sagemaker_session=sagemaker_session,
    )

    assert json.loads(
        pipeline.definition())["Steps"][0] == _expected_model_bias_dsl
    assert re.match(
        f"s3://{_DEFAULT_BUCKET}/{_MODEL_MONITOR_S3_PATH}" +
        f"/{_BIAS_MONITORING_CFG_BASE_NAME}-configuration" +
        f"/{_BIAS_MONITORING_CFG_BASE_NAME}-config.*/.*/analysis_config.json",
        model_bias_check_config.monitoring_analysis_config_uri,
    )
def test_data_bias_check_step(sagemaker_session, check_job_config,
                              model_package_group_name, bias_config):
    data_bias_data_config = DataConfig(
        s3_data_input_path=_S3_INPUT_PATH,
        s3_output_path=_S3_OUTPUT_PATH,
        s3_analysis_config_output_path=_S3_ANALYSIS_CONFIG_OUTPUT_PATH,
        label="fraud",
        dataset_type="text/csv",
    )
    data_bias_check_config = DataBiasCheckConfig(
        data_config=data_bias_data_config,
        data_bias_config=bias_config,
        methods="all",
        kms_key="kms_key",
    )
    data_bias_check_step = ClarifyCheckStep(
        name="DataBiasCheckStep",
        clarify_check_config=data_bias_check_config,
        check_job_config=check_job_config,
        skip_check=False,
        register_new_baseline=False,
        model_package_group_name=model_package_group_name,
        supplied_baseline_constraints="supplied_baseline_constraints",
        cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"),
    )
    pipeline = Pipeline(
        name="MyPipeline",
        parameters=[model_package_group_name],
        steps=[data_bias_check_step],
        sagemaker_session=sagemaker_session,
    )

    assert json.loads(
        pipeline.definition())["Steps"][0] == _expected_data_bias_dsl
    assert re.match(
        f"{_S3_ANALYSIS_CONFIG_OUTPUT_PATH}/{_BIAS_MONITORING_CFG_BASE_NAME}-configuration"
        +
        f"/{_BIAS_MONITORING_CFG_BASE_NAME}-config.*/.*/analysis_config.json",
        data_bias_check_config.monitoring_analysis_config_uri,
    )
def test_one_step_data_bias_pipeline_constraint_violation(
    sagemaker_session,
    role,
    pipeline_name,
    check_job_config,
    data_bias_check_config,
    supplied_baseline_constraints_uri_param,
):
    data_bias_supplied_baseline_constraints = Constraints.from_file_path(
        constraints_file_path=os.path.join(
            DATA_DIR,
            "pipeline/clarify_check_step/data_bias/bad_cases/analysis.json"),
        sagemaker_session=sagemaker_session,
    ).file_s3_uri
    data_bias_check_step = ClarifyCheckStep(
        name="DataBiasCheckStep",
        clarify_check_config=data_bias_check_config,
        check_job_config=check_job_config,
        skip_check=False,
        register_new_baseline=False,
        supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        steps=[data_bias_check_step],
        parameters=[supplied_baseline_constraints_uri_param],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]
        monitoring_analysis_cfg_json = S3Downloader.read_file(
            data_bias_check_config.monitoring_analysis_config_uri,
            sagemaker_session,
        )
        monitoring_analysis_cfg = json.loads(monitoring_analysis_cfg_json)

        assert monitoring_analysis_cfg is not None and len(
            monitoring_analysis_cfg) > 0

        for _ in retries(
                max_retry_count=5,
                exception_message_prefix=
                "Waiting for a successful execution of pipeline",
                seconds_to_sleep=10,
        ):
            execution = pipeline.start(parameters={
                "SuppliedBaselineConstraintsUri":
                data_bias_supplied_baseline_constraints
            }, )
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if _CHECK_FAIL_ERROR_MSG not in failure_reason:
                logging.error(
                    f"Pipeline execution failed with error: {failure_reason}. Retrying.."
                )
                continue
            assert execution_steps[0]["StepName"] == "DataBiasCheckStep"
            assert execution_steps[0]["StepStatus"] == "Failed"
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_get_s3_base_uri_for_monitoring_analysis_config(
    check_job_config,
    data_config,
    bias_config,
    model_config,
    shap_config,
    predictions_config,
):
    # ModelExplainabilityCheckStep without specifying s3_analysis_config_output_path
    model_explainability_check_config_1 = ModelExplainabilityCheckConfig(
        data_config=data_config,
        model_config=model_config,
        explainability_config=shap_config,
    )
    model_explainability_check_step_1 = ClarifyCheckStep(
        name="ModelExplainabilityCheckStep",
        clarify_check_config=model_explainability_check_config_1,
        check_job_config=check_job_config,
    )

    assert (f"s3://{_DEFAULT_BUCKET}/{_MODEL_MONITOR_S3_PATH}" +
            f"/{_EXPLAINABILITY_MONITORING_CFG_BASE_NAME}-configuration" ==
            model_explainability_check_step_1.
            _get_s3_base_uri_for_monitoring_analysis_config())

    # ModelExplainabilityCheckStep with specifying s3_analysis_config_output_path
    model_explainability_data_config = DataConfig(
        s3_data_input_path=_S3_INPUT_PATH,
        s3_output_path=ParameterString(name="S3OutputPath",
                                       default_value=_S3_OUTPUT_PATH),
        s3_analysis_config_output_path=_S3_ANALYSIS_CONFIG_OUTPUT_PATH,
    )
    model_explainability_check_config_2 = ModelExplainabilityCheckConfig(
        data_config=model_explainability_data_config,
        model_config=model_config,
        explainability_config=shap_config,
    )
    model_explainability_check_step_2 = ClarifyCheckStep(
        name="ModelExplainabilityCheckStep",
        clarify_check_config=model_explainability_check_config_2,
        check_job_config=check_job_config,
    )

    assert (
        f"{_S3_ANALYSIS_CONFIG_OUTPUT_PATH}/{_EXPLAINABILITY_MONITORING_CFG_BASE_NAME}-configuration"
        == model_explainability_check_step_2.
        _get_s3_base_uri_for_monitoring_analysis_config())

    # ModelBiasCheckStep with specifying s3_analysis_config_output_path
    model_bias_data_config = DataConfig(
        s3_data_input_path=_S3_INPUT_PATH,
        s3_output_path=_S3_OUTPUT_PATH,
        s3_analysis_config_output_path=_S3_ANALYSIS_CONFIG_OUTPUT_PATH,
    )
    model_bias_check_config = ModelBiasCheckConfig(
        data_config=model_bias_data_config,
        data_bias_config=bias_config,
        model_config=model_config,
        model_predicted_label_config=predictions_config,
    )
    model_bias_check_step = ClarifyCheckStep(
        name="ModelBiasCheckStep",
        clarify_check_config=model_bias_check_config,
        check_job_config=check_job_config,
    )

    assert (
        f"{_S3_ANALYSIS_CONFIG_OUTPUT_PATH}/{_BIAS_MONITORING_CFG_BASE_NAME}-configuration"
        == model_bias_check_step.
        _get_s3_base_uri_for_monitoring_analysis_config())

    # DataBiasCheckStep without specifying s3_analysis_config_output_path
    data_bias_check_config = DataBiasCheckConfig(
        data_config=data_config,
        data_bias_config=bias_config,
    )
    data_bias_check_step = ClarifyCheckStep(
        name="DataBiasCheckStep",
        clarify_check_config=data_bias_check_config,
        check_job_config=check_job_config,
    )
    assert (
        f"s3://{_DEFAULT_BUCKET}/{_MODEL_MONITOR_S3_PATH}" +
        f"/{_BIAS_MONITORING_CFG_BASE_NAME}-configuration" ==
        data_bias_check_step._get_s3_base_uri_for_monitoring_analysis_config())
def test_clarify_check_step_with_none_or_invalid_s3_analysis_config_output_uri(
    bias_config,
    check_job_config,
    model_package_group_name,
):
    # s3_analysis_config_output is None and s3_output_path is valid s3 path str
    data_config = DataConfig(
        s3_data_input_path=_S3_INPUT_PATH,
        s3_output_path=_S3_OUTPUT_PATH,
        label="fraud",
        dataset_type="text/csv",
    )
    clarify_check_config = DataBiasCheckConfig(
        data_config=data_config,
        data_bias_config=bias_config,
    )

    ClarifyCheckStep(
        name="ClarifyCheckStep",
        clarify_check_config=clarify_check_config,
        check_job_config=check_job_config,
        skip_check=False,
        register_new_baseline=False,
        model_package_group_name=model_package_group_name,
        supplied_baseline_constraints="supplied_baseline_constraints",
    )

    # s3_analysis_config_output is empty but s3_output_path is Parameter
    data_config = DataConfig(
        s3_data_input_path=_S3_INPUT_PATH,
        s3_output_path=ParameterString(name="S3OutputPath",
                                       default_value=_S3_OUTPUT_PATH),
        s3_analysis_config_output_path="",
        label="fraud",
        dataset_type="text/csv",
    )
    clarify_check_config = DataBiasCheckConfig(
        data_config=data_config,
        data_bias_config=bias_config,
    )

    with pytest.raises(Exception) as error:
        ClarifyCheckStep(
            name="ClarifyCheckStep",
            clarify_check_config=clarify_check_config,
            check_job_config=check_job_config,
            skip_check=False,
            register_new_baseline=False,
            model_package_group_name=model_package_group_name,
            supplied_baseline_constraints="supplied_baseline_constraints",
        )

    assert (
        str(error.value) ==
        "`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter/Properties "
        + "if `s3_analysis_config_output_path` is none or empty ")

    # s3_analysis_config_output is invalid
    data_config = DataConfig(
        s3_data_input_path=_S3_INPUT_PATH,
        s3_output_path=ParameterString(name="S3OutputPath",
                                       default_value=_S3_OUTPUT_PATH),
        s3_analysis_config_output_path=ParameterString(
            name="S3OAnalysisCfgOutput"),
        label="fraud",
        dataset_type="text/csv",
    )
    clarify_check_config = DataBiasCheckConfig(
        data_config=data_config,
        data_bias_config=bias_config,
    )

    with pytest.raises(Exception) as error:
        ClarifyCheckStep(
            name="ClarifyCheckStep",
            clarify_check_config=clarify_check_config,
            check_job_config=check_job_config,
            skip_check=False,
            register_new_baseline=False,
            model_package_group_name=model_package_group_name,
            supplied_baseline_constraints="supplied_baseline_constraints",
        )

    assert (str(
        error.value) == "s3_analysis_config_output_path cannot be of type "
            "ExecutionVariable/Expression/Parameter/Properties")