def testMergeDetails(self):
        a = text_format.Parse(
            """
        validation_details {
          slicing_details {
            slicing_spec {}
            num_matching_slices: 1
          }
          slicing_details {
            slicing_spec {
              feature_keys: ["x", "y"]
            }
            num_matching_slices: 1
          }
        }""", validation_result_pb2.ValidationResult())

        b = text_format.Parse(
            """
        validation_details {
          slicing_details {
            slicing_spec {
              feature_keys: ["x"]
            }
            num_matching_slices: 1
          }
          slicing_details {
            slicing_spec {
              feature_keys: ["x", "y"]
            }
            num_matching_slices: 2
          }
        }""", validation_result_pb2.ValidationResult())

        expected = text_format.Parse(
            """
        validation_details {
          slicing_details {
            slicing_spec {}
            num_matching_slices: 1
          }
          slicing_details {
            slicing_spec {
              feature_keys: ["x", "y"]
            }
            num_matching_slices: 3
          }
          slicing_details {
            slicing_spec {
              feature_keys: ["x"]
            }
            num_matching_slices: 1
          }
        }""", validation_result_pb2.ValidationResult())

        metrics_validator.merge_details(a, b)
        self.assertProtoEquals(expected, a)
Пример #2
0
 def extract_output(
     self, accumulator: 'Optional[validation_result_pb2.ValidationResult]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
     # Verification fails if there is empty input.
     if not accumulator:
         accumulator = validation_result_pb2.ValidationResult(
             validation_ok=False)
     thresholds = metric_specs.metric_thresholds_from_metrics_specs(
         self._eval_config.metrics_specs)
     if not thresholds:
         # Default is to validation NOT ok when not rubber stamping.
         accumulator.validation_ok = self._rubber_stamp
         # Default is to missing thresholds when not rubber stamping.
         accumulator.missing_thresholds = not self._rubber_stamp
     missing = metrics_validator.get_missing_slices(
         accumulator.validation_details.slicing_details, self._eval_config)
     if missing:
         missing_slices = []
         missing_cross_slices = []
         for m in missing:
             if isinstance(m, config.SlicingSpec):
                 missing_slices.append(m)
             elif isinstance(m, config.CrossSlicingSpec):
                 missing_cross_slices.append(m)
         accumulator.validation_ok = False
         if missing_slices:
             accumulator.missing_slices.extend(missing_slices)
         if missing_cross_slices:
             accumulator.missing_cross_slices.extend(missing_cross_slices)
     if self._rubber_stamp:
         accumulator.rubber_stamp = True
     return accumulator
 def testValidateMetricsMetricTDistributionValueAndThreshold(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         value_threshold=config.GenericValueThreshold(
             lower_bound={'value': 0.9}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(metrics=[
                 config.MetricConfig(
                     class_name='AUC',
                     threshold=threshold if slicing_specs is None else None,
                     per_slice_thresholds=[
                         config.PerSliceMetricThreshold(
                             slicing_specs=slicing_specs,
                             threshold=threshold)
                     ]),
             ],
                                model_names=['']),
         ],
     )
     sliced_metrics = (slice_key, {
         metric_types.MetricKey(name='auc'):
         types.ValueWithTDistribution(sample_mean=0.91, unsampled_value=0.8)
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       failures {
         metric_key {
           name: "auc"
         }
         metric_value {
           double_value {
             value: 0.8
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     expected.metric_validations_per_slice[0].failures[
         0].metric_threshold.CopyFrom(threshold)
     expected.metric_validations_per_slice[0].slice_key.CopyFrom(
         slicer.serialize_slice_key(slice_key))
     for spec in slicing_specs or [None]:
         if (spec is None or slicer.SingleSliceSpec(
                 spec=spec).is_slice_applicable(slice_key)):
             slicing_details = expected.validation_details.slicing_details.add(
             )
             if spec is not None:
                 slicing_details.slicing_spec.CopyFrom(spec)
             else:
                 slicing_details.slicing_spec.CopyFrom(config.SlicingSpec())
             slicing_details.num_matching_slices = 1
     self.assertEqual(result, expected)
Пример #4
0
 def extract_output(
     self, accumulator: 'Optional[validation_result_pb2.ValidationResult]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
   # Verification fails if there is empty input.
   if not accumulator:
     result = validation_result_pb2.ValidationResult(validation_ok=False)
     return result
   return accumulator
    def testGetMissingSlices(self):
        slicing_specs = [
            config.SlicingSpec(),
            config.SlicingSpec(feature_values={'feature1': 'value1'}),
            config.SlicingSpec(feature_values={'feature2': 'value2'})
        ]
        threshold = config.MetricThreshold(
            value_threshold=config.GenericValueThreshold(
                upper_bound={'value': 1}))
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(),
            ],
            slicing_specs=slicing_specs,
            metrics_specs=[
                config.MetricsSpec(
                    metrics=[
                        config.MetricConfig(
                            class_name='WeightedExampleCount',
                            # 1.5 < 1, NOT OK.
                            per_slice_thresholds=[
                                config.PerSliceMetricThreshold(
                                    slicing_specs=slicing_specs,
                                    threshold=threshold)
                            ]),
                    ],
                    model_names=['']),
            ],
        )
        sliced_metrics = ((('feature1', 'value1'), ), {
            metric_types.MetricKey(name='weighted_example_count'):
            0,
        })
        result = metrics_validator.validate_metrics(sliced_metrics,
                                                    eval_config)

        expected_checks = text_format.Parse(
            """
        validation_ok: true
        validation_details {
          slicing_details {
            slicing_spec {
              feature_values {
                key: "feature1"
                value: "value1"
              }
            }
            num_matching_slices: 1
          }
        }""", validation_result_pb2.ValidationResult())

        self.assertProtoEquals(expected_checks, result)

        missing = metrics_validator.get_missing_slices(
            result.validation_details.slicing_details, eval_config)
        self.assertLen(missing, 2)
        self.assertProtoEquals(missing[0], slicing_specs[0])
        self.assertProtoEquals(missing[1], slicing_specs[2])
Пример #6
0
 def add_input(
     self, result: 'Optional[validation_result_pb2.ValidationResult]',
     new_input: 'Optional[validation_result_pb2.ValidationResult]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
     if new_input is None:
         return None
     if result is None:
         result = validation_result_pb2.ValidationResult(validation_ok=True)
     result.validation_ok &= new_input.validation_ok
     result.metric_validations_per_slice.extend(
         new_input.metric_validations_per_slice)
     return result
 def extract_output(
     self, accumulator: 'Optional[validation_result_pb2.ValidationResult]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
   # Verification fails if there is empty input.
   if not accumulator:
     accumulator = validation_result_pb2.ValidationResult(validation_ok=False)
   missing = metrics_validator.get_missing_slices(
       accumulator.validation_details.slicing_details, self._eval_config)
   if missing:
     accumulator.validation_ok = False
     accumulator.missing_slices.extend(missing)
   return accumulator
Пример #8
0
 def merge_accumulators(
     self,
     accumulators: 'List[Optional[validation_result_pb2.ValidationResult]]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
   accumulators = [accumulator for accumulator in accumulators if accumulator]
   if not accumulators:
     return None
   result = validation_result_pb2.ValidationResult(validation_ok=True)
   for new_input in accumulators:
     result.metric_validations_per_slice.extend(
         new_input.metric_validations_per_slice)
     result.validation_ok &= new_input.validation_ok
   return result
 def testValidateMetricsMetricValueAndThreshold(self):
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=[config.SlicingSpec()],
         metrics_specs=[
             config.MetricsSpec(
                 metrics=[
                     config.MetricConfig(
                         class_name='WeightedExampleCount',
                         # 1.5 < 1, NOT OK.
                         threshold=config.MetricThreshold(
                             value_threshold=config.GenericValueThreshold(
                                 upper_bound={'value': 1}))),
                 ],
                 model_names=['']),
         ],
     )
     sliced_metrics = ((()), {
         metric_types.MetricKey(name='weighted_example_count'):
         1.5,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       slice_key {
       }
       failures {
         metric_key {
           name: "weighted_example_count"
         }
         metric_threshold {
           value_threshold {
             upper_bound {
               value: 1.0
             }
           }
         }
         metric_value {
           double_value {
             value: 1.5
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     self.assertEqual(result, expected)
 def testValidateMetricsMetricValueAndThreshold(self, slicing_specs,
                                                slice_key):
   threshold = config.MetricThreshold(
       value_threshold=config.GenericValueThreshold(upper_bound={'value': 1}))
   eval_config = config.EvalConfig(
       model_specs=[
           config.ModelSpec(),
       ],
       slicing_specs=slicing_specs,
       metrics_specs=[
           config.MetricsSpec(
               metrics=[
                   config.MetricConfig(
                       class_name='WeightedExampleCount',
                       # 1.5 < 1, NOT OK.
                       threshold=threshold if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold)
                       ]),
               ],
               model_names=['']),
       ],
   )
   sliced_metrics = (slice_key, {
       metric_types.MetricKey(name='weighted_example_count'): 1.5,
   })
   result = metrics_validator.validate_metrics(sliced_metrics, eval_config)
   self.assertFalse(result.validation_ok)
   expected = text_format.Parse(
       """
       metric_validations_per_slice {
         failures {
           metric_key {
             name: "weighted_example_count"
           }
           metric_value {
             double_value {
               value: 1.5
             }
           }
         }
       }""", validation_result_pb2.ValidationResult())
   expected.metric_validations_per_slice[0].failures[
       0].metric_threshold.CopyFrom(threshold)
   expected.metric_validations_per_slice[0].slice_key.CopyFrom(
       slicer.serialize_slice_key(slice_key))
   self.assertEqual(result, expected)
 def merge_accumulators(
     self,
     accumulators: 'Iterable[Optional[validation_result_pb2.ValidationResult]]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
   it = iter(accumulators)
   result = next(it)
   for new_input in it:
     if new_input is None:
       continue
     if result is None:
       result = validation_result_pb2.ValidationResult(validation_ok=True)
     result.metric_validations_per_slice.extend(
         new_input.metric_validations_per_slice)
     metrics_validator.merge_details(result, new_input)
     result.validation_ok &= new_input.validation_ok
   return result
Пример #12
0
 def testValidateMetricsInvalidThreshold(self):
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=[config.SlicingSpec()],
         metrics_specs=[
             config.MetricsSpec(
                 thresholds={
                     'invalid_threshold':
                     config.MetricThreshold(
                         value_threshold=config.GenericValueThreshold(
                             lower_bound={'value': 0.2}))
                 })
         ],
     )
     sliced_metrics = ((()), {
         metric_types.MetricKey(name='weighted_example_count'):
         1.5,
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       slice_key {
       }
       failures {
         metric_key {
           name: "invalid_threshold"
         }
         metric_threshold {
           value_threshold {
             lower_bound {
               value: 0.2
             }
           }
         }
         message: 'Metric not found.'
       }
     }""", validation_result_pb2.ValidationResult())
     self.assertProtoEquals(expected, result)
Пример #13
0
 def testValidateMetricsMetricTDistributionChangeAndThreshold(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         change_threshold=config.GenericChangeThreshold(
             direction=config.MetricDirection.LOWER_IS_BETTER,
             absolute={'value': -1}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
             config.ModelSpec(name='baseline', is_baseline=True)
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(metrics=[
                 config.MetricConfig(
                     class_name='AUC',
                     threshold=threshold if slicing_specs is None else None,
                     per_slice_thresholds=[
                         config.PerSliceMetricThreshold(
                             slicing_specs=slicing_specs,
                             threshold=threshold)
                     ]),
             ],
                                model_names=['']),
         ],
     )
     sliced_metrics = (
         slice_key,
         {
             # This is the mean of the diff.
             metric_types.MetricKey(name='auc', model_name='baseline'):
             types.ValueWithTDistribution(sample_mean=0.91,
                                          unsampled_value=0.6),
             metric_types.MetricKey(name='auc', is_diff=True):
             types.ValueWithTDistribution(sample_mean=0.1,
                                          unsampled_value=0.1),
         })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       failures {
         metric_key {
           name: "auc"
           is_diff: true
         }
         metric_value {
           double_value {
             value: 0.1
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     expected.metric_validations_per_slice[0].failures[
         0].metric_threshold.CopyFrom(threshold)
     expected.metric_validations_per_slice[0].slice_key.CopyFrom(
         slicer.serialize_slice_key(slice_key))
     for spec in slicing_specs or [None]:
         if (spec is None or slicer.SingleSliceSpec(
                 spec=spec).is_slice_applicable(slice_key)):
             slicing_details = expected.validation_details.slicing_details.add(
             )
             if spec is not None:
                 slicing_details.slicing_spec.CopyFrom(spec)
             else:
                 slicing_details.slicing_spec.CopyFrom(config.SlicingSpec())
             slicing_details.num_matching_slices = 1
     self.assertAlmostEqual(result, expected)
Пример #14
0
def validate_metrics(
    sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType],
                          Dict['metric_types.MetricKey',
                               Any]], eval_config: config_pb2.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    is_cross_slice = slicer.is_cross_slice_key(sliced_key)

    def _check_threshold(key: metric_types.MetricKey,
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        metric = float(metric)
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric >= lower_bound and metric <= upper_bound
        elif isinstance(threshold, config_pb2.GenericChangeThreshold):
            diff = metric
            metric_baseline = float(
                metrics[key.make_baseline_key(baseline_model_name)])
            if math.isclose(metric_baseline, 0.0):
                ratio = float('nan')
            else:
                ratio = diff / metric_baseline
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError(
                    '"UNKNOWN" direction for change threshold: {}.'.format(
                        threshold))
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                return diff <= absolute and ratio <= relative
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                return diff >= absolute and ratio >= relative
        else:
            raise ValueError('Unknown threshold: {}'.format(threshold))

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config_pb2.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if slice_spec is not None:
                if (isinstance(slice_spec, config_pb2.SlicingSpec)
                        and (is_cross_slice or not slicer.SingleSliceSpec(
                            spec=slice_spec).is_slice_applicable(sliced_key))):
                    continue
                if (isinstance(slice_spec, config_pb2.CrossSlicingSpec)
                        and (not is_cross_slice
                             or not slicer.is_cross_slice_applicable(
                                 cross_slice_key=sliced_key,
                                 cross_slicing_spec=slice_spec))):
                    continue
            elif is_cross_slice:
                continue
            try:
                check_result = _check_threshold(metric_key, threshold, metric)
            except ValueError:
                msg = """
          Invalid metrics or threshold for comparison: The type of the metric
          is: {}, the metric value is: {}, and the threshold is: {}.
          """.format(type(metric), metric, threshold)
                check_result = False
            else:
                msg = ''
            if not check_result:
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
            # Track we have completed a validation check for slice spec and metric
            slicing_details = result.validation_details.slicing_details.add()
            if slice_spec is not None:
                if isinstance(slice_spec, config_pb2.SlicingSpec):
                    slicing_details.slicing_spec.CopyFrom(slice_spec)
                else:
                    slicing_details.cross_slicing_spec.CopyFrom(slice_spec)
            else:
                slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec())
            slicing_details.num_matching_slices = 1
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for slice_spec, threshold in thresholds:
            if slice_spec is not None:
                if is_cross_slice != isinstance(slice_spec,
                                                config_pb2.CrossSlicingSpec):
                    continue
                if (is_cross_slice
                        and not slicer.is_cross_slice_applicable(
                            cross_slice_key=sliced_key,
                            cross_slicing_spec=slice_spec)):
                    continue
            elif is_cross_slice:
                continue
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        if not is_cross_slice:
            validation_for_slice.slice_key.CopyFrom(
                slicer.serialize_slice_key(sliced_key))
        else:
            validation_for_slice.cross_slice_key.CopyFrom(
                slicer.serialize_cross_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
Пример #15
0
def validate_metrics(
        sliced_metrics: Tuple[slicer.SliceKeyType,
                              Dict[metric_types.MetricKey,
                                   Any]], eval_config: config.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)

    def _check_threshold(key: metric_types.MetricKey, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        threshold = thresholds[key]
        if isinstance(threshold, config.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric > lower_bound and metric < upper_bound
        elif isinstance(threshold, config.GenericChangeThreshold):
            diff = metric
            ratio = diff / metrics[key.make_baseline_key(baseline_model_name)]
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError('"UNKNOWN" direction for change threshold.')
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                return diff < absolute and ratio < relative
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                return diff > absolute and ratio > relative

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    for metric_key, metric in metrics.items():
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if (metric_key.model_name == baseline_model_name
                or metric_key not in thresholds):
            continue
        msg = ''
        # We try to convert to float values.
        try:
            metric = float(metric)
        except (TypeError, ValueError):
            msg = """
        Invalid threshold config: This metric is not comparable to the
        threshold. The type of the threshold is: {}, and the metric value is:
        \n{}""".format(type(metric), metric)
        if not _check_threshold(metric_key, metric):
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_metric(metric, failure.metric_value)
            _copy_threshold(thresholds[metric_key], failure.metric_threshold)
            failure.message = msg
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        validation_for_slice.slice_key.CopyFrom(
            slicer.serialize_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
def validate_metrics(
        sliced_metrics: Tuple[slicer.SliceKeyType,
                              Dict['metric_types.MetricKey',
                                   Any]], eval_config: config.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)  # pytype: disable=wrong-arg-types

    def _check_threshold(key: metric_types.MetricKey,
                         slicing_spec: Optional[config.SlicingSpec],
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        if (slicing_spec is not None and not slicer.SingleSliceSpec(
                spec=slicing_spec).is_slice_applicable(sliced_key)):
            return True

        if isinstance(threshold, config.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric > lower_bound and metric < upper_bound
        elif isinstance(threshold, config.GenericChangeThreshold):
            diff = metric
            ratio = diff / metrics[key.make_baseline_key(baseline_model_name)]
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError('"UNKNOWN" direction for change threshold.')
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                return diff < absolute and ratio < relative
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                return diff > absolute and ratio > relative

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        # We try to convert to float values.
        try:
            metric = float(metric)
        except (TypeError, ValueError):
            msg = """
        Invalid threshold config: This metric is not comparable to the
        threshold. The type of the threshold is: {}, and the metric value is:
        \n{}""".format(type(metric), metric)
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if not _check_threshold(metric_key, slice_spec, threshold, metric):
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for _, threshold in thresholds:
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        validation_for_slice.slice_key.CopyFrom(
            slicer.serialize_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result