def bias_monitor(sagemaker_session):
    monitor = ModelBiasMonitor(
        role=ROLE,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        volume_size_in_gb=VOLUME_SIZE_IN_GB,
        max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
        sagemaker_session=sagemaker_session,
        tags=TEST_TAGS,
    )
    return monitor
def test_bias_monitor(sagemaker_session, scheduled_bias_monitor, endpoint_name,
                      ground_truth_input):
    monitor = scheduled_bias_monitor
    monitor._wait_for_schedule_changes_to_apply()

    # stop it as soon as possible to avoid any execution
    monitor.stop_monitoring_schedule()
    _verify_monitoring_schedule(
        monitor=monitor,
        schedule_status="Stopped",
    )
    _verify_bias_job_description(
        sagemaker_session=sagemaker_session,
        monitor=monitor,
        endpoint_name=endpoint_name,
        ground_truth_input=ground_truth_input,
    )

    # attach to schedule
    monitoring_schedule_name = monitor.monitoring_schedule_name
    job_definition_name = monitor.job_definition_name
    monitor = ModelBiasMonitor.attach(
        monitor_schedule_name=monitor.monitoring_schedule_name,
        sagemaker_session=sagemaker_session,
    )
    assert monitor.monitoring_schedule_name == monitoring_schedule_name
    assert monitor.job_definition_name == job_definition_name

    # update schedule
    monitor.update_monitoring_schedule(
        max_runtime_in_seconds=UPDATED_MAX_RUNTIME_IN_SECONDS,
        schedule_cron_expression=UPDATED_CRON)
    assert monitor.monitoring_schedule_name == monitoring_schedule_name
    assert monitor.job_definition_name != job_definition_name
    _verify_monitoring_schedule(monitor=monitor,
                                schedule_status="Scheduled",
                                schedule_cron_expression=UPDATED_CRON)
    _verify_bias_job_description(
        sagemaker_session=sagemaker_session,
        monitor=monitor,
        endpoint_name=endpoint_name,
        ground_truth_input=ground_truth_input,
        max_runtime_in_seconds=UPDATED_MAX_RUNTIME_IN_SECONDS,
    )

    # delete schedule
    monitor.delete_monitoring_schedule()
    def _generate_model_monitor(self, mm_type: str) -> Optional[ModelMonitor]:
        """Generates a ModelMonitor object

        Generates a ModelMonitor object with required config attributes for
            QualityCheckStep and ClarifyCheckStep

        Args:
            mm_type (str): The subclass type of ModelMonitor object.
                A valid mm_type should be one of the following: "DefaultModelMonitor",
                "ModelQualityMonitor", "ModelBiasMonitor", "ModelExplainabilityMonitor"

        Return:
            sagemaker.model_monitor.ModelMonitor or None if the mm_type is not valid

        """
        if mm_type == "DefaultModelMonitor":
            monitor = DefaultModelMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        elif mm_type == "ModelQualityMonitor":
            monitor = ModelQualityMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        elif mm_type == "ModelBiasMonitor":
            monitor = ModelBiasMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        elif mm_type == "ModelExplainabilityMonitor":
            monitor = ModelExplainabilityMonitor(
                role=self.role,
                instance_count=self.instance_count,
                instance_type=self.instance_type,
                volume_size_in_gb=self.volume_size_in_gb,
                volume_kms_key=self.volume_kms_key,
                output_kms_key=self.output_kms_key,
                max_runtime_in_seconds=self.max_runtime_in_seconds,
                base_job_name=self.base_job_name,
                sagemaker_session=self.sagemaker_session,
                env=self.env,
                tags=self.tags,
                network_config=self.network_config,
            )
        else:
            logging.warning(
                'Expected model monitor types: "DefaultModelMonitor", "ModelQualityMonitor", '
                '"ModelBiasMonitor", "ModelExplainabilityMonitor"')
            return None
        return monitor
def test_run_bias_monitor_baseline(
    sagemaker_session,
    data_config,
    model_config,
    bias_config,
    model_predicted_label_config,
    endpoint_name,
    ground_truth_input,
    upload_actual_data,
):
    monitor = ModelBiasMonitor(
        role=ROLE,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        volume_size_in_gb=VOLUME_SIZE_IN_GB,
        max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
        sagemaker_session=sagemaker_session,
        tags=TEST_TAGS,
    )

    baselining_job_name = utils.unique_name_from_base("bias-baselining-job")
    print("Creating baselining job: {}".format(baselining_job_name))
    monitor.suggest_baseline(
        data_config=data_config,
        bias_config=bias_config,
        model_config=model_config,
        model_predicted_label_config=model_predicted_label_config,
        job_name=baselining_job_name,
    )
    assert (monitor.latest_baselining_job_config.
            probability_threshold_attribute == BIAS_PROBABILITY_THRESHOLD)
    monitoring_schedule_name = utils.unique_name_from_base(
        "bias-suggest-baseline")
    s3_uri_monitoring_output = os.path.join(
        "s3://",
        sagemaker_session.default_bucket(),
        endpoint_name,
        monitoring_schedule_name,
        "monitor_output",
    )
    # Let's test if the schedule can pick up analysis_config from baselining job
    monitor.create_monitoring_schedule(
        output_s3_uri=s3_uri_monitoring_output,
        monitor_schedule_name=monitoring_schedule_name,
        endpoint_input=EndpointInput(
            endpoint_name=endpoint_name,
            destination=ENDPOINT_INPUT_LOCAL_PATH,
            start_time_offset=START_TIME_OFFSET,
            end_time_offset=END_TIME_OFFSET,
        ),
        ground_truth_input=ground_truth_input,
        schedule_cron_expression=CRON,
    )
    _verify_execution_status(monitor)

    _verify_bias_job_description(
        sagemaker_session=sagemaker_session,
        monitor=monitor,
        endpoint_name=endpoint_name,
        ground_truth_input=ground_truth_input,
    )

    monitor.delete_monitoring_schedule()