def _serialize_plots( plots: Tuple[slicer.SliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType]) -> bytes: """Converts the given slice plots into serialized proto PlotsForSlice.. Args: plots: The slice plots. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto PlotsForSlice. """ result = metrics_for_slice_pb2.PlotsForSlice() slice_key, slice_plots = plots if metric_keys.ERROR_METRIC in slice_plots: tf.compat.v1.logging.warning( 'Error for slice: %s with error message: %s ', slice_key, slice_plots[metric_keys.ERROR_METRIC]) metrics = metrics_for_slice_pb2.PlotsForSlice() metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) metrics.plots[metric_keys.ERROR_METRIC].debug_message = slice_plots[ metric_keys.ERROR_METRIC] return metrics.SerializeToString() # Convert the slice key. result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) # Convert the slice plots. _convert_slice_plots(slice_plots, add_metrics_callbacks, result) # pytype: disable=wrong-arg-types return result.SerializeToString()
def testValidateMetricsMetricTDistributionValueAndThreshold( self, slicing_specs, slice_key): threshold = config.MetricThreshold( value_threshold=config.GenericValueThreshold( lower_bound={'value': 0.9})) eval_config = config.EvalConfig( model_specs=[ config.ModelSpec(), ], slicing_specs=slicing_specs, metrics_specs=[ config.MetricsSpec(metrics=[ config.MetricConfig( class_name='AUC', threshold=threshold if slicing_specs is None else None, per_slice_thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slicing_specs, threshold=threshold) ]), ], model_names=['']), ], ) sliced_metrics = (slice_key, { metric_types.MetricKey(name='auc'): types.ValueWithTDistribution(sample_mean=0.91, unsampled_value=0.8) }) result = metrics_validator.validate_metrics(sliced_metrics, eval_config) self.assertFalse(result.validation_ok) expected = text_format.Parse( """ metric_validations_per_slice { failures { metric_key { name: "auc" } metric_value { double_value { value: 0.8 } } } }""", validation_result_pb2.ValidationResult()) expected.metric_validations_per_slice[0].failures[ 0].metric_threshold.CopyFrom(threshold) expected.metric_validations_per_slice[0].slice_key.CopyFrom( slicer.serialize_slice_key(slice_key)) for spec in slicing_specs or [None]: if (spec is None or slicer.SingleSliceSpec( spec=spec).is_slice_applicable(slice_key)): slicing_details = expected.validation_details.slicing_details.add( ) if spec is not None: slicing_details.slicing_spec.CopyFrom(spec) else: slicing_details.slicing_spec.CopyFrom(config.SlicingSpec()) slicing_details.num_matching_slices = 1 self.assertEqual(result, expected)
def testConvertSlicePlotsToProtoEmptyPlot(self): slice_key = _make_slice_key('fruit', 'apple') tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'} actual_plot = metrics_plots_and_validations_writer.convert_slice_plots_to_proto( (slice_key, tfma_plots), []) expected_plot = metrics_for_slice_pb2.PlotsForSlice() expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_plot.plots[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals(expected_plot, actual_plot)
def testSerializePlots_emptyPlot(self): slice_key = _make_slice_key('fruit', 'apple') tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'} actual_plot = metrics_and_plots_serialization._serialize_plots( (slice_key, tfma_plots), []) expected_plot = metrics_for_slice_pb2.PlotsForSlice() expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_plot.plots[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals( expected_plot, metrics_for_slice_pb2.PlotsForSlice.FromString(actual_plot))
def convert_slice_attributions_to_proto( attributions: Tuple[slicer.SliceKeyOrCrossSliceKeyType, Dict[Any, Dict[Text, Any]]] ) -> metrics_for_slice_pb2.AttributionsForSlice: """Converts the given slice attributions into serialized AtributionsForSlice. Args: attributions: The slice attributions. Returns: The AttributionsForSlice proto. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.AttributionsForSlice() slice_key, slice_attributions = attributions if slicer.is_cross_slice_key(slice_key): result.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(slice_key)) else: result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_attributions = slice_attributions.copy() for key in sorted(slice_attributions.keys()): key_and_value = result.attributions_keys_and_values.add() key_and_value.key.CopyFrom(key.to_proto()) for feature, value in slice_attributions[key].items(): attribution_value = metrics_for_slice_pb2.MetricValue() if isinstance(value, six.binary_type): # Convert textual types to string metrics. attribution_value.bytes_value = value elif isinstance(value, six.text_type): # Convert textual types to string metrics. attribution_value.bytes_value = value.encode('utf8') elif isinstance(value, np.ndarray) and value.size != 1: # Convert NumPy arrays to ArrayValue. attribution_value.array_value.CopyFrom( _convert_to_array_value(value)) else: # We try to convert to float values. try: attribution_value.double_value.value = float(value) except (TypeError, ValueError) as e: attribution_value.unknown_type.value = str(value) attribution_value.unknown_type.error = e.message # pytype: disable=attribute-error key_and_value.values[feature].CopyFrom(attribution_value) return result
def _serialize_metrics( metrics: Tuple[slicer.SliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType]) -> bytes: """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The serialized proto MetricsForSlice. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics if metric_keys.ERROR_METRIC in slice_metrics: tf.compat.v1.logging.warning( 'Error for slice: %s with error message: %s ', slice_key, slice_metrics[metric_keys.ERROR_METRIC]) metrics = metrics_for_slice_pb2.MetricsForSlice() metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) metrics.metrics[ metric_keys.ERROR_METRIC].debug_message = slice_metrics[ metric_keys.ERROR_METRIC] return metrics.SerializeToString() # Convert the slice key. result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) # Convert the slice metrics. convert_slice_metrics(slice_key, slice_metrics, add_metrics_callbacks, result) return result.SerializeToString()
def testValidateMetricsMetricValueAndThreshold(self, slicing_specs, slice_key): threshold = config.MetricThreshold( value_threshold=config.GenericValueThreshold(upper_bound={'value': 1})) eval_config = config.EvalConfig( model_specs=[ config.ModelSpec(), ], slicing_specs=slicing_specs, metrics_specs=[ config.MetricsSpec( metrics=[ config.MetricConfig( class_name='WeightedExampleCount', # 1.5 < 1, NOT OK. threshold=threshold if slicing_specs is None else None, per_slice_thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slicing_specs, threshold=threshold) ]), ], model_names=['']), ], ) sliced_metrics = (slice_key, { metric_types.MetricKey(name='weighted_example_count'): 1.5, }) result = metrics_validator.validate_metrics(sliced_metrics, eval_config) self.assertFalse(result.validation_ok) expected = text_format.Parse( """ metric_validations_per_slice { failures { metric_key { name: "weighted_example_count" } metric_value { double_value { value: 1.5 } } } }""", validation_result_pb2.ValidationResult()) expected.metric_validations_per_slice[0].failures[ 0].metric_threshold.CopyFrom(threshold) expected.metric_validations_per_slice[0].slice_key.CopyFrom( slicer.serialize_slice_key(slice_key)) self.assertEqual(result, expected)
def testConvertSliceMetricsToProtoEmptyMetrics(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'} actual_metrics = ( metrics_plots_and_validations_writer.convert_slice_metrics_to_proto( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')])) expected_metrics = metrics_for_slice_pb2.MetricsForSlice() expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_metrics.metrics[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals(expected_metrics, actual_metrics)
def testSerializeMetrics_emptyMetrics(self): slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3) slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'} actual_metrics = metrics_and_plots_serialization._serialize_metrics( (slice_key, slice_metrics), [post_export_metrics.auc(), post_export_metrics.auc(curve='PR')]) expected_metrics = metrics_for_slice_pb2.MetricsForSlice() expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_metrics.metrics[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals( expected_metrics, metrics_for_slice_pb2.MetricsForSlice.FromString(actual_metrics))
def convert_slice_plots_to_proto( plots: Tuple[slicer.SliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType] ) -> metrics_for_slice_pb2.PlotsForSlice: """Converts the given slice plots into PlotsForSlice proto. Args: plots: The slice plots. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The PlotsForSlice proto. """ result = metrics_for_slice_pb2.PlotsForSlice() slice_key, slice_plots = plots result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_plots = slice_plots.copy() if metric_keys.ERROR_METRIC in slice_plots: logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_plots[metric_keys.ERROR_METRIC]) error_metric = slice_plots.pop(metric_keys.ERROR_METRIC) result.plots[metric_keys.ERROR_METRIC].debug_message = error_metric return result if add_metrics_callbacks and (not any( isinstance(k, metric_types.MetricKey) for k in slice_plots.keys())): for add_metrics_callback in add_metrics_callbacks: if hasattr(add_metrics_callback, 'populate_plots_and_pop'): add_metrics_callback.populate_plots_and_pop( slice_plots, result.plots) plots_by_key = {} for key in sorted(slice_plots.keys()): value = slice_plots[key] # Remove plot name from key (multiple plots are combined into a single # proto). if isinstance(key, metric_types.MetricKey): parent_key = key._replace(name=None) else: continue if parent_key not in plots_by_key: key_and_value = result.plot_keys_and_values.add() key_and_value.key.CopyFrom(parent_key.to_proto()) plots_by_key[parent_key] = key_and_value.value if isinstance(value, metrics_for_slice_pb2.CalibrationHistogramBuckets): plots_by_key[parent_key].calibration_histogram_buckets.CopyFrom( value) slice_plots.pop(key) elif isinstance(value, metrics_for_slice_pb2.ConfusionMatrixAtThresholds): plots_by_key[parent_key].confusion_matrix_at_thresholds.CopyFrom( value) slice_plots.pop(key) elif isinstance( value, metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds): plots_by_key[ parent_key].multi_class_confusion_matrix_at_thresholds.CopyFrom( value) slice_plots.pop(key) elif isinstance( value, metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds): plots_by_key[ parent_key].multi_label_confusion_matrix_at_thresholds.CopyFrom( value) slice_plots.pop(key) if slice_plots: if add_metrics_callbacks is None: add_metrics_callbacks = [] raise NotImplementedError( 'some plots were not converted or popped. keys: %s. ' 'add_metrics_callbacks were: %s' % ( slice_plots.keys(), [ x.name for x in add_metrics_callbacks # pytype: disable=attribute-error ])) return result
def convert_slice_metrics_to_proto( metrics: Tuple[slicer.SliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType] ) -> metrics_for_slice_pb2.MetricsForSlice: """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The MetricsForSlice proto. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_metrics = slice_metrics.copy() if metric_keys.ERROR_METRIC in slice_metrics: logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_metrics[metric_keys.ERROR_METRIC]) result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[ metric_keys.ERROR_METRIC] return result # Convert the metrics from add_metrics_callbacks to the structured output if # defined. if add_metrics_callbacks and (not any( isinstance(k, metric_types.MetricKey) for k in slice_metrics.keys())): for add_metrics_callback in add_metrics_callbacks: if hasattr(add_metrics_callback, 'populate_stats_and_pop'): add_metrics_callback.populate_stats_and_pop( slice_key, slice_metrics, result.metrics) for key in sorted(slice_metrics.keys()): value = slice_metrics[key] metric_value = metrics_for_slice_pb2.MetricValue() if isinstance(value, metrics_for_slice_pb2.ConfusionMatrixAtThresholds): metric_value.confusion_matrix_at_thresholds.CopyFrom(value) elif isinstance( value, metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds): metric_value.multi_class_confusion_matrix_at_thresholds.CopyFrom( value) elif isinstance(value, types.ValueWithTDistribution): # Currently we populate both bounded_value and confidence_interval. # Avoid populating bounded_value once the UI handles confidence_interval. # Convert to a bounded value. 95% confidence level is computed here. _, lower_bound, upper_bound = ( math_util.calculate_confidence_interval(value)) metric_value.bounded_value.value.value = value.unsampled_value metric_value.bounded_value.lower_bound.value = lower_bound metric_value.bounded_value.upper_bound.value = upper_bound metric_value.bounded_value.methodology = ( metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP) # Populate confidence_interval metric_value.confidence_interval.lower_bound.value = lower_bound metric_value.confidence_interval.upper_bound.value = upper_bound t_dist_value = metrics_for_slice_pb2.TDistributionValue() t_dist_value.sample_mean.value = value.sample_mean t_dist_value.sample_standard_deviation.value = ( value.sample_standard_deviation) t_dist_value.sample_degrees_of_freedom.value = ( value.sample_degrees_of_freedom) # Once the UI handles confidence interval, we will avoid setting this and # instead use the double_value. t_dist_value.unsampled_value.value = value.unsampled_value metric_value.confidence_interval.t_distribution_value.CopyFrom( t_dist_value) elif isinstance(value, six.binary_type): # Convert textual types to string metrics. metric_value.bytes_value = value elif isinstance(value, six.text_type): # Convert textual types to string metrics. metric_value.bytes_value = value.encode('utf8') elif isinstance(value, np.ndarray): # Convert NumPy arrays to ArrayValue. metric_value.array_value.CopyFrom(_convert_to_array_value(value)) else: # We try to convert to float values. try: metric_value.double_value.value = float(value) except (TypeError, ValueError) as e: metric_value.unknown_type.value = str(value) metric_value.unknown_type.error = e.message # pytype: disable=attribute-error if isinstance(key, metric_types.MetricKey): key_and_value = result.metric_keys_and_values.add() key_and_value.key.CopyFrom(key.to_proto()) key_and_value.value.CopyFrom(metric_value) else: result.metrics[key].CopyFrom(metric_value) return result
def testValidateMetricsMetricTDistributionChangeAndThreshold( self, slicing_specs, slice_key): threshold = config.MetricThreshold( change_threshold=config.GenericChangeThreshold( direction=config.MetricDirection.LOWER_IS_BETTER, absolute={'value': -1})) eval_config = config.EvalConfig( model_specs=[ config.ModelSpec(), config.ModelSpec(name='baseline', is_baseline=True) ], slicing_specs=slicing_specs, metrics_specs=[ config.MetricsSpec(metrics=[ config.MetricConfig( class_name='AUC', threshold=threshold if slicing_specs is None else None, per_slice_thresholds=[ config.PerSliceMetricThreshold( slicing_specs=slicing_specs, threshold=threshold) ]), ], model_names=['']), ], ) sliced_metrics = ( slice_key, { # This is the mean of the diff. metric_types.MetricKey(name='auc', model_name='baseline'): types.ValueWithTDistribution(sample_mean=0.91, unsampled_value=0.6), metric_types.MetricKey(name='auc', is_diff=True): types.ValueWithTDistribution(sample_mean=0.1, unsampled_value=0.1), }) result = metrics_validator.validate_metrics(sliced_metrics, eval_config) self.assertFalse(result.validation_ok) expected = text_format.Parse( """ metric_validations_per_slice { failures { metric_key { name: "auc" is_diff: true } metric_value { double_value { value: 0.1 } } } }""", validation_result_pb2.ValidationResult()) expected.metric_validations_per_slice[0].failures[ 0].metric_threshold.CopyFrom(threshold) expected.metric_validations_per_slice[0].slice_key.CopyFrom( slicer.serialize_slice_key(slice_key)) for spec in slicing_specs or [None]: if (spec is None or slicer.SingleSliceSpec( spec=spec).is_slice_applicable(slice_key)): slicing_details = expected.validation_details.slicing_details.add( ) if spec is not None: slicing_details.slicing_spec.CopyFrom(spec) else: slicing_details.slicing_spec.CopyFrom(config.SlicingSpec()) slicing_details.num_matching_slices = 1 self.assertAlmostEqual(result, expected)
def validate_metrics( sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], Dict['metric_types.MetricKey', Any]], eval_config: config_pb2.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) is_cross_slice = slicer.is_cross_slice_key(sliced_key) def _check_threshold(key: metric_types.MetricKey, threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" metric = float(metric) if isinstance(threshold, config_pb2.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric >= lower_bound and metric <= upper_bound elif isinstance(threshold, config_pb2.GenericChangeThreshold): diff = metric metric_baseline = float( metrics[key.make_baseline_key(baseline_model_name)]) if math.isclose(metric_baseline, 0.0): ratio = float('nan') else: ratio = diff / metric_baseline if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError( '"UNKNOWN" direction for change threshold: {}.'.format( threshold)) if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER: return diff <= absolute and ratio <= relative elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER: return diff >= absolute and ratio >= relative else: raise ValueError('Unknown threshold: {}'.format(threshold)) def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config_pb2.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config_pb2.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if slice_spec is not None: if (isinstance(slice_spec, config_pb2.SlicingSpec) and (is_cross_slice or not slicer.SingleSliceSpec( spec=slice_spec).is_slice_applicable(sliced_key))): continue if (isinstance(slice_spec, config_pb2.CrossSlicingSpec) and (not is_cross_slice or not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec))): continue elif is_cross_slice: continue try: check_result = _check_threshold(metric_key, threshold, metric) except ValueError: msg = """ Invalid metrics or threshold for comparison: The type of the metric is: {}, the metric value is: {}, and the threshold is: {}. """.format(type(metric), metric, threshold) check_result = False else: msg = '' if not check_result: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # Track we have completed a validation check for slice spec and metric slicing_details = result.validation_details.slicing_details.add() if slice_spec is not None: if isinstance(slice_spec, config_pb2.SlicingSpec): slicing_details.slicing_spec.CopyFrom(slice_spec) else: slicing_details.cross_slicing_spec.CopyFrom(slice_spec) else: slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec()) slicing_details.num_matching_slices = 1 # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for slice_spec, threshold in thresholds: if slice_spec is not None: if is_cross_slice != isinstance(slice_spec, config_pb2.CrossSlicingSpec): continue if (is_cross_slice and not slicer.is_cross_slice_applicable( cross_slice_key=sliced_key, cross_slicing_spec=slice_spec)): continue elif is_cross_slice: continue # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: if not is_cross_slice: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) else: validation_for_slice.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result
def validate_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]], eval_config: config.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) def _check_threshold(key: metric_types.MetricKey, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" threshold = thresholds[key] if isinstance(threshold, config.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric > lower_bound and metric < upper_bound elif isinstance(threshold, config.GenericChangeThreshold): diff = metric ratio = diff / metrics[key.make_baseline_key(baseline_model_name)] if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError('"UNKNOWN" direction for change threshold.') if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: return diff < absolute and ratio < relative elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: return diff > absolute and ratio > relative def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() for metric_key, metric in metrics.items(): # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if (metric_key.model_name == baseline_model_name or metric_key not in thresholds): continue msg = '' # We try to convert to float values. try: metric = float(metric) except (TypeError, ValueError): msg = """ Invalid threshold config: This metric is not comparable to the threshold. The type of the threshold is: {}, and the metric value is: \n{}""".format(type(metric), metric) if not _check_threshold(metric_key, metric): failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(thresholds[metric_key], failure.metric_threshold) failure.message = msg # Any failure leads to overall failure. if validation_for_slice.failures: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result
def convert_slice_metrics_to_proto( metrics: Tuple[slicer.SliceKeyOrCrossSliceKeyType, Dict[Any, Any]], add_metrics_callbacks: List[types.AddMetricsCallbackType] ) -> metrics_for_slice_pb2.MetricsForSlice: """Converts the given slice metrics into serialized proto MetricsForSlice. Args: metrics: The slice metrics. add_metrics_callbacks: A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate(). Returns: The MetricsForSlice proto. Raises: TypeError: If the type of the feature value in slice key cannot be recognized. """ result = metrics_for_slice_pb2.MetricsForSlice() slice_key, slice_metrics = metrics if slicer.is_cross_slice_key(slice_key): result.cross_slice_key.CopyFrom( slicer.serialize_cross_slice_key(slice_key)) else: result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) slice_metrics = slice_metrics.copy() if metric_keys.ERROR_METRIC in slice_metrics: logging.warning('Error for slice: %s with error message: %s ', slice_key, slice_metrics[metric_keys.ERROR_METRIC]) result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[ metric_keys.ERROR_METRIC] return result # Convert the metrics from add_metrics_callbacks to the structured output if # defined. if add_metrics_callbacks and (not any( isinstance(k, metric_types.MetricKey) for k in slice_metrics.keys())): for add_metrics_callback in add_metrics_callbacks: if hasattr(add_metrics_callback, 'populate_stats_and_pop'): add_metrics_callback.populate_stats_and_pop( slice_key, slice_metrics, result.metrics) for key in sorted(slice_metrics.keys()): value = slice_metrics[key] if isinstance(value, types.ValueWithTDistribution): unsampled_value = value.unsampled_value _, lower_bound, upper_bound = ( math_util.calculate_confidence_interval(value)) confidence_interval = metrics_for_slice_pb2.ConfidenceInterval( lower_bound=convert_metric_value_to_proto(lower_bound), upper_bound=convert_metric_value_to_proto(upper_bound), standard_error=convert_metric_value_to_proto( value.sample_standard_deviation), degrees_of_freedom={'value': value.sample_degrees_of_freedom}) metric_value = convert_metric_value_to_proto(unsampled_value) # If metric can be stored to double_value metrics, replace it with a # bounded_value for backwards compatibility. # TODO(b/188575688): remove this logic to stop populating bounded_value if metric_value.WhichOneof('type') == 'double_value': # setting bounded_value clears double_value in the same oneof scope. metric_value.bounded_value.value.value = unsampled_value metric_value.bounded_value.lower_bound.value = lower_bound metric_value.bounded_value.upper_bound.value = upper_bound metric_value.bounded_value.methodology = ( metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP) else: metric_value = convert_metric_value_to_proto(value) confidence_interval = None if isinstance(key, metric_types.MetricKey): result.metric_keys_and_values.add( key=key.to_proto(), value=metric_value, confidence_interval=confidence_interval) else: result.metrics[key].CopyFrom(metric_value) return result
def validate_metrics( sliced_metrics: Tuple[slicer.SliceKeyType, Dict['metric_types.MetricKey', Any]], eval_config: config.EvalConfig ) -> validation_result_pb2.ValidationResult: """Check the metrics and check whether they should be validated.""" # Find out which model is baseline. baseline_spec = model_util.get_baseline_model_spec(eval_config) baseline_model_name = baseline_spec.name if baseline_spec else None sliced_key, metrics = sliced_metrics thresholds = metric_specs.metric_thresholds_from_metrics_specs( eval_config.metrics_specs) # pytype: disable=wrong-arg-types def _check_threshold(key: metric_types.MetricKey, slicing_spec: Optional[config.SlicingSpec], threshold: _ThresholdType, metric: Any) -> bool: """Verify a metric given its metric key and metric value.""" if (slicing_spec is not None and not slicer.SingleSliceSpec( spec=slicing_spec).is_slice_applicable(sliced_key)): return True if isinstance(threshold, config.GenericValueThreshold): lower_bound, upper_bound = -np.inf, np.inf if threshold.HasField('lower_bound'): lower_bound = threshold.lower_bound.value if threshold.HasField('upper_bound'): upper_bound = threshold.upper_bound.value return metric > lower_bound and metric < upper_bound elif isinstance(threshold, config.GenericChangeThreshold): diff = metric ratio = diff / metrics[key.make_baseline_key(baseline_model_name)] if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: absolute, relative = np.inf, np.inf elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: absolute, relative = -np.inf, -np.inf else: raise ValueError('"UNKNOWN" direction for change threshold.') if threshold.HasField('absolute'): absolute = threshold.absolute.value if threshold.HasField('relative'): relative = threshold.relative.value if threshold.direction == config.MetricDirection.LOWER_IS_BETTER: return diff < absolute and ratio < relative elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER: return diff > absolute and ratio > relative def _copy_metric(metric, to): # Will add more types when more MetricValue are supported. to.double_value.value = float(metric) def _copy_threshold(threshold, to): if isinstance(threshold, config.GenericValueThreshold): to.value_threshold.CopyFrom(threshold) if isinstance(threshold, config.GenericChangeThreshold): to.change_threshold.CopyFrom(threshold) def _add_to_set(s, v): """Adds value to set. Returns true if didn't exist.""" if v in s: return False else: s.add(v) return True # Empty metrics per slice is considered validated. result = validation_result_pb2.ValidationResult(validation_ok=True) validation_for_slice = validation_result_pb2.MetricsValidationForSlice() unchecked_thresholds = dict(thresholds) for metric_key, metric in metrics.items(): if metric_key not in thresholds: continue del unchecked_thresholds[metric_key] # Not meaningful to check threshold for baseline model, thus always return # True if such threshold is configured. We also do not compare Message type # metrics. if metric_key.model_name == baseline_model_name: continue msg = '' # We try to convert to float values. try: metric = float(metric) except (TypeError, ValueError): msg = """ Invalid threshold config: This metric is not comparable to the threshold. The type of the threshold is: {}, and the metric value is: \n{}""".format(type(metric), metric) existing_failures = set() for slice_spec, threshold in thresholds[metric_key]: if not _check_threshold(metric_key, slice_spec, threshold, metric): # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_metric(metric, failure.metric_value) _copy_threshold(threshold, failure.metric_threshold) failure.message = msg # All unchecked thresholds are considered failures. for metric_key, thresholds in unchecked_thresholds.items(): if metric_key.model_name == baseline_model_name: continue existing_failures = set() for _, threshold in thresholds: # The same threshold values could be set for multiple matching slice # specs. Only store the first match. # # Note that hashing by SerializeToString() is only safe if used within # the same process. if not _add_to_set(existing_failures, threshold.SerializeToString()): continue failure = validation_for_slice.failures.add() failure.metric_key.CopyFrom(metric_key.to_proto()) _copy_threshold(threshold, failure.metric_threshold) failure.message = 'Metric not found.' # Any failure leads to overall failure. if validation_for_slice.failures: validation_for_slice.slice_key.CopyFrom( slicer.serialize_slice_key(sliced_key)) result.validation_ok = False result.metric_validations_per_slice.append(validation_for_slice) return result