コード例 #1
0
  def testConfusionMatrixAtThresholdsWeighted(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator_extra_fields
        .simple_fixed_prediction_estimator_extra_fields(None,
                                                        temp_eval_export_dir))
    examples = self.makeConfusionMatrixExamples()

    def check_result(got):  # pylint: disable=invalid-name
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        self.assertIn(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES,
                      value)
        matrices = value[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES]
        self.assertSequenceAlmostEqual(matrices[0],
                                       [0.0, 0.0, 3.0, 7.0, 7.0 / 10.0, 1.0])
        self.assertSequenceAlmostEqual(
            matrices[1], [1.0, 1.0, 2.0, 6.0, 6.0 / 8.0, 6.0 / 7.0])
        self.assertSequenceAlmostEqual(
            matrices[2], [4.0, 1.0, 2.0, 3.0, 3.0 / 5.0, 3.0 / 7.0])
        self.assertSequenceAlmostEqual(matrices[3],
                                       [4.0, 3.0, 0.0, 3.0, 1.0, 3.0 / 7.0])
        self.assertSequenceAlmostEqual(
            matrices[4],
            [7.0, 3.0, 0.0, 0.0, float('nan'), 0.0])
        self.assertIn(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS,
                      value)
        thresholds = value[metric_keys
                           .CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS]
        self.assertAlmostEqual(-1e-6, thresholds[0])
        self.assertAlmostEqual(0.0, thresholds[1])
        self.assertAlmostEqual(0.7, thresholds[2])
        self.assertAlmostEqual(0.8, thresholds[3])
        self.assertAlmostEqual(1.0, thresholds[4])
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [
            post_export_metrics.confusion_matrix_at_thresholds(
                example_weight_key='fixed_float',
                thresholds=[-1e-6, 0.0, 0.7, 0.8, 1.0])
        ],
        custom_metrics_check=check_result)
コード例 #2
0
    def testSerializeConfusionMatrices(self):
        slice_key = _make_slice_key()

        thresholds = [0.25, 0.75, 1.00]
        matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0],
                    [1.0, 1.0, 0.0, 1.0, 1.0, 0.5],
                    [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]]

        slice_metrics = {
            _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES):
            matrices,
            _full_key(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS):
            thresholds,
        }
        expected_metrics_for_slice = text_format.Parse(
            """
        slice_key {}
        metrics {
          key: "post_export_metrics/confusion_matrix_at_thresholds"
          value {
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
                bounded_false_negatives {
                  value {
                    value: 0.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 2.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 1.0
                  }
                }
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
                bounded_false_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 1.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 0.5
                  }
                }
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
                bounded_false_negatives {
                  value {
                    value: 2.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: nan
                  }
                }
                bounded_recall {
                  value {
                    value: 0.0
                  }
                }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.confusion_matrix_at_thresholds(thresholds)])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
コード例 #3
0
  def testConvertSliceMetricsToProtoConfusionMatrices(self):
    slice_key = _make_slice_key()

    thresholds = [0.25, 0.75, 1.00]
    matrices = [[0.0, 1.0, 0.0, 2.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 0.5],
                [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0]]

    slice_metrics = {
        metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES: matrices,
        metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS: thresholds,
    }
    expected_metrics_for_slice = text_format.Parse(
        """
        slice_key {}
        metrics {
          key: "post_export_metrics/confusion_matrix_at_thresholds"
          value {
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
                bounded_false_negatives {
                  value {
                    value: 0.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 2.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 1.0
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 2.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 1.0
                  }
                }
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
                bounded_false_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 1.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: 1.0
                  }
                }
                bounded_recall {
                  value {
                    value: 0.5
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 0.5
                  }
                }
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
                bounded_false_negatives {
                  value {
                    value: 2.0
                  }
                }
                bounded_true_negatives {
                  value {
                    value: 1.0
                  }
                }
                bounded_true_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_false_positives {
                  value {
                    value: 0.0
                  }
                }
                bounded_precision {
                  value {
                    value: nan
                  }
                }
                bounded_recall {
                  value {
                    value: 0.0
                  }
                }
                t_distribution_false_negatives {
                  unsampled_value {
                    value: 2.0
                  }
                }
                t_distribution_true_negatives {
                  unsampled_value {
                    value: 1.0
                  }
                }
                t_distribution_true_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_false_positives {
                  unsampled_value {
                    value: 0.0
                  }
                }
                t_distribution_precision {
                  unsampled_value {
                    value: nan
                  }
                }
                t_distribution_recall {
                  unsampled_value {
                    value: 0.0
                  }
                }
              }
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

    got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
        (slice_key, slice_metrics),
        [post_export_metrics.confusion_matrix_at_thresholds(thresholds)])
    self.assertProtoEquals(expected_metrics_for_slice, got)
コード例 #4
0
  def testConfusionMatrixAtThresholdsSerialization(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    examples = [
        self._makeExample(prediction=0.0000, label=0.0000),
        self._makeExample(prediction=0.5000, label=1.0000),
        self._makeExample(prediction=1.0000, label=1.0000),
    ]

    confusion_matrix_at_thresholds_metric = (
        post_export_metrics.confusion_matrix_at_thresholds(
            thresholds=[0.25, 0.75, 1.00]))

    def check_result(got):  # pylint: disable=invalid-name
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        self.assertIn(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES,
                      value)
        matrices = value[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES]
        #            |      | ---- Threshold ----
        # true label | pred | 0.25 | 0.75 | 1.00
        #     -      | 0.0  | TN   | TN   | TN
        #     +      | 0.5  | TP   | FN   | FN
        #     +      | 1.0  | TP   | TP   | FN
        self.assertSequenceAlmostEqual(matrices[0],
                                       [0.0, 1.0, 0.0, 2.0, 1.0, 1.0])
        self.assertSequenceAlmostEqual(matrices[1],
                                       [1.0, 1.0, 0.0, 1.0, 1.0, 0.5])
        self.assertSequenceAlmostEqual(
            matrices[2],
            [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0])
        self.assertIn(metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS,
                      value)
        thresholds = value[metric_keys
                           .CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS]
        self.assertAlmostEqual(0.25, thresholds[0])
        self.assertAlmostEqual(0.75, thresholds[1])
        self.assertAlmostEqual(1.00, thresholds[2])

        # Check serialization too.
        # Note that we can't just make this a dict, since proto maps
        # allow uninitialized key access, i.e. they act like defaultdicts.
        output_metrics = metrics_for_slice_pb2.MetricsForSlice().metrics
        confusion_matrix_at_thresholds_metric.populate_stats_and_pop(
            value, output_metrics)
        self.assertProtoEquals(
            """
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
              }
            }
            """, output_metrics[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS])
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [confusion_matrix_at_thresholds_metric],
        custom_metrics_check=check_result)