def validate_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]], eval_config: config.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) def _check_threshold(key: metric_types.MetricKey, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" threshold = thresholds[key] if isinstance(threshold, config.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric > lower_bound and metric < upper_bound elif isinstance(threshold, config.GenericChangeThreshold): diff = metric ratio = diff / metrics[key.make_baseline_key(baseline_model_name)] if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError('"UNKNOWN" direction for change threshold.') if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: return diff < absolute and ratio < relative elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: return diff > absolute and ratio > relative def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() for metric_key, metric in metrics.items(): # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if (metric_key.model_name == baseline_model_name or metric_key not in thresholds): continue msg = '' # We try to convert to float values. try: metric = float(metric) except (TypeError, ValueError): msg = """ Invalid threshold config: This metric is not comparable to the threshold. The type of the threshold is: {}, and the metric value is: \n{}""".format(type(metric), metric) if not _check_threshold(metric_key, metric): failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(thresholds[metric_key], failure.metric_threshold) failure.message = msg # Any failure leads to overall failure. if validation_for_slice.failures: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result
def validate_metrics( sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], Dict['metric_types.MetricKey', Any]], eval_config: config_pb2.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) is_cross_slice = slicer.is_cross_slice_key(sliced_key) def _check_threshold(key: metric_types.MetricKey, threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" metric = float(metric) if isinstance(threshold, config_pb2.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric >= lower_bound and metric <= upper_bound elif isinstance(threshold, config_pb2.GenericChangeThreshold): diff = metric metric_baseline = float( metrics[key.make_baseline_key(baseline_model_name)]) if math.isclose(metric_baseline, 0.0): ratio = float('nan') else: ratio = diff / metric_baseline if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError( '"UNKNOWN" direction for change threshold: {}.'.format( threshold)) if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: return diff <= absolute and ratio <= relative elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: return diff >= absolute and ratio >= relative else: raise ValueError('Unknown threshold: {}'.format(threshold)) def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config_pb2.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config_pb2.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if slice_spec is not None: if (isinstance(slice_spec, config_pb2.SlicingSpec) and (is_cross_slice or not slicer.SingleSliceSpec( spec=slice_spec).is_slice_applicable(sliced_key))): continue if (isinstance(slice_spec, config_pb2.CrossSlicingSpec) and (not is_cross_slice or not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec))): continue elif is_cross_slice: continue try: check_result = _check_threshold(metric_key, threshold, metric) except ValueError: msg = """ Invalid metrics or threshold for comparison: The type of the metric is: {}, the metric value is: {}, and the threshold is: {}. """.format(type(metric), metric, threshold) check_result = False else: msg = '' if not check_result: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # Track we have completed a validation check for slice spec and metric slicing_details = result.validation_details.slicing_details.add() if slice_spec is not None: if isinstance(slice_spec, config_pb2.SlicingSpec): slicing_details.slicing_spec.CopyFrom(slice_spec) else: slicing_details.cross_slicing_spec.CopyFrom(slice_spec) else: slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) slicing_details.num_matching_slices = 1 # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for slice_spec, threshold in thresholds: if slice_spec is not None: if is_cross_slice != isinstance(slice_spec, config_pb2.CrossSlicingSpec): continue if (is_cross_slice and not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec)): continue elif is_cross_slice: continue # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: if not is_cross_slice: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) else: validation_for_slice.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result
def validate_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, Dict['metric_types.MetricKey', Any]], eval_config: config.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) # pytype: disable=wrong-arg-types def _check_threshold(key: metric_types.MetricKey, slicing_spec: Optional[config.SlicingSpec], threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" if (slicing_spec is not None and not slicer.SingleSliceSpec( spec=slicing_spec).is_slice_applicable(sliced_key)): return True if isinstance(threshold, config.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric > lower_bound and metric < upper_bound elif isinstance(threshold, config.GenericChangeThreshold): diff = metric ratio = diff / metrics[key.make_baseline_key(baseline_model_name)] if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError('"UNKNOWN" direction for change threshold.') if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: return diff < absolute and ratio < relative elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: return diff > absolute and ratio > relative def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' # We try to convert to float values. try: metric = float(metric) except (TypeError, ValueError): msg = """ Invalid threshold config: This metric is not comparable to the threshold. The type of the threshold is: {}, and the metric value is: \n{}""".format(type(metric), metric) existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if not _check_threshold(metric_key, slice_spec, threshold, metric): # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for _, threshold in thresholds: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result