Exemple #1
0
 def extract_output(
     self, accumulator: 'Optional[validation_result_pb2.ValidationResult]'
 ) -> 'Optional[validation_result_pb2.ValidationResult]':
     # Verification fails if there is empty input.
     if not accumulator:
         accumulator = validation_result_pb2.ValidationResult(
             validation_ok=False)
     thresholds = metric_specs.metric_thresholds_from_metrics_specs(
         self._eval_config.metrics_specs)
     if not thresholds:
         # Default is to validation NOT ok when not rubber stamping.
         accumulator.validation_ok = self._rubber_stamp
         # Default is to missing thresholds when not rubber stamping.
         accumulator.missing_thresholds = not self._rubber_stamp
     missing = metrics_validator.get_missing_slices(
         accumulator.validation_details.slicing_details, self._eval_config)
     if missing:
         missing_slices = []
         missing_cross_slices = []
         for m in missing:
             if isinstance(m, config.SlicingSpec):
                 missing_slices.append(m)
             elif isinstance(m, config.CrossSlicingSpec):
                 missing_cross_slices.append(m)
         accumulator.validation_ok = False
         if missing_slices:
             accumulator.missing_slices.extend(missing_slices)
         if missing_cross_slices:
             accumulator.missing_cross_slices.extend(missing_cross_slices)
     if self._rubber_stamp:
         accumulator.rubber_stamp = True
     return accumulator
Exemple #2
0
def get_missing_slices(
    slicing_details: Iterable[validation_result_pb2.SlicingDetails],
    eval_config: config_pb2.EvalConfig
) -> List[Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec]]:
    """Returns specs that are defined in the EvalConfig but not found in details.

  Args:
    slicing_details: Slicing details.
    eval_config: Eval config.

  Returns:
    List of missing slices or empty list if none are missing.
  """
    hashed_details = _hashed_slicing_details(slicing_details)
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None
    missing_slices = []
    for metric_key, sliced_thresholds in thresholds.items():
        # Skip baseline.
        if metric_key.model_name == baseline_model_name:
            continue
        for slice_spec, _ in sliced_thresholds:
            if not slice_spec:
                slice_spec = config_pb2.SlicingSpec()
            slice_hash = slice_spec.SerializeToString()
            if slice_hash not in hashed_details:
                missing_slices.append(slice_spec)
                # Same slice may be used by other metrics/thresholds, only add once
                hashed_details[
                    slice_hash] = validation_result_pb2.SlicingDetails()
    return missing_slices
Exemple #3
0
 def testMetricThresholdsFromMetricsSpecs(self):
     metrics_specs = [
         config.MetricsSpec(
             thresholds={
                 'auc':
                 config.MetricThreshold(
                     value_threshold=config.GenericValueThreshold()),
                 'mean/label':
                 config.MetricThreshold(
                     value_threshold=config.GenericValueThreshold(),
                     change_threshold=config.GenericChangeThreshold()),
                 # The mse metric will be overridden by MetricConfig below.
                 'mse':
                 config.MetricThreshold(
                     change_threshold=config.GenericChangeThreshold())
             },
             # Model names and output_names should be ignored because
             # ExampleCount is model independent.
             model_names=['model_name'],
             output_names=['output_name']),
         config.MetricsSpec(
             metrics=[
                 config.MetricConfig(
                     class_name='ExampleCount',
                     config=json.dumps({'name': 'example_count'}),
                     threshold=config.MetricThreshold(
                         value_threshold=config.GenericValueThreshold()))
             ],
             # Model names and output_names should be ignored because
             # ExampleCount is model independent.
             model_names=['model_name1', 'model_name2'],
             output_names=['output_name1', 'output_name2']),
         config.MetricsSpec(metrics=[
             config.MetricConfig(
                 class_name='WeightedExampleCount',
                 config=json.dumps({'name': 'weighted_example_count'}),
                 threshold=config.MetricThreshold(
                     value_threshold=config.GenericValueThreshold()))
         ],
                            model_names=['model_name1', 'model_name2'],
                            output_names=['output_name1', 'output_name2']),
         config.MetricsSpec(
             metrics=[
                 config.MetricConfig(
                     class_name='MeanSquaredError',
                     config=json.dumps({'name': 'mse'}),
                     threshold=config.MetricThreshold(
                         change_threshold=config.GenericChangeThreshold())),
                 config.MetricConfig(
                     class_name='MeanLabel',
                     config=json.dumps({'name': 'mean_label'}),
                     threshold=config.MetricThreshold(
                         change_threshold=config.GenericChangeThreshold()))
             ],
             model_names=['model_name'],
             output_names=['output_name'],
             binarize=config.BinarizationOptions(
                 class_ids={'values': [0, 1]}),
             aggregate=config.AggregationOptions(macro_average=True))
     ]
     thresholds = metric_specs.metric_thresholds_from_metrics_specs(
         metrics_specs)
     self.assertLen(thresholds, 14)
     self.assertIn(
         metric_types.MetricKey(name='auc',
                                model_name='model_name',
                                output_name='output_name'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean/label',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean/label',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=False), thresholds)
     self.assertIn(metric_types.MetricKey(name='example_count'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name1',
                                output_name='output_name1'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name1',
                                output_name='output_name2'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name2',
                                output_name='output_name1'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name2',
                                output_name='output_name2'), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mse',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=0),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mse',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=1),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mse',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean_label',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=0),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean_label',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=1),
                                is_diff=True), thresholds)
     self.assertIn(
         metric_types.MetricKey(name='mean_label',
                                model_name='model_name',
                                output_name='output_name',
                                is_diff=True), thresholds)
Exemple #4
0
def validate_metrics(
    sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType],
                          Dict['metric_types.MetricKey',
                               Any]], eval_config: config_pb2.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    is_cross_slice = slicer.is_cross_slice_key(sliced_key)

    def _check_threshold(key: metric_types.MetricKey,
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        metric = float(metric)
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric >= lower_bound and metric <= upper_bound
        elif isinstance(threshold, config_pb2.GenericChangeThreshold):
            diff = metric
            metric_baseline = float(
                metrics[key.make_baseline_key(baseline_model_name)])
            if math.isclose(metric_baseline, 0.0):
                ratio = float('nan')
            else:
                ratio = diff / metric_baseline
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError(
                    '"UNKNOWN" direction for change threshold: {}.'.format(
                        threshold))
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                return diff <= absolute and ratio <= relative
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                return diff >= absolute and ratio >= relative
        else:
            raise ValueError('Unknown threshold: {}'.format(threshold))

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config_pb2.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if slice_spec is not None:
                if (isinstance(slice_spec, config_pb2.SlicingSpec)
                        and (is_cross_slice or not slicer.SingleSliceSpec(
                            spec=slice_spec).is_slice_applicable(sliced_key))):
                    continue
                if (isinstance(slice_spec, config_pb2.CrossSlicingSpec)
                        and (not is_cross_slice
                             or not slicer.is_cross_slice_applicable(
                                 cross_slice_key=sliced_key,
                                 cross_slicing_spec=slice_spec))):
                    continue
            elif is_cross_slice:
                continue
            try:
                check_result = _check_threshold(metric_key, threshold, metric)
            except ValueError:
                msg = """
          Invalid metrics or threshold for comparison: The type of the metric
          is: {}, the metric value is: {}, and the threshold is: {}.
          """.format(type(metric), metric, threshold)
                check_result = False
            else:
                msg = ''
            if not check_result:
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
            # Track we have completed a validation check for slice spec and metric
            slicing_details = result.validation_details.slicing_details.add()
            if slice_spec is not None:
                if isinstance(slice_spec, config_pb2.SlicingSpec):
                    slicing_details.slicing_spec.CopyFrom(slice_spec)
                else:
                    slicing_details.cross_slicing_spec.CopyFrom(slice_spec)
            else:
                slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec())
            slicing_details.num_matching_slices = 1
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for slice_spec, threshold in thresholds:
            if slice_spec is not None:
                if is_cross_slice != isinstance(slice_spec,
                                                config_pb2.CrossSlicingSpec):
                    continue
                if (is_cross_slice
                        and not slicer.is_cross_slice_applicable(
                            cross_slice_key=sliced_key,
                            cross_slicing_spec=slice_spec)):
                    continue
            elif is_cross_slice:
                continue
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        if not is_cross_slice:
            validation_for_slice.slice_key.CopyFrom(
                slicer.serialize_slice_key(sliced_key))
        else:
            validation_for_slice.cross_slice_key.CopyFrom(
                slicer.serialize_cross_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
def validate_metrics(
        sliced_metrics: Tuple[slicer.SliceKeyType,
                              Dict[metric_types.MetricKey,
                                   Any]], eval_config: config.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)

    def _check_threshold(key: metric_types.MetricKey, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        threshold = thresholds[key]
        if isinstance(threshold, config.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric > lower_bound and metric < upper_bound
        elif isinstance(threshold, config.GenericChangeThreshold):
            diff = metric
            ratio = diff / metrics[key.make_baseline_key(baseline_model_name)]
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError('"UNKNOWN" direction for change threshold.')
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                return diff < absolute and ratio < relative
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                return diff > absolute and ratio > relative

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    for metric_key, metric in metrics.items():
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if (metric_key.model_name == baseline_model_name
                or metric_key not in thresholds):
            continue
        msg = ''
        # We try to convert to float values.
        try:
            metric = float(metric)
        except (TypeError, ValueError):
            msg = """
        Invalid threshold config: This metric is not comparable to the
        threshold. The type of the threshold is: {}, and the metric value is:
        \n{}""".format(type(metric), metric)
        if not _check_threshold(metric_key, metric):
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_metric(metric, failure.metric_value)
            _copy_threshold(thresholds[metric_key], failure.metric_threshold)
            failure.message = msg
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        validation_for_slice.slice_key.CopyFrom(
            slicer.serialize_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
Exemple #6
0
    def testMetricThresholdsFromMetricsSpecs(self):
        slice_specs = [
            config.SlicingSpec(feature_keys=['feature1']),
            config.SlicingSpec(feature_values={'feature2': 'value1'})
        ]

        # For cross slice tests.
        baseline_slice_spec = config.SlicingSpec(feature_keys=['feature3'])

        metrics_specs = [
            config.MetricsSpec(
                thresholds={
                    'auc':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()),
                    'mean/label':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold(),
                        change_threshold=config.GenericChangeThreshold()),
                    'mse':
                    config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())
                },
                per_slice_thresholds={
                    'auc':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                )))
                    ]),
                    'mean/label':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ])
                },
                cross_slice_thresholds={
                    'auc':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
                    'mse':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                        # Test for duplicate cross_slicing_spec.
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold())
                        )
                    ])
                },
                model_names=['model_name'],
                output_names=['output_name']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='ExampleCount',
                    config=json.dumps({'name': 'example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='WeightedExampleCount',
                    config=json.dumps({'name': 'weighted_example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='MeanSquaredError',
                    config=json.dumps({'name': 'mse'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())),
                config.MetricConfig(
                    class_name='MeanLabel',
                    config=json.dumps({'name': 'mean_label'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold()),
                    per_slice_thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                    ],
                    cross_slice_thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
            ],
                               model_names=['model_name'],
                               output_names=['output_name'],
                               binarize=config.BinarizationOptions(
                                   class_ids={'values': [0, 1]}),
                               aggregate=config.AggregationOptions(
                                   macro_average=True,
                                   class_weights={
                                       0: 1.0,
                                       1: 1.0
                                   }))
        ]

        thresholds = metric_specs.metric_thresholds_from_metrics_specs(
            metrics_specs)

        expected_keys_and_threshold_counts = {
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            4,
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            3,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            3,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            2,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            4
        }
        self.assertLen(thresholds, len(expected_keys_and_threshold_counts))
        for key, count in expected_keys_and_threshold_counts.items():
            self.assertIn(key, thresholds)
            self.assertLen(thresholds[key], count,
                           'failed for key {}'.format(key))
def validate_metrics(
        sliced_metrics: Tuple[slicer.SliceKeyType,
                              Dict['metric_types.MetricKey',
                                   Any]], eval_config: config.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)  # pytype: disable=wrong-arg-types

    def _check_threshold(key: metric_types.MetricKey,
                         slicing_spec: Optional[config.SlicingSpec],
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        if (slicing_spec is not None and not slicer.SingleSliceSpec(
                spec=slicing_spec).is_slice_applicable(sliced_key)):
            return True

        if isinstance(threshold, config.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric > lower_bound and metric < upper_bound
        elif isinstance(threshold, config.GenericChangeThreshold):
            diff = metric
            ratio = diff / metrics[key.make_baseline_key(baseline_model_name)]
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError('"UNKNOWN" direction for change threshold.')
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                return diff < absolute and ratio < relative
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                return diff > absolute and ratio > relative

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        # We try to convert to float values.
        try:
            metric = float(metric)
        except (TypeError, ValueError):
            msg = """
        Invalid threshold config: This metric is not comparable to the
        threshold. The type of the threshold is: {}, and the metric value is:
        \n{}""".format(type(metric), metric)
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if not _check_threshold(metric_key, slice_spec, threshold, metric):
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for _, threshold in thresholds:
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        validation_for_slice.slice_key.CopyFrom(
            slicer.serialize_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result