def extract_output( self, accumulator: 'Optional[validation_result_pb2.ValidationResult]' ) -> 'Optional[validation_result_pb2.ValidationResult]': # Verification fails if there is empty input. if not accumulator: accumulator = validation_result_pb2.ValidationResult( validation_ok=False) thresholds = metric_specs.metric_thresholds_from_metrics_specs( self._eval_config.metrics_specs) if not thresholds: # Default is to validation NOT ok when not rubber stamping. accumulator.validation_ok = self._rubber_stamp # Default is to missing thresholds when not rubber stamping. accumulator.missing_thresholds = not self._rubber_stamp missing = metrics_validator.get_missing_slices( accumulator.validation_details.slicing_details, self._eval_config) if missing: missing_slices = [] missing_cross_slices = [] for m in missing: if isinstance(m, config.SlicingSpec): missing_slices.append(m) elif isinstance(m, config.CrossSlicingSpec): missing_cross_slices.append(m) accumulator.validation_ok = False if missing_slices: accumulator.missing_slices.extend(missing_slices) if missing_cross_slices: accumulator.missing_cross_slices.extend(missing_cross_slices) if self._rubber_stamp: accumulator.rubber_stamp = True return accumulator
def get_missing_slices( slicing_details: Iterable[validation_result_pb2.SlicingDetails], eval_config: config_pb2.EvalConfig ) -> List[Union[config_pb2.SlicingSpec, config_pb2.CrossSlicingSpec]]: """Returns specs that are defined in the EvalConfig but not found in details. Args: slicing_details: Slicing details. eval_config: Eval config. Returns: List of missing slices or empty list if none are missing. """ hashed_details = _hashed_slicing_details(slicing_details) thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None missing_slices = [] for metric_key, sliced_thresholds in thresholds.items(): # Skip baseline. if metric_key.model_name == baseline_model_name: continue for slice_spec, _ in sliced_thresholds: if not slice_spec: slice_spec = config_pb2.SlicingSpec() slice_hash = slice_spec.SerializeToString() if slice_hash not in hashed_details: missing_slices.append(slice_spec) # Same slice may be used by other metrics/thresholds, only add once hashed_details[ slice_hash] = validation_result_pb2.SlicingDetails() return missing_slices
def testMetricThresholdsFromMetricsSpecs(self): metrics_specs = [ config.MetricsSpec( thresholds={ 'auc': config.MetricThreshold( value_threshold=config.GenericValueThreshold()), 'mean/label': config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold()), # The mse metric will be overridden by MetricConfig below. 'mse': config.MetricThreshold( change_threshold=config.GenericChangeThreshold()) }, # Model names and output_names should be ignored because # ExampleCount is model independent. model_names=['model_name'], output_names=['output_name']), config.MetricsSpec( metrics=[ config.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())) ], # Model names and output_names should be ignored because # ExampleCount is model independent. model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), config.MetricsSpec( metrics=[ config.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold())), config.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold())) ], model_names=['model_name'], output_names=['output_name'], binarize=config.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config.AggregationOptions(macro_average=True)) ] thresholds = metric_specs.metric_thresholds_from_metrics_specs( metrics_specs) self.assertLen(thresholds, 14) self.assertIn( metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name'), thresholds) self.assertIn( metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=True), thresholds) self.assertIn( metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=False), thresholds) self.assertIn(metric_types.MetricKey(name='example_count'), thresholds) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name1'), thresholds) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name2'), thresholds) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name1'), thresholds) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name2'), thresholds) self.assertIn( metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True), thresholds) self.assertIn( metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True), thresholds) self.assertIn( metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=True), thresholds) self.assertIn( metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True), thresholds) self.assertIn( metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True), thresholds) self.assertIn( metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', is_diff=True), thresholds)
def validate_metrics( sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], Dict['metric_types.MetricKey', Any]], eval_config: config_pb2.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) is_cross_slice = slicer.is_cross_slice_key(sliced_key) def _check_threshold(key: metric_types.MetricKey, threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" metric = float(metric) if isinstance(threshold, config_pb2.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric >= lower_bound and metric <= upper_bound elif isinstance(threshold, config_pb2.GenericChangeThreshold): diff = metric metric_baseline = float( metrics[key.make_baseline_key(baseline_model_name)]) if math.isclose(metric_baseline, 0.0): ratio = float('nan') else: ratio = diff / metric_baseline if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError( '"UNKNOWN" direction for change threshold: {}.'.format( threshold)) if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: return diff <= absolute and ratio <= relative elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: return diff >= absolute and ratio >= relative else: raise ValueError('Unknown threshold: {}'.format(threshold)) def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config_pb2.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config_pb2.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if slice_spec is not None: if (isinstance(slice_spec, config_pb2.SlicingSpec) and (is_cross_slice or not slicer.SingleSliceSpec( spec=slice_spec).is_slice_applicable(sliced_key))): continue if (isinstance(slice_spec, config_pb2.CrossSlicingSpec) and (not is_cross_slice or not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec))): continue elif is_cross_slice: continue try: check_result = _check_threshold(metric_key, threshold, metric) except ValueError: msg = """ Invalid metrics or threshold for comparison: The type of the metric is: {}, the metric value is: {}, and the threshold is: {}. """.format(type(metric), metric, threshold) check_result = False else: msg = '' if not check_result: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # Track we have completed a validation check for slice spec and metric slicing_details = result.validation_details.slicing_details.add() if slice_spec is not None: if isinstance(slice_spec, config_pb2.SlicingSpec): slicing_details.slicing_spec.CopyFrom(slice_spec) else: slicing_details.cross_slicing_spec.CopyFrom(slice_spec) else: slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) slicing_details.num_matching_slices = 1 # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for slice_spec, threshold in thresholds: if slice_spec is not None: if is_cross_slice != isinstance(slice_spec, config_pb2.CrossSlicingSpec): continue if (is_cross_slice and not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec)): continue elif is_cross_slice: continue # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: if not is_cross_slice: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) else: validation_for_slice.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result
def validate_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]], eval_config: config.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) def _check_threshold(key: metric_types.MetricKey, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" threshold = thresholds[key] if isinstance(threshold, config.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric > lower_bound and metric < upper_bound elif isinstance(threshold, config.GenericChangeThreshold): diff = metric ratio = diff / metrics[key.make_baseline_key(baseline_model_name)] if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError('"UNKNOWN" direction for change threshold.') if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: return diff < absolute and ratio < relative elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: return diff > absolute and ratio > relative def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() for metric_key, metric in metrics.items(): # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if (metric_key.model_name == baseline_model_name or metric_key not in thresholds): continue msg = '' # We try to convert to float values. try: metric = float(metric) except (TypeError, ValueError): msg = """ Invalid threshold config: This metric is not comparable to the threshold. The type of the threshold is: {}, and the metric value is: \n{}""".format(type(metric), metric) if not _check_threshold(metric_key, metric): failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(thresholds[metric_key], failure.metric_threshold) failure.message = msg # Any failure leads to overall failure. if validation_for_slice.failures: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result
def testMetricThresholdsFromMetricsSpecs(self): slice_specs = [ config.SlicingSpec(feature_keys=['feature1']), config.SlicingSpec(feature_values={'feature2': 'value1'}) ] # For cross slice tests. baseline_slice_spec = config.SlicingSpec(feature_keys=['feature3']) metrics_specs = [ config.MetricsSpec( thresholds={ 'auc': config.MetricThreshold( value_threshold=config.GenericValueThreshold()), 'mean/label': config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold()), 'mse': config.MetricThreshold( change_threshold=config.GenericChangeThreshold()) }, per_slice_thresholds={ 'auc': config.PerSliceMetricThresholds(thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold( ))) ]), 'mean/label': config.PerSliceMetricThresholds(thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold( ))) ]) }, cross_slice_thresholds={ 'auc': config.CrossSliceMetricThresholds(thresholds=[ config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold(), change_threshold=config.GenericChangeThreshold( ))) ]), 'mse': config.CrossSliceMetricThresholds(thresholds=[ config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( ))), # Test for duplicate cross_slicing_spec. config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold()) ) ]) }, model_names=['model_name'], output_names=['output_name']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), config.MetricsSpec(metrics=[ config.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold())), config.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold()), per_slice_thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( ))), ], cross_slice_thresholds=[ config.CrossSliceMetricThreshold( cross_slicing_specs=[ config.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold( ))) ]), ], model_names=['model_name'], output_names=['output_name'], binarize=config.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config.AggregationOptions( macro_average=True, class_weights={ 0: 1.0, 1: 1.0 })) ] thresholds = metric_specs.metric_thresholds_from_metrics_specs( metrics_specs) expected_keys_and_threshold_counts = { metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=False): 4, metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=True): 1, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=True): 3, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=False): 3, metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name1'): 1, metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name2'): 1, metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name1'): 1, metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name2'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name1'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name2'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name1'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name2'): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=True): 2, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=False): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 1, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 4 } self.assertLen(thresholds, len(expected_keys_and_threshold_counts)) for key, count in expected_keys_and_threshold_counts.items(): self.assertIn(key, thresholds) self.assertLen(thresholds[key], count, 'failed for key {}'.format(key))
def validate_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, Dict['metric_types.MetricKey', Any]], eval_config: config.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) # pytype: disable=wrong-arg-types def _check_threshold(key: metric_types.MetricKey, slicing_spec: Optional[config.SlicingSpec], threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" if (slicing_spec is not None and not slicer.SingleSliceSpec( spec=slicing_spec).is_slice_applicable(sliced_key)): return True if isinstance(threshold, config.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric > lower_bound and metric < upper_bound elif isinstance(threshold, config.GenericChangeThreshold): diff = metric ratio = diff / metrics[key.make_baseline_key(baseline_model_name)] if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError('"UNKNOWN" direction for change threshold.') if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: return diff < absolute and ratio < relative elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: return diff > absolute and ratio > relative def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' # We try to convert to float values. try: metric = float(metric) except (TypeError, ValueError): msg = """ Invalid threshold config: This metric is not comparable to the threshold. The type of the threshold is: {}, and the metric value is: \n{}""".format(type(metric), metric) existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if not _check_threshold(metric_key, slice_spec, threshold, metric): # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for _, threshold in thresholds: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result