def testIsCrossSliceApplicable(self): test_cases = [ (True, 'overall pass', ((), (('b', 2),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(), slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), (True, 'value pass', ((('a', 1),), (('b', 2),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), (True, 'baseline key pass', ((('a', 1),), (('b', 2),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']), slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), (True, 'comparison key pass', ((('a', 1),), (('b', 2),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])), (True, 'comparison multiple key pass', ((('a', 1),), (('c', 3),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']), config_pb2.SlicingSpec(feature_keys=['c'])])), (False, 'overall fail', ((('a', 1),), (('b', 2),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(), slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), (False, 'value fail', ((('a', 1),), (('b', 3),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), (False, 'baseline key fail', ((('c', 1),), (('b', 2),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_keys=['a']), slicing_specs=[config_pb2.SlicingSpec(feature_values={'b': '2'})])), (False, 'comparison key fail', ((('a', 1),), (('c', 3),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b'])])), (False, 'comparison multiple key fail', ((('a', 1),), (('d', 3),)), config_pb2.CrossSlicingSpec( baseline_spec=config_pb2.SlicingSpec(feature_values={'a': '1'}), slicing_specs=[config_pb2.SlicingSpec(feature_keys=['b']), config_pb2.SlicingSpec(feature_keys=['c'])])), ] # pyformat: disable for (expected_result, name, sliced_key, slicing_spec) in test_cases: self.assertEqual( expected_result, slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slicing_spec), msg=name)
def validate_metrics( sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], Dict['metric_types.MetricKey', Any]], eval_config: config_pb2.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) is_cross_slice = slicer.is_cross_slice_key(sliced_key) def _check_threshold(key: metric_types.MetricKey, threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" metric = float(metric) if isinstance(threshold, config_pb2.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric >= lower_bound and metric <= upper_bound elif isinstance(threshold, config_pb2.GenericChangeThreshold): diff = metric metric_baseline = float( metrics[key.make_baseline_key(baseline_model_name)]) if math.isclose(metric_baseline, 0.0): ratio = float('nan') else: ratio = diff / metric_baseline if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError( '"UNKNOWN" direction for change threshold: {}.'.format( threshold)) if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: return diff <= absolute and ratio <= relative elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: return diff >= absolute and ratio >= relative else: raise ValueError('Unknown threshold: {}'.format(threshold)) def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config_pb2.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config_pb2.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if slice_spec is not None: if (isinstance(slice_spec, config_pb2.SlicingSpec) and (is_cross_slice or not slicer.SingleSliceSpec( spec=slice_spec).is_slice_applicable(sliced_key))): continue if (isinstance(slice_spec, config_pb2.CrossSlicingSpec) and (not is_cross_slice or not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec))): continue elif is_cross_slice: continue try: check_result = _check_threshold(metric_key, threshold, metric) except ValueError: msg = """ Invalid metrics or threshold for comparison: The type of the metric is: {}, the metric value is: {}, and the threshold is: {}. """.format(type(metric), metric, threshold) check_result = False else: msg = '' if not check_result: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # Track we have completed a validation check for slice spec and metric slicing_details = result.validation_details.slicing_details.add() if slice_spec is not None: if isinstance(slice_spec, config_pb2.SlicingSpec): slicing_details.slicing_spec.CopyFrom(slice_spec) else: slicing_details.cross_slicing_spec.CopyFrom(slice_spec) else: slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) slicing_details.num_matching_slices = 1 # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for slice_spec, threshold in thresholds: if slice_spec is not None: if is_cross_slice != isinstance(slice_spec, config_pb2.CrossSlicingSpec): continue if (is_cross_slice and not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec)): continue elif is_cross_slice: continue # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: if not is_cross_slice: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) else: validation_for_slice.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result