Exemple #1
0
def _serialize_plots(
        plots: Tuple[slicer.SliceKeyType, Dict[Any, Any]],
        add_metrics_callbacks: List[types.AddMetricsCallbackType]) -> bytes:
    """Converts the given slice plots into serialized proto PlotsForSlice..

  Args:
    plots: The slice plots.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto PlotsForSlice.
  """
    result = metrics_for_slice_pb2.PlotsForSlice()
    slice_key, slice_plots = plots

    if metric_keys.ERROR_METRIC in slice_plots:
        tf.compat.v1.logging.warning(
            'Error for slice: %s with error message: %s ', slice_key,
            slice_plots[metric_keys.ERROR_METRIC])
        metrics = metrics_for_slice_pb2.PlotsForSlice()
        metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
        metrics.plots[metric_keys.ERROR_METRIC].debug_message = slice_plots[
            metric_keys.ERROR_METRIC]
        return metrics.SerializeToString()

    # Convert the slice key.
    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    # Convert the slice plots.
    _convert_slice_plots(slice_plots, add_metrics_callbacks, result)  # pytype: disable=wrong-arg-types

    return result.SerializeToString()
 def testValidateMetricsMetricTDistributionValueAndThreshold(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         value_threshold=config.GenericValueThreshold(
             lower_bound={'value': 0.9}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(metrics=[
                 config.MetricConfig(
                     class_name='AUC',
                     threshold=threshold if slicing_specs is None else None,
                     per_slice_thresholds=[
                         config.PerSliceMetricThreshold(
                             slicing_specs=slicing_specs,
                             threshold=threshold)
                     ]),
             ],
                                model_names=['']),
         ],
     )
     sliced_metrics = (slice_key, {
         metric_types.MetricKey(name='auc'):
         types.ValueWithTDistribution(sample_mean=0.91, unsampled_value=0.8)
     })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       failures {
         metric_key {
           name: "auc"
         }
         metric_value {
           double_value {
             value: 0.8
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     expected.metric_validations_per_slice[0].failures[
         0].metric_threshold.CopyFrom(threshold)
     expected.metric_validations_per_slice[0].slice_key.CopyFrom(
         slicer.serialize_slice_key(slice_key))
     for spec in slicing_specs or [None]:
         if (spec is None or slicer.SingleSliceSpec(
                 spec=spec).is_slice_applicable(slice_key)):
             slicing_details = expected.validation_details.slicing_details.add(
             )
             if spec is not None:
                 slicing_details.slicing_spec.CopyFrom(spec)
             else:
                 slicing_details.slicing_spec.CopyFrom(config.SlicingSpec())
             slicing_details.num_matching_slices = 1
     self.assertEqual(result, expected)
Exemple #3
0
  def testConvertSlicePlotsToProtoEmptyPlot(self):
    slice_key = _make_slice_key('fruit', 'apple')
    tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_plot = metrics_plots_and_validations_writer.convert_slice_plots_to_proto(
        (slice_key, tfma_plots), [])
    expected_plot = metrics_for_slice_pb2.PlotsForSlice()
    expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_plot.plots[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(expected_plot, actual_plot)
Exemple #4
0
  def testSerializePlots_emptyPlot(self):
    slice_key = _make_slice_key('fruit', 'apple')
    tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_plot = metrics_and_plots_serialization._serialize_plots(
        (slice_key, tfma_plots), [])
    expected_plot = metrics_for_slice_pb2.PlotsForSlice()
    expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_plot.plots[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(
        expected_plot,
        metrics_for_slice_pb2.PlotsForSlice.FromString(actual_plot))
Exemple #5
0
def convert_slice_attributions_to_proto(
    attributions: Tuple[slicer.SliceKeyOrCrossSliceKeyType,
                        Dict[Any, Dict[Text, Any]]]
) -> metrics_for_slice_pb2.AttributionsForSlice:
    """Converts the given slice attributions into serialized AtributionsForSlice.

  Args:
    attributions: The slice attributions.

  Returns:
    The AttributionsForSlice proto.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
    result = metrics_for_slice_pb2.AttributionsForSlice()
    slice_key, slice_attributions = attributions

    if slicer.is_cross_slice_key(slice_key):
        result.cross_slice_key.CopyFrom(
            slicer.serialize_cross_slice_key(slice_key))
    else:
        result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_attributions = slice_attributions.copy()
    for key in sorted(slice_attributions.keys()):
        key_and_value = result.attributions_keys_and_values.add()
        key_and_value.key.CopyFrom(key.to_proto())
        for feature, value in slice_attributions[key].items():
            attribution_value = metrics_for_slice_pb2.MetricValue()
            if isinstance(value, six.binary_type):
                # Convert textual types to string metrics.
                attribution_value.bytes_value = value
            elif isinstance(value, six.text_type):
                # Convert textual types to string metrics.
                attribution_value.bytes_value = value.encode('utf8')
            elif isinstance(value, np.ndarray) and value.size != 1:
                # Convert NumPy arrays to ArrayValue.
                attribution_value.array_value.CopyFrom(
                    _convert_to_array_value(value))
            else:
                # We try to convert to float values.
                try:
                    attribution_value.double_value.value = float(value)
                except (TypeError, ValueError) as e:
                    attribution_value.unknown_type.value = str(value)
                    attribution_value.unknown_type.error = e.message  # pytype: disable=attribute-error
            key_and_value.values[feature].CopyFrom(attribution_value)

    return result
Exemple #6
0
def _serialize_metrics(
        metrics: Tuple[slicer.SliceKeyType, Dict[Any, Any]],
        add_metrics_callbacks: List[types.AddMetricsCallbackType]) -> bytes:
    """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The serialized proto MetricsForSlice.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
    result = metrics_for_slice_pb2.MetricsForSlice()
    slice_key, slice_metrics = metrics

    if metric_keys.ERROR_METRIC in slice_metrics:
        tf.compat.v1.logging.warning(
            'Error for slice: %s with error message: %s ', slice_key,
            slice_metrics[metric_keys.ERROR_METRIC])
        metrics = metrics_for_slice_pb2.MetricsForSlice()
        metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
        metrics.metrics[
            metric_keys.ERROR_METRIC].debug_message = slice_metrics[
                metric_keys.ERROR_METRIC]
        return metrics.SerializeToString()

    # Convert the slice key.
    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    # Convert the slice metrics.
    convert_slice_metrics(slice_key, slice_metrics, add_metrics_callbacks,
                          result)
    return result.SerializeToString()
 def testValidateMetricsMetricValueAndThreshold(self, slicing_specs,
                                                slice_key):
   threshold = config.MetricThreshold(
       value_threshold=config.GenericValueThreshold(upper_bound={'value': 1}))
   eval_config = config.EvalConfig(
       model_specs=[
           config.ModelSpec(),
       ],
       slicing_specs=slicing_specs,
       metrics_specs=[
           config.MetricsSpec(
               metrics=[
                   config.MetricConfig(
                       class_name='WeightedExampleCount',
                       # 1.5 < 1, NOT OK.
                       threshold=threshold if slicing_specs is None else None,
                       per_slice_thresholds=[
                           config.PerSliceMetricThreshold(
                               slicing_specs=slicing_specs,
                               threshold=threshold)
                       ]),
               ],
               model_names=['']),
       ],
   )
   sliced_metrics = (slice_key, {
       metric_types.MetricKey(name='weighted_example_count'): 1.5,
   })
   result = metrics_validator.validate_metrics(sliced_metrics, eval_config)
   self.assertFalse(result.validation_ok)
   expected = text_format.Parse(
       """
       metric_validations_per_slice {
         failures {
           metric_key {
             name: "weighted_example_count"
           }
           metric_value {
             double_value {
               value: 1.5
             }
           }
         }
       }""", validation_result_pb2.ValidationResult())
   expected.metric_validations_per_slice[0].failures[
       0].metric_threshold.CopyFrom(threshold)
   expected.metric_validations_per_slice[0].slice_key.CopyFrom(
       slicer.serialize_slice_key(slice_key))
   self.assertEqual(result, expected)
Exemple #8
0
  def testConvertSliceMetricsToProtoEmptyMetrics(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_metrics = (
        metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')]))

    expected_metrics = metrics_for_slice_pb2.MetricsForSlice()
    expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_metrics.metrics[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(expected_metrics, actual_metrics)
  def testSerializeMetrics_emptyMetrics(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_metrics = metrics_and_plots_serialization._serialize_metrics(
        (slice_key, slice_metrics),
        [post_export_metrics.auc(),
         post_export_metrics.auc(curve='PR')])

    expected_metrics = metrics_for_slice_pb2.MetricsForSlice()
    expected_metrics.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_metrics.metrics[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(
        expected_metrics,
        metrics_for_slice_pb2.MetricsForSlice.FromString(actual_metrics))
def convert_slice_plots_to_proto(
    plots: Tuple[slicer.SliceKeyType, Dict[Any, Any]],
    add_metrics_callbacks: List[types.AddMetricsCallbackType]
) -> metrics_for_slice_pb2.PlotsForSlice:
    """Converts the given slice plots into PlotsForSlice proto.

  Args:
    plots: The slice plots.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The PlotsForSlice proto.
  """
    result = metrics_for_slice_pb2.PlotsForSlice()
    slice_key, slice_plots = plots

    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_plots = slice_plots.copy()

    if metric_keys.ERROR_METRIC in slice_plots:
        logging.warning('Error for slice: %s with error message: %s ',
                        slice_key, slice_plots[metric_keys.ERROR_METRIC])
        error_metric = slice_plots.pop(metric_keys.ERROR_METRIC)
        result.plots[metric_keys.ERROR_METRIC].debug_message = error_metric
        return result

    if add_metrics_callbacks and (not any(
            isinstance(k, metric_types.MetricKey)
            for k in slice_plots.keys())):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, 'populate_plots_and_pop'):
                add_metrics_callback.populate_plots_and_pop(
                    slice_plots, result.plots)
    plots_by_key = {}
    for key in sorted(slice_plots.keys()):
        value = slice_plots[key]
        # Remove plot name from key (multiple plots are combined into a single
        # proto).
        if isinstance(key, metric_types.MetricKey):
            parent_key = key._replace(name=None)
        else:
            continue
        if parent_key not in plots_by_key:
            key_and_value = result.plot_keys_and_values.add()
            key_and_value.key.CopyFrom(parent_key.to_proto())
            plots_by_key[parent_key] = key_and_value.value

        if isinstance(value,
                      metrics_for_slice_pb2.CalibrationHistogramBuckets):
            plots_by_key[parent_key].calibration_histogram_buckets.CopyFrom(
                value)
            slice_plots.pop(key)
        elif isinstance(value,
                        metrics_for_slice_pb2.ConfusionMatrixAtThresholds):
            plots_by_key[parent_key].confusion_matrix_at_thresholds.CopyFrom(
                value)
            slice_plots.pop(key)
        elif isinstance(
                value,
                metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds):
            plots_by_key[
                parent_key].multi_class_confusion_matrix_at_thresholds.CopyFrom(
                    value)
            slice_plots.pop(key)
        elif isinstance(
                value,
                metrics_for_slice_pb2.MultiLabelConfusionMatrixAtThresholds):
            plots_by_key[
                parent_key].multi_label_confusion_matrix_at_thresholds.CopyFrom(
                    value)
            slice_plots.pop(key)

    if slice_plots:
        if add_metrics_callbacks is None:
            add_metrics_callbacks = []
        raise NotImplementedError(
            'some plots were not converted or popped. keys: %s. '
            'add_metrics_callbacks were: %s' % (
                slice_plots.keys(),
                [
                    x.name for x in add_metrics_callbacks  # pytype: disable=attribute-error
                ]))

    return result
def convert_slice_metrics_to_proto(
    metrics: Tuple[slicer.SliceKeyType, Dict[Any, Any]],
    add_metrics_callbacks: List[types.AddMetricsCallbackType]
) -> metrics_for_slice_pb2.MetricsForSlice:
    """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The MetricsForSlice proto.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
    result = metrics_for_slice_pb2.MetricsForSlice()
    slice_key, slice_metrics = metrics

    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_metrics = slice_metrics.copy()

    if metric_keys.ERROR_METRIC in slice_metrics:
        logging.warning('Error for slice: %s with error message: %s ',
                        slice_key, slice_metrics[metric_keys.ERROR_METRIC])
        result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
            metric_keys.ERROR_METRIC]
        return result

    # Convert the metrics from add_metrics_callbacks to the structured output if
    # defined.
    if add_metrics_callbacks and (not any(
            isinstance(k, metric_types.MetricKey)
            for k in slice_metrics.keys())):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, 'populate_stats_and_pop'):
                add_metrics_callback.populate_stats_and_pop(
                    slice_key, slice_metrics, result.metrics)
    for key in sorted(slice_metrics.keys()):
        value = slice_metrics[key]
        metric_value = metrics_for_slice_pb2.MetricValue()
        if isinstance(value,
                      metrics_for_slice_pb2.ConfusionMatrixAtThresholds):
            metric_value.confusion_matrix_at_thresholds.CopyFrom(value)
        elif isinstance(
                value,
                metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds):
            metric_value.multi_class_confusion_matrix_at_thresholds.CopyFrom(
                value)
        elif isinstance(value, types.ValueWithTDistribution):
            # Currently we populate both bounded_value and confidence_interval.
            # Avoid populating bounded_value once the UI handles confidence_interval.
            # Convert to a bounded value. 95% confidence level is computed here.
            _, lower_bound, upper_bound = (
                math_util.calculate_confidence_interval(value))
            metric_value.bounded_value.value.value = value.unsampled_value
            metric_value.bounded_value.lower_bound.value = lower_bound
            metric_value.bounded_value.upper_bound.value = upper_bound
            metric_value.bounded_value.methodology = (
                metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP)
            # Populate confidence_interval
            metric_value.confidence_interval.lower_bound.value = lower_bound
            metric_value.confidence_interval.upper_bound.value = upper_bound
            t_dist_value = metrics_for_slice_pb2.TDistributionValue()
            t_dist_value.sample_mean.value = value.sample_mean
            t_dist_value.sample_standard_deviation.value = (
                value.sample_standard_deviation)
            t_dist_value.sample_degrees_of_freedom.value = (
                value.sample_degrees_of_freedom)
            # Once the UI handles confidence interval, we will avoid setting this and
            # instead use the double_value.
            t_dist_value.unsampled_value.value = value.unsampled_value
            metric_value.confidence_interval.t_distribution_value.CopyFrom(
                t_dist_value)
        elif isinstance(value, six.binary_type):
            # Convert textual types to string metrics.
            metric_value.bytes_value = value
        elif isinstance(value, six.text_type):
            # Convert textual types to string metrics.
            metric_value.bytes_value = value.encode('utf8')
        elif isinstance(value, np.ndarray):
            # Convert NumPy arrays to ArrayValue.
            metric_value.array_value.CopyFrom(_convert_to_array_value(value))
        else:
            # We try to convert to float values.
            try:
                metric_value.double_value.value = float(value)
            except (TypeError, ValueError) as e:
                metric_value.unknown_type.value = str(value)
                metric_value.unknown_type.error = e.message  # pytype: disable=attribute-error

        if isinstance(key, metric_types.MetricKey):
            key_and_value = result.metric_keys_and_values.add()
            key_and_value.key.CopyFrom(key.to_proto())
            key_and_value.value.CopyFrom(metric_value)
        else:
            result.metrics[key].CopyFrom(metric_value)

    return result
 def testValidateMetricsMetricTDistributionChangeAndThreshold(
         self, slicing_specs, slice_key):
     threshold = config.MetricThreshold(
         change_threshold=config.GenericChangeThreshold(
             direction=config.MetricDirection.LOWER_IS_BETTER,
             absolute={'value': -1}))
     eval_config = config.EvalConfig(
         model_specs=[
             config.ModelSpec(),
             config.ModelSpec(name='baseline', is_baseline=True)
         ],
         slicing_specs=slicing_specs,
         metrics_specs=[
             config.MetricsSpec(metrics=[
                 config.MetricConfig(
                     class_name='AUC',
                     threshold=threshold if slicing_specs is None else None,
                     per_slice_thresholds=[
                         config.PerSliceMetricThreshold(
                             slicing_specs=slicing_specs,
                             threshold=threshold)
                     ]),
             ],
                                model_names=['']),
         ],
     )
     sliced_metrics = (
         slice_key,
         {
             # This is the mean of the diff.
             metric_types.MetricKey(name='auc', model_name='baseline'):
             types.ValueWithTDistribution(sample_mean=0.91,
                                          unsampled_value=0.6),
             metric_types.MetricKey(name='auc', is_diff=True):
             types.ValueWithTDistribution(sample_mean=0.1,
                                          unsampled_value=0.1),
         })
     result = metrics_validator.validate_metrics(sliced_metrics,
                                                 eval_config)
     self.assertFalse(result.validation_ok)
     expected = text_format.Parse(
         """
     metric_validations_per_slice {
       failures {
         metric_key {
           name: "auc"
           is_diff: true
         }
         metric_value {
           double_value {
             value: 0.1
           }
         }
       }
     }""", validation_result_pb2.ValidationResult())
     expected.metric_validations_per_slice[0].failures[
         0].metric_threshold.CopyFrom(threshold)
     expected.metric_validations_per_slice[0].slice_key.CopyFrom(
         slicer.serialize_slice_key(slice_key))
     for spec in slicing_specs or [None]:
         if (spec is None or slicer.SingleSliceSpec(
                 spec=spec).is_slice_applicable(slice_key)):
             slicing_details = expected.validation_details.slicing_details.add(
             )
             if spec is not None:
                 slicing_details.slicing_spec.CopyFrom(spec)
             else:
                 slicing_details.slicing_spec.CopyFrom(config.SlicingSpec())
             slicing_details.num_matching_slices = 1
     self.assertAlmostEqual(result, expected)
Exemple #13
0
def validate_metrics(
    sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType],
                          Dict['metric_types.MetricKey',
                               Any]], eval_config: config_pb2.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)
    is_cross_slice = slicer.is_cross_slice_key(sliced_key)

    def _check_threshold(key: metric_types.MetricKey,
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        metric = float(metric)
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric >= lower_bound and metric <= upper_bound
        elif isinstance(threshold, config_pb2.GenericChangeThreshold):
            diff = metric
            metric_baseline = float(
                metrics[key.make_baseline_key(baseline_model_name)])
            if math.isclose(metric_baseline, 0.0):
                ratio = float('nan')
            else:
                ratio = diff / metric_baseline
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError(
                    '"UNKNOWN" direction for change threshold: {}.'.format(
                        threshold))
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config_pb2.MetricDirection.LOWER_IS_BETTER:
                return diff <= absolute and ratio <= relative
            elif threshold.direction == config_pb2.MetricDirection.HIGHER_IS_BETTER:
                return diff >= absolute and ratio >= relative
        else:
            raise ValueError('Unknown threshold: {}'.format(threshold))

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config_pb2.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config_pb2.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if slice_spec is not None:
                if (isinstance(slice_spec, config_pb2.SlicingSpec)
                        and (is_cross_slice or not slicer.SingleSliceSpec(
                            spec=slice_spec).is_slice_applicable(sliced_key))):
                    continue
                if (isinstance(slice_spec, config_pb2.CrossSlicingSpec)
                        and (not is_cross_slice
                             or not slicer.is_cross_slice_applicable(
                                 cross_slice_key=sliced_key,
                                 cross_slicing_spec=slice_spec))):
                    continue
            elif is_cross_slice:
                continue
            try:
                check_result = _check_threshold(metric_key, threshold, metric)
            except ValueError:
                msg = """
          Invalid metrics or threshold for comparison: The type of the metric
          is: {}, the metric value is: {}, and the threshold is: {}.
          """.format(type(metric), metric, threshold)
                check_result = False
            else:
                msg = ''
            if not check_result:
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
            # Track we have completed a validation check for slice spec and metric
            slicing_details = result.validation_details.slicing_details.add()
            if slice_spec is not None:
                if isinstance(slice_spec, config_pb2.SlicingSpec):
                    slicing_details.slicing_spec.CopyFrom(slice_spec)
                else:
                    slicing_details.cross_slicing_spec.CopyFrom(slice_spec)
            else:
                slicing_details.slicing_spec.CopyFrom(config_pb2.SlicingSpec())
            slicing_details.num_matching_slices = 1
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for slice_spec, threshold in thresholds:
            if slice_spec is not None:
                if is_cross_slice != isinstance(slice_spec,
                                                config_pb2.CrossSlicingSpec):
                    continue
                if (is_cross_slice
                        and not slicer.is_cross_slice_applicable(
                            cross_slice_key=sliced_key,
                            cross_slicing_spec=slice_spec)):
                    continue
            elif is_cross_slice:
                continue
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        if not is_cross_slice:
            validation_for_slice.slice_key.CopyFrom(
                slicer.serialize_slice_key(sliced_key))
        else:
            validation_for_slice.cross_slice_key.CopyFrom(
                slicer.serialize_cross_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
def validate_metrics(
        sliced_metrics: Tuple[slicer.SliceKeyType,
                              Dict[metric_types.MetricKey,
                                   Any]], eval_config: config.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)

    def _check_threshold(key: metric_types.MetricKey, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        threshold = thresholds[key]
        if isinstance(threshold, config.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric > lower_bound and metric < upper_bound
        elif isinstance(threshold, config.GenericChangeThreshold):
            diff = metric
            ratio = diff / metrics[key.make_baseline_key(baseline_model_name)]
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError('"UNKNOWN" direction for change threshold.')
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                return diff < absolute and ratio < relative
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                return diff > absolute and ratio > relative

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    for metric_key, metric in metrics.items():
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if (metric_key.model_name == baseline_model_name
                or metric_key not in thresholds):
            continue
        msg = ''
        # We try to convert to float values.
        try:
            metric = float(metric)
        except (TypeError, ValueError):
            msg = """
        Invalid threshold config: This metric is not comparable to the
        threshold. The type of the threshold is: {}, and the metric value is:
        \n{}""".format(type(metric), metric)
        if not _check_threshold(metric_key, metric):
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_metric(metric, failure.metric_value)
            _copy_threshold(thresholds[metric_key], failure.metric_threshold)
            failure.message = msg
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        validation_for_slice.slice_key.CopyFrom(
            slicer.serialize_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result
Exemple #15
0
def convert_slice_metrics_to_proto(
    metrics: Tuple[slicer.SliceKeyOrCrossSliceKeyType, Dict[Any, Any]],
    add_metrics_callbacks: List[types.AddMetricsCallbackType]
) -> metrics_for_slice_pb2.MetricsForSlice:
    """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The MetricsForSlice proto.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
    result = metrics_for_slice_pb2.MetricsForSlice()
    slice_key, slice_metrics = metrics

    if slicer.is_cross_slice_key(slice_key):
        result.cross_slice_key.CopyFrom(
            slicer.serialize_cross_slice_key(slice_key))
    else:
        result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

    slice_metrics = slice_metrics.copy()

    if metric_keys.ERROR_METRIC in slice_metrics:
        logging.warning('Error for slice: %s with error message: %s ',
                        slice_key, slice_metrics[metric_keys.ERROR_METRIC])
        result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
            metric_keys.ERROR_METRIC]
        return result

    # Convert the metrics from add_metrics_callbacks to the structured output if
    # defined.
    if add_metrics_callbacks and (not any(
            isinstance(k, metric_types.MetricKey)
            for k in slice_metrics.keys())):
        for add_metrics_callback in add_metrics_callbacks:
            if hasattr(add_metrics_callback, 'populate_stats_and_pop'):
                add_metrics_callback.populate_stats_and_pop(
                    slice_key, slice_metrics, result.metrics)
    for key in sorted(slice_metrics.keys()):
        value = slice_metrics[key]
        if isinstance(value, types.ValueWithTDistribution):
            unsampled_value = value.unsampled_value
            _, lower_bound, upper_bound = (
                math_util.calculate_confidence_interval(value))
            confidence_interval = metrics_for_slice_pb2.ConfidenceInterval(
                lower_bound=convert_metric_value_to_proto(lower_bound),
                upper_bound=convert_metric_value_to_proto(upper_bound),
                standard_error=convert_metric_value_to_proto(
                    value.sample_standard_deviation),
                degrees_of_freedom={'value': value.sample_degrees_of_freedom})
            metric_value = convert_metric_value_to_proto(unsampled_value)

            # If metric can be stored to double_value metrics, replace it with a
            # bounded_value for backwards compatibility.
            # TODO(b/188575688): remove this logic to stop populating bounded_value
            if metric_value.WhichOneof('type') == 'double_value':
                # setting bounded_value clears double_value in the same oneof scope.
                metric_value.bounded_value.value.value = unsampled_value
                metric_value.bounded_value.lower_bound.value = lower_bound
                metric_value.bounded_value.upper_bound.value = upper_bound
                metric_value.bounded_value.methodology = (
                    metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP)
        else:
            metric_value = convert_metric_value_to_proto(value)
            confidence_interval = None

        if isinstance(key, metric_types.MetricKey):
            result.metric_keys_and_values.add(
                key=key.to_proto(),
                value=metric_value,
                confidence_interval=confidence_interval)
        else:
            result.metrics[key].CopyFrom(metric_value)

    return result
def validate_metrics(
        sliced_metrics: Tuple[slicer.SliceKeyType,
                              Dict['metric_types.MetricKey',
                                   Any]], eval_config: config.EvalConfig
) -> validation_result_pb2.ValidationResult:
    """Check the metrics and check whether they should be validated."""
    # Find out which model is baseline.
    baseline_spec = model_util.get_baseline_model_spec(eval_config)
    baseline_model_name = baseline_spec.name if baseline_spec else None

    sliced_key, metrics = sliced_metrics
    thresholds = metric_specs.metric_thresholds_from_metrics_specs(
        eval_config.metrics_specs)  # pytype: disable=wrong-arg-types

    def _check_threshold(key: metric_types.MetricKey,
                         slicing_spec: Optional[config.SlicingSpec],
                         threshold: _ThresholdType, metric: Any) -> bool:
        """Verify a metric given its metric key and metric value."""
        if (slicing_spec is not None and not slicer.SingleSliceSpec(
                spec=slicing_spec).is_slice_applicable(sliced_key)):
            return True

        if isinstance(threshold, config.GenericValueThreshold):
            lower_bound, upper_bound = -np.inf, np.inf
            if threshold.HasField('lower_bound'):
                lower_bound = threshold.lower_bound.value
            if threshold.HasField('upper_bound'):
                upper_bound = threshold.upper_bound.value
            return metric > lower_bound and metric < upper_bound
        elif isinstance(threshold, config.GenericChangeThreshold):
            diff = metric
            ratio = diff / metrics[key.make_baseline_key(baseline_model_name)]
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                absolute, relative = np.inf, np.inf
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                absolute, relative = -np.inf, -np.inf
            else:
                raise ValueError('"UNKNOWN" direction for change threshold.')
            if threshold.HasField('absolute'):
                absolute = threshold.absolute.value
            if threshold.HasField('relative'):
                relative = threshold.relative.value
            if threshold.direction == config.MetricDirection.LOWER_IS_BETTER:
                return diff < absolute and ratio < relative
            elif threshold.direction == config.MetricDirection.HIGHER_IS_BETTER:
                return diff > absolute and ratio > relative

    def _copy_metric(metric, to):
        # Will add more types when more MetricValue are supported.
        to.double_value.value = float(metric)

    def _copy_threshold(threshold, to):
        if isinstance(threshold, config.GenericValueThreshold):
            to.value_threshold.CopyFrom(threshold)
        if isinstance(threshold, config.GenericChangeThreshold):
            to.change_threshold.CopyFrom(threshold)

    def _add_to_set(s, v):
        """Adds value to set. Returns true if didn't exist."""
        if v in s:
            return False
        else:
            s.add(v)
            return True

    # Empty metrics per slice is considered validated.
    result = validation_result_pb2.ValidationResult(validation_ok=True)
    validation_for_slice = validation_result_pb2.MetricsValidationForSlice()
    unchecked_thresholds = dict(thresholds)
    for metric_key, metric in metrics.items():
        if metric_key not in thresholds:
            continue
        del unchecked_thresholds[metric_key]
        # Not meaningful to check threshold for baseline model, thus always return
        # True if such threshold is configured. We also do not compare Message type
        # metrics.
        if metric_key.model_name == baseline_model_name:
            continue
        msg = ''
        # We try to convert to float values.
        try:
            metric = float(metric)
        except (TypeError, ValueError):
            msg = """
        Invalid threshold config: This metric is not comparable to the
        threshold. The type of the threshold is: {}, and the metric value is:
        \n{}""".format(type(metric), metric)
        existing_failures = set()
        for slice_spec, threshold in thresholds[metric_key]:
            if not _check_threshold(metric_key, slice_spec, threshold, metric):
                # The same threshold values could be set for multiple matching slice
                # specs. Only store the first match.
                #
                # Note that hashing by SerializeToString() is only safe if used within
                # the same process.
                if not _add_to_set(existing_failures,
                                   threshold.SerializeToString()):
                    continue
                failure = validation_for_slice.failures.add()
                failure.metric_key.CopyFrom(metric_key.to_proto())
                _copy_metric(metric, failure.metric_value)
                _copy_threshold(threshold, failure.metric_threshold)
                failure.message = msg
    # All unchecked thresholds are considered failures.
    for metric_key, thresholds in unchecked_thresholds.items():
        if metric_key.model_name == baseline_model_name:
            continue
        existing_failures = set()
        for _, threshold in thresholds:
            # The same threshold values could be set for multiple matching slice
            # specs. Only store the first match.
            #
            # Note that hashing by SerializeToString() is only safe if used within
            # the same process.
            if not _add_to_set(existing_failures,
                               threshold.SerializeToString()):
                continue
            failure = validation_for_slice.failures.add()
            failure.metric_key.CopyFrom(metric_key.to_proto())
            _copy_threshold(threshold, failure.metric_threshold)
            failure.message = 'Metric not found.'
    # Any failure leads to overall failure.
    if validation_for_slice.failures:
        validation_for_slice.slice_key.CopyFrom(
            slicer.serialize_slice_key(sliced_key))
        result.validation_ok = False
        result.metric_validations_per_slice.append(validation_for_slice)
    return result