Ejemplo n.º 1
0
 def testPlotKeyFromProto(self):
     plot_keys = [
         metric_types.PlotKey(name=''),
         metric_types.PlotKey(name='',
                              model_name='model_name',
                              output_name='output_name',
                              sub_key=metric_types.SubKey(class_id=1)),
         metric_types.MetricKey(name='',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(top_k=2))
     ]
     for key in plot_keys:
         got_key = metric_types.PlotKey.from_proto(key.to_proto())
         self.assertEqual(key, got_key, '{} != {}'.format(key, got_key))
Ejemplo n.º 2
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(
                        name='multi_class_confusion_matrix_plot')
                    got_matrix = got_plots[key]
                    self.assertProtoEquals(
                        """
              matrices {
                threshold: 0.0
                entries {
                  actual_class_id: 0
                  predicted_class_id: 0
                  num_weighted_examples: 0.5
                }
                entries {
                  actual_class_id: 2
                  predicted_class_id: 2
                  num_weighted_examples: 1.0
                }
              }
          """, got_matrix)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 3
0
def _multi_label_confusion_matrix_plot(
    thresholds: Optional[List[float]] = None,
    num_thresholds: Optional[int] = None,
    name: Text = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
) -> metric_types.MetricComputations:
    """Returns computations for multi-label confusion matrix at thresholds."""
    if num_thresholds is not None and thresholds is not None:
        raise ValueError(
            'only one of thresholds or num_thresholds can be set at a time')
    if num_thresholds is None and thresholds is None:
        thresholds = [0.5]
    if num_thresholds is not None:
        thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
                      for i in range(num_thresholds - 2)]
        thresholds = [-_EPSILON] + thresholds + [1.0 + _EPSILON]

    key = metric_types.PlotKey(name=name,
                               model_name=model_name,
                               output_name=output_name)
    return [
        metric_types.MetricComputation(
            keys=[key],
            preprocessor=None,
            combiner=_MultiLabelConfusionMatrixPlotCombiner(
                key=key, eval_config=eval_config, thresholds=thresholds))
    ]
Ejemplo n.º 4
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(name='auc_plot')
                    self.assertIn(key, got_plots)
                    got_plot = got_plots[key]
                    self.assertProtoEquals(
                        """
              matrices {
                threshold: -1e-06
                false_positives: 2.0
                true_positives: 2.0
                precision: 0.5
                recall: 1.0
              }
              matrices {
                true_negatives: 1.0
                false_positives: 1.0
                true_positives: 2.0
                precision: 0.6666667
                recall: 1.0
              }
              matrices {
                threshold: 0.25
                true_negatives: 1.0
                false_positives: 1.0
                true_positives: 2.0
                precision: 0.6666667
                recall: 1.0
              }
              matrices {
                threshold: 0.5
                false_negatives: 1.0
                true_negatives: 2.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 2.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 1.0
                false_negatives: 2.0
                true_negatives: 2.0
                precision: 1.0
                recall: 0.0
              }
          """, got_plot)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 5
0
def _calibration_plot(
    num_buckets: int = DEFAULT_NUM_BUCKETS,
    left: Optional[float] = None,
    right: Optional[float] = None,
    name: Text = CALIBRATION_PLOT_NAME,
    eval_config: Optional[config_pb2.EvalConfig] = None,
    schema: Optional[schema_pb2.Schema] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    aggregation_type: Optional[metric_types.AggregationType] = None,
    class_weights: Optional[Dict[int, float]] = None
) -> metric_types.MetricComputations:
  """Returns metric computations for calibration plot."""
  key = metric_types.PlotKey(
      name=name,
      model_name=model_name,
      output_name=output_name,
      sub_key=sub_key)

  label_left, label_right = None, None
  if (left is None or right is None) and eval_config and schema:
    label_left, label_right = _find_label_domain(eval_config, schema,
                                                 model_name, output_name)
  if left is None:
    left = label_left if label_left is not None else 0.0
  if right is None:
    right = label_right if label_right is not None else 1.0

  # Make sure calibration histogram is calculated. Note we are using the default
  # number of buckets assigned to the histogram instead of the value used for
  # the plots just in case the computation is shared with other metrics and
  # plots that need higher preicion. It will be downsampled later.
  computations = calibration_histogram.calibration_histogram(
      eval_config=eval_config,
      model_name=model_name,
      output_name=output_name,
      sub_key=sub_key,
      left=left,
      right=right,
      aggregation_type=aggregation_type,
      class_weights=class_weights)
  histogram_key = computations[-1].keys[-1]

  def result(
      metrics: Dict[metric_types.MetricKey, Any]
  ) -> Dict[metric_types.MetricKey, Any]:
    thresholds = [
        left + i * (right - left) / num_buckets for i in range(num_buckets + 1)
    ]
    thresholds = [float('-inf')] + thresholds
    histogram = calibration_histogram.rebin(
        thresholds, metrics[histogram_key], left=left, right=right)
    return {key: _to_proto(thresholds, histogram)}

  derived_computation = metric_types.DerivedMetricComputation(
      keys=[key], result=result)
  computations.append(derived_computation)
  return computations
Ejemplo n.º 6
0
def calibration_histogram(
    num_buckets: Optional[int] = None,
    left: Optional[float] = None,
    right: Optional[float] = None,
    name: Text = None,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    aggregation_type: Optional[metric_types.AggregationType] = None,
    class_weights: Optional[Dict[int, float]] = None
) -> metric_types.MetricComputations:
  """Returns metric computations for calibration histogram.

  Args:
    num_buckets: Number of buckets to use. Note that the actual number of
      buckets will be num_buckets + 2 to account for the edge cases.
    left: Start of predictions interval.
    right: End of predictions interval.
    name: Metric name.
    eval_config: Eval config.
    model_name: Optional model name (if multi-model evaluation).
    output_name: Optional output name (if multi-output model type).
    sub_key: Optional sub key.
    aggregation_type: Optional aggregation type.
    class_weights: Optional class weights to apply to multi-class / multi-label
      labels and predictions prior to flattening (when micro averaging is used).

  Returns:
    MetricComputations for computing the histogram(s).
  """
  if num_buckets is None:
    num_buckets = DEFAULT_NUM_BUCKETS
  if left is None:
    left = 0.0
  if right is None:
    right = 1.0
  if name is None:
    name = '{}_{}'.format(CALIBRATION_HISTOGRAM_NAME, num_buckets)
  key = metric_types.PlotKey(
      name=name,
      model_name=model_name,
      output_name=output_name,
      sub_key=sub_key)
  return [
      metric_types.MetricComputation(
          keys=[key],
          preprocessor=None,
          combiner=_CalibrationHistogramCombiner(
              key=key,
              eval_config=eval_config,
              aggregation_type=aggregation_type,
              class_weights=class_weights,
              num_buckets=num_buckets,
              left=left,
              right=right))
  ]
Ejemplo n.º 7
0
      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_plots = got[0]
          self.assertEqual(got_slice_key, ())
          self.assertLen(got_plots, 1)
          key = metric_types.PlotKey(
              name='_calibration_histogram_10000',
              sub_key=metric_types.SubKey(top_k=2),
              example_weighted=True)
          self.assertIn(key, got_plots)
          got_histogram = got_plots[key]
          self.assertLen(got_histogram, 5)
          self.assertEqual(
              got_histogram[0],
              calibration_histogram.Bucket(
                  bucket_id=0,
                  weighted_labels=3.0 + 4.0,
                  weighted_predictions=(2 * 1.0 * float('-inf') +
                                        2 * 2.0 * float('-inf') +
                                        2 * 3.0 * float('-inf') +
                                        2 * 4.0 * float('-inf') + -0.1 * 4.0),
                  weighted_examples=(1.0 * 2.0 + 2.0 * 2.0 + 3.0 * 2.0 +
                                     4.0 * 3.0)))
          self.assertEqual(
              got_histogram[1],
              calibration_histogram.Bucket(
                  bucket_id=2001,
                  weighted_labels=0.0 + 0.0,
                  weighted_predictions=0.2 + 3 * 0.2,
                  weighted_examples=1.0 + 3.0))
          self.assertEqual(
              got_histogram[2],
              calibration_histogram.Bucket(
                  bucket_id=5001,
                  weighted_labels=1.0 + 0.0 * 3.0,
                  weighted_predictions=0.5 * 1.0 + 0.5 * 3.0,
                  weighted_examples=1.0 + 3.0))
          self.assertEqual(
              got_histogram[3],
              calibration_histogram.Bucket(
                  bucket_id=8001,
                  weighted_labels=0.0 * 2.0 + 1.0 * 2.0,
                  weighted_predictions=0.8 * 2.0 + 0.8 * 2.0,
                  weighted_examples=2.0 + 2.0))
          self.assertEqual(
              got_histogram[4],
              calibration_histogram.Bucket(
                  bucket_id=10001,
                  weighted_labels=0.0 * 4.0,
                  weighted_predictions=1.1 * 4.0,
                  weighted_examples=4.0))

        except AssertionError as err:
          raise util.BeamAssertException(err)
Ejemplo n.º 8
0
def _confusion_matrix_plot(
    num_thresholds: int = DEFAULT_NUM_THRESHOLDS,
    name: Text = CONFUSION_MATRIX_PLOT_NAME,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    aggregation_type: Optional[metric_types.AggregationType] = None,
    class_weights: Optional[Dict[int, float]] = None
) -> metric_types.MetricComputations:
    """Returns metric computations for confusion matrix plots."""
    key = metric_types.PlotKey(name=name,
                               model_name=model_name,
                               output_name=output_name,
                               sub_key=sub_key)

    # The interoploation strategy used here matches how the legacy post export
    # metrics calculated its plots.
    thresholds = [
        i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)
    ]
    thresholds = [-1e-6] + thresholds

    # Make sure matrices are calculated.
    matrices_computations = binary_confusion_matrices.binary_confusion_matrices(
        # Use a custom name since we have a custom interpolation strategy which
        # will cause the default naming used by the binary confusion matrix to be
        # very long.
        name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + '_' +
              name),
        eval_config=eval_config,
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        aggregation_type=aggregation_type,
        class_weights=class_weights,
        thresholds=thresholds)
    matrices_key = matrices_computations[-1].keys[-1]

    def result(
        metrics: Dict[metric_types.MetricKey, Any]
    ) -> Dict[metric_types.MetricKey,
              metrics_for_slice_pb2.ConfusionMatrixAtThresholds]:
        return {
            key:
            confusion_matrix_metrics.to_proto(thresholds,
                                              metrics[matrices_key])
        }

    derived_computation = metric_types.DerivedMetricComputation(keys=[key],
                                                                result=result)
    computations = matrices_computations
    computations.append(derived_computation)
    return computations
            def check_plots(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    plot_key = metric_types.PlotKey('calibration_plot')
                    self.assertIn(plot_key, got_plots)
                    # 10 buckets + 2 for edge cases
                    self.assertLen(got_plots[plot_key].buckets, 12)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 10
0
      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_plots = got[0]
          self.assertEqual(got_slice_key, ())
          self.assertLen(got_plots, 1)
          key = metric_types.PlotKey(
              name='_calibration_histogram_10000',
              sub_key=metric_types.SubKey(k=2),
              example_weighted=True)
          self.assertIn(key, got_plots)
          got_histogram = got_plots[key]
          self.assertLen(got_histogram, 5)
          self.assertEqual(
              got_histogram[0],
              calibration_histogram.Bucket(
                  bucket_id=0,
                  weighted_labels=0.0 * 4.0,
                  weighted_predictions=-0.2 * 4.0,
                  weighted_examples=4.0))
          self.assertEqual(
              got_histogram[1],
              calibration_histogram.Bucket(
                  bucket_id=1001,
                  weighted_labels=1.0 + 7 * 1.0,
                  weighted_predictions=0.1 + 7 * 0.1,
                  weighted_examples=1.0 + 7.0))
          self.assertEqual(
              got_histogram[2],
              calibration_histogram.Bucket(
                  bucket_id=4001,
                  weighted_labels=1.0 * 3.0 + 0.0 * 5.0,
                  weighted_predictions=0.4 * 3.0 + 0.4 * 5.0,
                  weighted_examples=3.0 + 5.0))
          self.assertEqual(
              got_histogram[3],
              calibration_histogram.Bucket(
                  bucket_id=7001,
                  weighted_labels=0.0 * 2.0 + 0.0 * 6.0,
                  weighted_predictions=0.7 * 2.0 + 0.7 * 6.0,
                  weighted_examples=2.0 + 6.0))
          self.assertEqual(
              got_histogram[4],
              calibration_histogram.Bucket(
                  bucket_id=10001,
                  weighted_labels=0.0 * 8.0,
                  weighted_predictions=1.05 * 8.0,
                  weighted_examples=8.0))

        except AssertionError as err:
          raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey('_calibration_histogram_10000')
                    self.assertIn(key, got_plots)
                    got_histogram = got_plots[key]
                    self.assertLen(got_histogram, 5)
                    self.assertEqual(
                        got_histogram[0],
                        calibration_histogram.Bucket(
                            bucket_id=0,
                            weighted_labels=1.0 * 4.0,
                            weighted_predictions=-0.1 * 4.0,
                            weighted_examples=4.0))
                    self.assertEqual(
                        got_histogram[1],
                        calibration_histogram.Bucket(
                            bucket_id=2001,
                            weighted_labels=0.0 + 0.0,
                            weighted_predictions=0.2 + 7 * 0.2,
                            weighted_examples=1.0 + 7.0))
                    self.assertEqual(
                        got_histogram[2],
                        calibration_histogram.Bucket(
                            bucket_id=5001,
                            weighted_labels=1.0 * 5.0,
                            weighted_predictions=0.5 * 3.0 + 0.5 * 5.0,
                            weighted_examples=3.0 + 5.0))
                    self.assertEqual(
                        got_histogram[3],
                        calibration_histogram.Bucket(
                            bucket_id=8001,
                            weighted_labels=1.0 * 2.0 + 1.0 * 6.0,
                            weighted_predictions=0.8 * 2.0 + 0.8 * 6.0,
                            weighted_examples=2.0 + 6.0))
                    self.assertEqual(
                        got_histogram[4],
                        calibration_histogram.Bucket(bucket_id=10001,
                                                     weighted_labels=1.0 * 8.0,
                                                     weighted_predictions=1.1 *
                                                     8.0,
                                                     weighted_examples=8.0))

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 12
0
def _auc_plot(
    num_thresholds: int = DEFAULT_NUM_THRESHOLDS,
    name: Text = AUC_PLOT_NAME,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    class_weights: Optional[Dict[int, float]] = None
) -> metric_types.MetricComputations:
    """Returns metric computations for AUC plots."""
    key = metric_types.PlotKey(name=name,
                               model_name=model_name,
                               output_name=output_name,
                               sub_key=sub_key)

    # The interoploation stragety used here matches how the legacy post export
    # metrics calculated its plots.
    thresholds = [
        i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)
    ]
    thresholds = [-1e-6] + thresholds

    # Make sure matrices are calculated.
    matrices_computations = binary_confusion_matrices.binary_confusion_matrices(
        eval_config=eval_config,
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        class_weights=class_weights,
        thresholds=thresholds)
    matrices_key = matrices_computations[-1].keys[-1]

    def result(
        metrics: Dict[metric_types.MetricKey, Any]
    ) -> Dict[metric_types.MetricKey,
              metrics_for_slice_pb2.ConfusionMatrixAtThresholds]:
        return {
            key:
            confusion_matrix_at_thresholds.to_proto(thresholds,
                                                    metrics[matrices_key])
        }

    derived_computation = metric_types.DerivedMetricComputation(keys=[key],
                                                                result=result)
    computations = matrices_computations
    computations.append(derived_computation)
    return computations
Ejemplo n.º 13
0
def _calibration_plot(
    num_buckets: int = DEFAULT_NUM_BUCKETS,
    left: float = 0.0,
    right: float = 1.0,
    name: Text = CALIBRATION_PLOT_NAME,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None
) -> metric_types.MetricComputations:
    """Returns metric computations for calibration plot."""
    key = metric_types.PlotKey(name=name,
                               model_name=model_name,
                               output_name=output_name,
                               sub_key=sub_key)

    # Make sure calibration histogram is calculated. Note we are using the default
    # number of buckets assigned to the histogram instead of the value used for
    # the plots just in case the computation is shared with other metrics and
    # plots that need higher preicion. It will be downsampled later.
    computations = calibration_histogram.calibration_histogram(
        eval_config=eval_config,
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        left=left,
        right=right)
    histogram_key = computations[-1].keys[-1]

    def result(
        metrics: Dict[metric_types.MetricKey, Any]
    ) -> Dict[metric_types.MetricKey, Any]:
        thresholds = [
            left + i * (right - left) / num_buckets
            for i in range(num_buckets + 1)
        ]
        thresholds = [float('-inf')] + thresholds
        histogram = calibration_histogram.rebin(thresholds,
                                                metrics[histogram_key],
                                                left=left,
                                                right=right)
        return {key: _to_proto(thresholds, histogram)}

    derived_computation = metric_types.DerivedMetricComputation(keys=[key],
                                                                result=result)
    computations.append(derived_computation)
    return computations
Ejemplo n.º 14
0
def _multi_class_confusion_matrix_at_thresholds(
    thresholds: Optional[List[float]] = None,
    name: Text = MULTI_CLASS_CONFUSION_MATRIX_AT_THRESHOLDS_NAME,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
) -> metric_types.MetricComputations:
  """Returns computations for multi-class confusion matrix at thresholds."""
  key = metric_types.PlotKey(
      name=name, model_name=model_name, output_name=output_name)
  return [
      metric_types.MetricComputation(
          keys=[key],
          preprocessor=None,
          combiner=_MultiClassConfusionMatrixAtThresholdsCombiner(
              key=key, eval_config=eval_config, thresholds=thresholds))
  ]
def _multi_class_confusion_matrix_plot(
    thresholds: Optional[List[float]] = None,
    num_thresholds: Optional[int] = None,
    name: str = MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME,
    eval_config: Optional[config_pb2.EvalConfig] = None,
    model_name: str = '',
    output_name: str = '',
    example_weighted: bool = False) -> metric_types.MetricComputations:
  """Returns computations for multi-class confusion matrix plot."""
  if num_thresholds is None and thresholds is None:
    thresholds = [0.0]

  key = metric_types.PlotKey(
      name=name,
      model_name=model_name,
      output_name=output_name,
      example_weighted=example_weighted)

  # Make sure matrices are calculated.
  matrices_computations = (
      multi_class_confusion_matrix_metrics.multi_class_confusion_matrices(
          thresholds=thresholds,
          num_thresholds=num_thresholds,
          eval_config=eval_config,
          model_name=model_name,
          output_name=output_name,
          example_weighted=example_weighted))
  matrices_key = matrices_computations[-1].keys[-1]

  def result(
      metrics: Dict[metric_types.MetricKey,
                    multi_class_confusion_matrix_metrics.Matrices]
  ) -> Dict[metric_types.PlotKey,
            metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds]:
    return {
        key:
            metrics[matrices_key].to_proto()
            .multi_class_confusion_matrix_at_thresholds
    }

  derived_computation = metric_types.DerivedMetricComputation(
      keys=[key], result=result)
  computations = matrices_computations
  computations.append(derived_computation)
  return computations
Ejemplo n.º 16
0
def _multi_label_confusion_matrix_plot(
    thresholds: Optional[List[float]] = None,
    name: Text = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME,
    eval_config: Optional[config.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
) -> metric_types.MetricComputations:
    """Returns computations for multi-label confusion matrix at thresholds."""
    key = metric_types.PlotKey(name=name,
                               model_name=model_name,
                               output_name=output_name)
    return [
        metric_types.MetricComputation(
            keys=[key],
            preprocessor=None,
            combiner=_MultiLabelConfusionMatrixPlotCombiner(
                key=key, eval_config=eval_config, thresholds=thresholds))
    ]
Ejemplo n.º 17
0
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(
                        name='multi_label_confusion_matrix_plot')
                    got_matrix = got_plots[key]
                    self.assertProtoEquals(
                        """
              matrices {
                threshold: 0.5
                entries {
                  actual_class_id: 0
                  predicted_class_id: 0
                  false_negatives: 0.0
                  true_negatives: 0.0
                  false_positives: 0.0
                  true_positives: 2.0
                }
                entries {
                  actual_class_id: 0
                  predicted_class_id: 1
                  false_negatives: 1.0
                  true_negatives: 1.0
                  false_positives: 0.0
                  true_positives: 0.0
                }
                entries {
                  actual_class_id: 0
                  predicted_class_id: 2
                  false_negatives: 0.0
                  true_negatives: 2.0
                  false_positives: 0.0
                  true_positives: 0.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 0
                  false_negatives: 0.0
                  true_negatives: 1.0
                  false_positives: 0.0
                  true_positives: 1.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 1
                  false_negatives: 1.0
                  true_negatives: 0.0
                  false_positives: 0.0
                  true_positives: 1.0
                }
                entries {
                  actual_class_id: 1
                  predicted_class_id: 2
                  false_negatives: 0.0
                  false_positives: 0.0
                  true_negatives: 2.0
                  true_positives: 0.0
                }
              }
          """, got_matrix)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(name='calibration_plot')
                    self.assertIn(key, got_plots)
                    got_plot = got_plots[key]
                    self.assertProtoEquals(
                        """
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                total_weighted_label {
                  value: 4.0
                }
                total_weighted_refined_prediction {
                  value: -0.4
                }
                num_weighted_examples {
                  value: 4.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.1
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.1
                upper_threshold_exclusive: 0.2
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.2
                upper_threshold_exclusive: 0.3
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                  value: 1.6
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.3
                upper_threshold_exclusive: 0.4
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.4
                upper_threshold_exclusive: 0.5
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 0.6
                total_weighted_label {
                  value: 5.0
                }
                total_weighted_refined_prediction {
                  value: 4.0
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.6
                upper_threshold_exclusive: 0.7
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.7
                upper_threshold_exclusive: 0.8
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 0.8
                upper_threshold_exclusive: 0.9
                total_weighted_label {
                  value: 8.0
                }
                total_weighted_refined_prediction {
                  value: 6.4
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
              buckets {
                lower_threshold_inclusive: 0.9
                upper_threshold_exclusive: 1.0
                total_weighted_label {
                }
                total_weighted_refined_prediction {
                }
                num_weighted_examples {
                }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                total_weighted_label {
                  value: 8.0
                }
                total_weighted_refined_prediction {
                  value: 8.8
                }
                num_weighted_examples {
                  value: 8.0
                }
              }
          """, got_plot)

                except AssertionError as err:
                    raise util.BeamAssertException(err)
Ejemplo n.º 19
0
    def testSerializePlots(self):
        slice_key = _make_slice_key('fruit', 'apple')
        plot_key = metric_types.PlotKey(name='calibration_plot',
                                        output_name='output_name')
        calibration_plot = text_format.Parse(
            """
        buckets {
          lower_threshold_inclusive: -inf
          upper_threshold_exclusive: 0.0
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
        buckets {
          lower_threshold_inclusive: 0.0
          upper_threshold_exclusive: 0.5
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 1.0 }
          total_weighted_refined_prediction { value: 0.3 }
        }
        buckets {
          lower_threshold_inclusive: 0.5
          upper_threshold_exclusive: 1.0
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.7 }
        }
        buckets {
          lower_threshold_inclusive: 1.0
          upper_threshold_exclusive: inf
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
     """, metrics_for_slice_pb2.CalibrationHistogramBuckets())

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {
        single_slice_keys {
          column: 'fruit'
          bytes_value: 'apple'
        }
      }
      plot_keys_and_values {
        key {
          output_name: "output_name"
        }
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              upper_threshold_exclusive: 0.0
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
            buckets {
              lower_threshold_inclusive: 0.0
              upper_threshold_exclusive: 0.5
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 1.0 }
              total_weighted_refined_prediction { value: 0.3 }
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.7 }
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
          }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        got = metrics_and_plots_serialization._serialize_plots(
            (slice_key, {
                plot_key: calibration_plot
            }), None)
        self.assertProtoEquals(
            expected_plots_for_slice,
            metrics_for_slice_pb2.PlotsForSlice.FromString(got))
def calibration_histogram(
    num_buckets: Optional[int] = None,
    left: Optional[float] = None,
    right: Optional[float] = None,
    name: Optional[str] = None,
    eval_config: Optional[config_pb2.EvalConfig] = None,
    model_name: str = '',
    output_name: str = '',
    sub_key: Optional[metric_types.SubKey] = None,
    aggregation_type: Optional[metric_types.AggregationType] = None,
    class_weights: Optional[Dict[int, float]] = None,
    example_weighted: bool = False,
    prediction_based_bucketing: bool = True,
    fractional_labels: Optional[bool] = None,
) -> metric_types.MetricComputations:
    """Returns metric computations for calibration histogram.

  Args:
    num_buckets: Number of buckets to use. Note that the actual number of
      buckets will be num_buckets + 2 to account for the edge cases.
    left: Start of predictions interval.
    right: End of predictions interval.
    name: Metric name.
    eval_config: Eval config.
    model_name: Optional model name (if multi-model evaluation).
    output_name: Optional output name (if multi-output model type).
    sub_key: Optional sub key.
    aggregation_type: Optional aggregation type.
    class_weights: Optional class weights to apply to multi-class / multi-label
      labels and predictions prior to flattening (when micro averaging is used).
    example_weighted:  True if example weights should be applied.
    prediction_based_bucketing: If true, create buckets based on predictions
      else use labels to perform bucketing.
    fractional_labels: If true, each incoming tuple of (label, prediction, and
      example weight) will be split into two tuples as follows (where l, p, w
      represent the resulting label, prediction, and example weight values): (1)
        l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l =
        1.0, p = prediction, and w = example_weight * label If enabled, an
        exception will be raised if labels are not within [0, 1]. The
        implementation is such that tuples associated with a weight of zero are
        not yielded. This means it is safe to enable fractional_labels even when
        the labels only take on the values of 0.0 or 1.0.

  Returns:
    MetricComputations for computing the histogram(s).
  """
    if num_buckets is None:
        num_buckets = DEFAULT_NUM_BUCKETS
    if left is None:
        left = 0.0
    if right is None:
        right = 1.0
    if fractional_labels is None:
        fractional_labels = (left == 0.0 and right == 1.0)
    if name is None:
        name = f'{CALIBRATION_HISTOGRAM_NAME}_{num_buckets}'
    key = metric_types.PlotKey(name=name,
                               model_name=model_name,
                               output_name=output_name,
                               sub_key=sub_key,
                               example_weighted=example_weighted)
    return [
        metric_types.MetricComputation(
            keys=[key],
            preprocessor=None,
            combiner=_CalibrationHistogramCombiner(
                key=key,
                eval_config=eval_config,
                aggregation_type=aggregation_type,
                class_weights=class_weights,
                example_weighted=example_weighted,
                num_buckets=num_buckets,
                left=left,
                right=right,
                prediction_based_bucketing=prediction_based_bucketing,
                fractional_labels=fractional_labels))
    ]