Esempio n. 1
0
 def testIsCrossSliceApplicable(self):
   test_cases = [
       (True, 'overall pass', ((), (('b', 2),)), config_pb2.CrossSlicingSpec(
           baseline_spec=config_pb2.SlicingSpec(),
           slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (True, 'value pass', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (True, 'baseline key pass', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (True, 'comparison key pass', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])),
       (True, 'comparison multiple key pass', ((('a', 1),), (('c', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']),
                           config_pb2.SlicingSpec(feature_keys=['c'])])),
       (False, 'overall fail', ((('a', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (False, 'value fail', ((('a', 1),), (('b', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (False, 'baseline key fail', ((('c', 1),), (('b', 2),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']),
            slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])),
       (False, 'comparison key fail', ((('a', 1),), (('c', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])),
       (False, 'comparison multiple key fail', ((('a', 1),), (('d', 3),)),
        config_pb2.CrossSlicingSpec(
            baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}),
            slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']),
                           config_pb2.SlicingSpec(feature_keys=['c'])])),
   ]  # pyformat: disable
   for (expected_result, name, sliced_key, slicing_spec) in test_cases:
     self.assertEqual(
         expected_result,
         slicer.is_cross_slice_applicable(
             cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec),
         msg=name)
Esempio n. 2
0
def validate_metrics(
    sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType],
                          Dict['metric_types.MetricKey',
                               Any]], eval_config: config_pb2.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    is_cross_slice = slicer.is_cross_slice_key(sliced_key)

    def _check_threshold(key: metric_types.MetricKey,
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        metric = float(metric)
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric >= lower_bound and metric <= upper_bound
        elif isinstance(threshold, config_pb2.GenericChangeThreshold):
            diff = metric
            metric_baseline = float(
                metrics[key.make_baseline_key(baseline_model_name)])
            if math.isclose(metric_baseline, 0.0):
                ratio = float('nan')
            else:
                ratio = diff / metric_baseline
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError(
                    '"UNKNOWN" direction for change threshold: {}.'.format(
                        threshold))
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                return diff <= absolute and ratio <= relative
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                return diff >= absolute and ratio >= relative
        else:
            raise ValueError('Unknown threshold: {}'.format(threshold))

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config_pb2.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if slice_spec is not None:
                if (isinstance(slice_spec, config_pb2.SlicingSpec)
                        and (is_cross_slice or not slicer.SingleSliceSpec(
                            spec=slice_spec).is_slice_applicable(sliced_key))):
                    continue
                if (isinstance(slice_spec, config_pb2.CrossSlicingSpec)
                        and (not is_cross_slice
                             or not slicer.is_cross_slice_applicable(
                                 cross_slice_key=sliced_key,
                                 cross_slicing_spec=slice_spec))):
                    continue
            elif is_cross_slice:
                continue
            try:
                check_result = _check_threshold(metric_key, threshold, metric)
            except ValueError:
                msg = """
          Invalid metrics or threshold for comparison: The type of the metric
          is: {}, the metric value is: {}, and the threshold is: {}.
          """.format(type(metric), metric, threshold)
                check_result = False
            else:
                msg = ''
            if not check_result:
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
            # Track we have completed a validation check for slice spec and metric
            slicing_details = result.validation_details.slicing_details.add()
            if slice_spec is not None:
                if isinstance(slice_spec, config_pb2.SlicingSpec):
                    slicing_details.slicing_spec.CopyFrom(slice_spec)
                else:
                    slicing_details.cross_slicing_spec.CopyFrom(slice_spec)
            else:
                slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec())
            slicing_details.num_matching_slices = 1
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for slice_spec, threshold in thresholds:
            if slice_spec is not None:
                if is_cross_slice != isinstance(slice_spec,
                                                config_pb2.CrossSlicingSpec):
                    continue
                if (is_cross_slice
                        and not slicer.is_cross_slice_applicable(
                            cross_slice_key=sliced_key,
                            cross_slicing_spec=slice_spec)):
                    continue
            elif is_cross_slice:
                continue
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        if not is_cross_slice:
            validation_for_slice.slice_key.CopyFrom(
                slicer.serialize_slice_key(sliced_key))
        else:
            validation_for_slice.cross_slice_key.CopyFrom(
                slicer.serialize_cross_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result