def testServingGraphAlsoExportedIfSpecified(self):
        # Most of the example trainers also pass serving_input_receiver_fn to
        # export_eval_savedmodel, so the serving graph should be included.
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))

        # Check the eval graph.
        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        example1 = self._makeExample(prediction=0.9,
                                     label=0.0).SerializeToString()
        eval_saved_model.metrics_reset_update_get(example1)

        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(metric_values,
                                           {'average_loss': 0.81})

        # Check the serving graph.
        # TODO(b/124466113): Remove tf.compat.v2 once TF 2.0 is the default.
        if hasattr(tf, 'compat.v2'):
            imported = tf.compat.v2.saved_model.load(
                eval_export_dir, tags=tf.saved_model.SERVING)
            predictions = imported.signatures[
                tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
                    inputs=tf.constant([example1.SerializeToString()]))
            self.assertAllClose(predictions['outputs'], np.array([[0.9]]))
Esempio n. 2
0
    def testServingGraphAlsoExportedIfSpecified(self):
        # Most of the example trainers also pass serving_input_receiver_fn to
        # export_eval_savedmodel, so the serving graph should be included.
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))

        # Check the eval graph.
        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        example1 = self._makeExample(prediction=0.9, label=0.0)
        features_predictions_labels = self.predict_injective_single_example(
            eval_saved_model, example1.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)

        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(metric_values,
                                           {'average_loss': 0.81})

        # Check the serving graph.
        estimator = tf.contrib.estimator.SavedModelEstimator(eval_export_dir)

        def predict_input_fn():
            return {'inputs': tf.constant([example1.SerializeToString()])}

        predictions = next(estimator.predict(predict_input_fn))
        self.assertAllClose(predictions['outputs'], np.array([0.9]))
Esempio n. 3
0
    def testAucUnweighted(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        examples = [
            self._makeExample(prediction=0.0000, label=0.0000),
            self._makeExample(prediction=0.0000, label=1.0000),
            self._makeExample(prediction=0.7000, label=1.0000),
            self._makeExample(prediction=0.8000, label=0.0000),
            self._makeExample(prediction=1.0000, label=1.0000),
        ]

        expected_values_dict = {
            metric_keys.AUC: 0.58333,
            metric_keys.lower_bound(metric_keys.AUC): 0.5,
            metric_keys.upper_bound(metric_keys.AUC): 0.66667,
            metric_keys.lower_bound(metric_keys.AUPRC): 0.74075,
            metric_keys.lower_bound(metric_keys.AUPRC): 0.70000,
            metric_keys.upper_bound(metric_keys.AUPRC): 0.77778,
        }

        self._runTest(
            examples, eval_export_dir,
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')], expected_values_dict)
Esempio n. 4
0
    def testCalibrationPlotAndPredictionHistogramUnweighted(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        examples = [
            # For each example, we set label to prediction + 1.
            # These two go in bucket 0: (-inf, 0)
            self._makeExample(prediction=-10.0, label=-9.0),
            self._makeExample(prediction=-9.0, label=-8.0),
            # This goes in bucket 1: [0, 0.00100)
            self._makeExample(prediction=0.00000, label=1.00000),
            # These three go in bucket 1: [0.00100, 0.00110)
            self._makeExample(prediction=0.00100, label=1.00100),
            self._makeExample(prediction=0.00101, label=1.00101),
            self._makeExample(prediction=0.00102, label=1.00102),
            # These two go in bucket 10000: [0.99990, 1.00000)
            self._makeExample(prediction=0.99998, label=1.99998),
            self._makeExample(prediction=0.99999, label=1.99999),
            # These four go in bucket 10001: [1.0000, +inf)
            self._makeExample(prediction=1.0, label=2.0),
            self._makeExample(prediction=8.0, label=9.0),
            self._makeExample(prediction=9.0, label=10.0),
            self._makeExample(prediction=10.0, label=11.0),
        ]

        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertIn(metric_keys.CALIBRATION_PLOT_MATRICES, value)
                buckets = value[metric_keys.CALIBRATION_PLOT_MATRICES]
                self.assertSequenceAlmostEqual(buckets[0], [-19.0, -17.0, 2.0])
                self.assertSequenceAlmostEqual(buckets[1], [0.0, 1.0, 1.0])
                self.assertSequenceAlmostEqual(buckets[11],
                                               [0.00303, 3.00303, 3.0])
                self.assertSequenceAlmostEqual(buckets[10000],
                                               [1.99997, 3.99997, 2.0])
                self.assertSequenceAlmostEqual(buckets[10001],
                                               [28.0, 32.0, 4.0])
                self.assertIn(metric_keys.CALIBRATION_PLOT_BOUNDARIES, value)
                boundaries = value[metric_keys.CALIBRATION_PLOT_BOUNDARIES]
                self.assertAlmostEqual(0.0, boundaries[0])
                self.assertAlmostEqual(0.001, boundaries[10])
                self.assertAlmostEqual(0.005, boundaries[50])
                self.assertAlmostEqual(0.010, boundaries[100])
                self.assertAlmostEqual(0.100, boundaries[1000])
                self.assertAlmostEqual(0.800, boundaries[8000])
                self.assertAlmostEqual(1.000, boundaries[10000])
            except AssertionError as err:
                raise util.BeamAssertException(err)

        self._runTestWithCustomCheck(
            examples,
            eval_export_dir,
            [post_export_metrics.calibration_plot_and_prediction_histogram()],
            custom_plots_check=check_result)
Esempio n. 5
0
  def testEvaluateWithPlots(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(prediction=0.0, label=1.0)
      example2 = self._makeExample(prediction=0.7, label=0.0)
      example3 = self._makeExample(prediction=0.8, label=1.0)
      example4 = self._makeExample(prediction=1.0, label=1.0)

      metrics, plots = (
          pipeline
          | beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
              example3.SerializeToString(),
              example4.SerializeToString()
          ])
          | evaluate.Evaluate(
              eval_saved_model_path=eval_export_dir,
              add_metrics_callbacks=[
                  post_export_metrics.example_count(),
                  post_export_metrics.auc_plots()
              ]))

      def check_metrics(got):
        try:
          self.assertEqual(1, len(got), 'got: %s' % got)
          (slice_key, value) = got[0]
          self.assertEqual((), slice_key)
          self.assertDictElementsAlmostEqual(
              got_values_dict=value,
              expected_values_dict={
                  metric_keys.EXAMPLE_COUNT: 4.0,
              })
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_metrics, label='metrics')

      def check_plots(got):
        try:
          self.assertEqual(1, len(got), 'got: %s' % got)
          (slice_key, value) = got[0]
          self.assertEqual((), slice_key)
          self.assertDictMatrixRowsAlmostEqual(
              got_values_dict=value,
              expected_values_dict={
                  metric_keys.AUC_PLOTS_MATRICES: [(8001, [
                      2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0
                  ])],
              })
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(plots, check_plots, label='plots')
Esempio n. 6
0
    def testAucUnweightedSerialization(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        examples = [
            self._makeExample(prediction=0.0000, label=0.0000),
            self._makeExample(prediction=0.0000, label=1.0000),
            self._makeExample(prediction=0.7000, label=1.0000),
            self._makeExample(prediction=0.8000, label=0.0000),
            self._makeExample(prediction=1.0000, label=1.0000),
        ]

        expected_values_dict = {
            metric_keys.lower_bound(metric_keys.AUPRC): 0.74075,
            metric_keys.lower_bound(metric_keys.AUPRC): 0.70000,
            metric_keys.upper_bound(metric_keys.AUPRC): 0.77778,
        }

        auc_metric = post_export_metrics.auc(curve='PR')

        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertDictElementsAlmostEqual(value, expected_values_dict)

                # Check serialization too.
                # Note that we can't just make this a dict, since proto maps
                # allow uninitialized key access, i.e. they act like defaultdicts.
                output_metrics = metrics_for_slice_pb2.MetricsForSlice(
                ).metrics
                auc_metric.populate_stats_and_pop(value, output_metrics)
                self.assertProtoEquals(
                    """
            bounded_value {
              lower_bound {
                value: 0.6999999
              }
              upper_bound {
                value: 0.7777776
              }
              value {
                value: 0.7407472
              }
            }
            """, output_metrics[metric_keys.AUPRC])
            except AssertionError as err:
                raise util.BeamAssertException(err)

        self._runTestWithCustomCheck(examples,
                                     eval_export_dir, [auc_metric],
                                     custom_metrics_check=check_result)
Esempio n. 7
0
    def testAucPlotsUnweighted(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        examples = [
            self._makeExample(prediction=0.0000, label=0.0000),
            self._makeExample(prediction=0.0000, label=1.0000),
            self._makeExample(prediction=0.7000, label=1.0000),
            self._makeExample(prediction=0.8000, label=0.0000),
            self._makeExample(prediction=1.0000, label=1.0000),
        ]

        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertIn(metric_keys.AUC_PLOTS_MATRICES, value)
                matrices = value[metric_keys.AUC_PLOTS_MATRICES]
                #            |      | --------- Threshold -----------
                # true label | pred | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0
                #     -      | 0.0  | FP    | TN  | TN  | TN  | TN
                #     +      | 0.0  | TP    | FN  | FN  | FN  | FN
                #     +      | 0.7  | TP    | TP  | FN  | FN  | FN
                #     -      | 0.8  | FP    | FP  | FP  | TN  | TN
                #     +      | 1.0  | TP    | TP  | TP  | TP  | FN
                self.assertSequenceAlmostEqual(matrices[0],
                                               [0, 0, 2, 3, 3.0 / 5.0, 1.0])
                self.assertSequenceAlmostEqual(
                    matrices[1], [1, 1, 1, 2, 2.0 / 3.0, 2.0 / 3.0])
                self.assertSequenceAlmostEqual(
                    matrices[7001], [2, 1, 1, 1, 1.0 / 2.0, 1.0 / 3.0])
                self.assertSequenceAlmostEqual(
                    matrices[8001], [2, 2, 0, 1, 1.0 / 1.0, 1.0 / 3.0])
                self.assertSequenceAlmostEqual(
                    matrices[10001],
                    [3, 2, 0, 0, float('nan'), 0.0])
                self.assertIn(metric_keys.AUC_PLOTS_THRESHOLDS, value)
                thresholds = value[metric_keys.AUC_PLOTS_THRESHOLDS]
                self.assertAlmostEqual(0.0, thresholds[1])
                self.assertAlmostEqual(0.001, thresholds[11])
                self.assertAlmostEqual(0.005, thresholds[51])
                self.assertAlmostEqual(0.010, thresholds[101])
                self.assertAlmostEqual(0.100, thresholds[1001])
                self.assertAlmostEqual(0.800, thresholds[8001])
                self.assertAlmostEqual(1.000, thresholds[10001])
            except AssertionError as err:
                raise util.BeamAssertException(err)

        self._runTestWithCustomCheck(examples,
                                     eval_export_dir,
                                     [post_export_metrics.auc_plots()],
                                     custom_plots_check=check_result)
Esempio n. 8
0
 def testAssertMetricsComputedWithBeamAre(self):
   temp_eval_export_dir = self._getEvalExportDir()
   _, eval_export_dir = (
       fixed_prediction_estimator.simple_fixed_prediction_estimator(
           None, temp_eval_export_dir))
   examples = [
       self.makeExample(prediction=0.0, label=1.0),
       self.makeExample(prediction=0.7, label=0.0),
       self.makeExample(prediction=0.8, label=1.0),
       self.makeExample(prediction=1.0, label=1.0)
   ]
   self.assertMetricsComputedWithBeamAre(
       eval_saved_model_path=eval_export_dir,
       serialized_examples=examples,
       expected_metrics={'average_loss': (1.0 + 0.49 + 0.04 + 0.00) / 4.0})
 def testExampleCountNoStandardKeys(self):
   # Test ExampleCount with a custom Estimator that doesn't have any of the
   # standard PredictionKeys.
   temp_eval_export_dir = self._getEvalExportDir()
   _, eval_export_dir = (
       fixed_prediction_estimator.simple_fixed_prediction_estimator(
           None, temp_eval_export_dir, output_prediction_key='non_standard'))
   examples = [
       self._makeExample(prediction=5.0, label=5.0),
       self._makeExample(prediction=6.0, label=6.0),
       self._makeExample(prediction=7.0, label=7.0),
   ]
   expected_values_dict = {
       metric_keys.EXAMPLE_COUNT: 3.0,
   }
   self._runTest(examples, eval_export_dir, [
       post_export_metrics.example_count(),
   ], expected_values_dict)
Esempio n. 10
0
 def testExampleCountEmptyPredictionsDict(self):
     # Test ExampleCount with a custom Estimator that has empty predictions dict.
     # This is possible if the Estimator doesn't return the predictions dict
     # in EVAL mode, but computes predictions and feeds them into the metrics
     # internally.
     temp_eval_export_dir = self._getEvalExportDir()
     _, eval_export_dir = (
         fixed_prediction_estimator.simple_fixed_prediction_estimator(
             None, temp_eval_export_dir, output_prediction_key=None))
     examples = [
         self._makeExample(prediction=5.0, label=5.0),
         self._makeExample(prediction=6.0, label=6.0),
         self._makeExample(prediction=7.0, label=7.0),
     ]
     expected_values_dict = {
         metric_keys.EXAMPLE_COUNT: 3.0,
     }
     self._runTest(examples, eval_export_dir, [
         post_export_metrics.example_count(),
     ], expected_values_dict)
Esempio n. 11
0
    def testBoundedValueChecks(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        examples = [
            self.makeExample(prediction=0.8, label=1.0),
        ]

        self.assertMetricsComputedWithBeamAre(
            eval_saved_model_path=eval_export_dir,
            serialized_examples=examples,
            expected_metrics={'average_loss': 0.04})

        self.assertMetricsComputedWithoutBeamAre(
            eval_saved_model_path=eval_export_dir,
            serialized_examples=examples,
            expected_metrics={
                'average_loss':
                tfma_unit.BoundedValue(lower_bound=0.03, upper_bound=0.05)
            })

        with self.assertRaisesRegexp(
                AssertionError,
                'expecting key average_loss to have value between'):
            self.assertMetricsComputedWithoutBeamAre(
                eval_saved_model_path=eval_export_dir,
                serialized_examples=examples,
                expected_metrics={
                    'average_loss': tfma_unit.BoundedValue(upper_bound=0.01)
                })

        with self.assertRaisesRegexp(
                AssertionError,
                'expecting key average_loss to have value between'):
            self.assertMetricsComputedWithoutBeamAre(
                eval_saved_model_path=eval_export_dir,
                serialized_examples=examples,
                expected_metrics={
                    'average_loss': tfma_unit.BoundedValue(lower_bound=0.10)
                })
    def testEvaluateWithPlots(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                post_export_metrics.example_count(),
                post_export_metrics.auc_plots()
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=0.0, label=1.0)
            example2 = self._makeExample(prediction=0.7, label=0.0)
            example3 = self._makeExample(prediction=0.8, label=1.0)
            example4 = self._makeExample(prediction=1.0, label=1.0)

            metrics, plots = (
                pipeline
                | 'Create' >> beam.Create([
                    example1.SerializeToString(),
                    example2.SerializeToString(),
                    example3.SerializeToString(),
                    example4.SerializeToString()
                ])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
                | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator.
                ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

            def check_metrics(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictElementsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            metric_keys.EXAMPLE_COUNT: 4.0,
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics, check_metrics, label='metrics')

            def check_plots(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictMatrixRowsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            _full_key(metric_keys.AUC_PLOTS_MATRICES):
                            [(8001, [2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0])],
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(plots, check_plots, label='plots')
  def testWriteMetricsAndPlots(self):
    metrics_file = os.path.join(self._getTempDir(), 'metrics')
    plots_file = os.path.join(self._getTempDir(), 'plots')
    temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir')

    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    eval_config = config.EvalConfig(
        model_specs=[config.ModelSpec()],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}))
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=eval_export_dir,
        add_metrics_callbacks=[
            post_export_metrics.example_count(),
            post_export_metrics.calibration_plot_and_prediction_histogram(
                num_buckets=2)
        ])
    extractors = [
        predict_extractor.PredictExtractor(eval_shared_model),
        slice_key_extractor.SliceKeyExtractor()
    ]
    evaluators = [
        metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model)
    ]
    output_paths = {
        constants.METRICS_KEY: metrics_file,
        constants.PLOTS_KEY: plots_file
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths, eval_shared_model.add_metrics_callbacks)
    ]

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(prediction=0.0, label=1.0)
      example2 = self._makeExample(prediction=1.0, label=1.0)

      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
          ])
          | 'ExtractEvaluateAndWriteResults' >>
          model_eval_lib.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              extractors=extractors,
              evaluators=evaluators,
              writers=writers))
      # pylint: enable=no-value-for-parameter

    expected_metrics_for_slice = text_format.Parse(
        """
        slice_key {}
        metrics {
          key: "average_loss"
          value {
            double_value {
              value: 0.5
            }
          }
        }
        metrics {
          key: "post_export_metrics/example_count"
          value {
            double_value {
              value: 2.0
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

    metric_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file):
      metric_records.append(
          metrics_for_slice_pb2.MetricsForSlice.FromString(record))
    self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records)
    self.assertProtoEquals(expected_metrics_for_slice, metric_records[0])

    expected_plots_for_slice = text_format.Parse(
        """
      slice_key {}
      plots {
        key: "post_export_metrics"
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              num_weighted_examples {}
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              upper_threshold_exclusive: 0.5
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples {
              }
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {
                value: 1.0
              }
            }
         }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

    plot_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(plots_file):
      plot_records.append(
          metrics_for_slice_pb2.PlotsForSlice.FromString(record))
    self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records)
    self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
Esempio n. 14
0
    def testConfusionMatrixAtThresholdsSerialization(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        examples = [
            self._makeExample(prediction=0.0000, label=0.0000),
            self._makeExample(prediction=0.5000, label=1.0000),
            self._makeExample(prediction=1.0000, label=1.0000),
        ]

        confusion_matrix_at_thresholds_metric = (
            post_export_metrics.confusion_matrix_at_thresholds(
                thresholds=[0.25, 0.75, 1.00]))

        def check_result(got):  # pylint: disable=invalid-name
            try:
                self.assertEqual(1, len(got), 'got: %s' % got)
                (slice_key, value) = got[0]
                self.assertEqual((), slice_key)
                self.assertIn(
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES, value)
                matrices = value[
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_MATRICES]
                #            |      | ---- Threshold ----
                # true label | pred | 0.25 | 0.75 | 1.00
                #     -      | 0.0  | TN   | TN   | TN
                #     +      | 0.5  | TP   | FN   | FN
                #     +      | 1.0  | TP   | TP   | FN
                self.assertSequenceAlmostEqual(matrices[0],
                                               [0.0, 1.0, 0.0, 2.0, 1.0, 1.0])
                self.assertSequenceAlmostEqual(matrices[1],
                                               [1.0, 1.0, 0.0, 1.0, 1.0, 0.5])
                self.assertSequenceAlmostEqual(
                    matrices[2],
                    [2.0, 1.0, 0.0, 0.0, float('nan'), 0.0])
                self.assertIn(
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS,
                    value)
                thresholds = value[
                    metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS_THRESHOLDS]
                self.assertAlmostEqual(0.25, thresholds[0])
                self.assertAlmostEqual(0.75, thresholds[1])
                self.assertAlmostEqual(1.00, thresholds[2])

                # Check serialization too.
                # Note that we can't just make this a dict, since proto maps
                # allow uninitialized key access, i.e. they act like defaultdicts.
                output_metrics = metrics_for_slice_pb2.MetricsForSlice(
                ).metrics
                confusion_matrix_at_thresholds_metric.populate_stats_and_pop(
                    value, output_metrics)
                self.assertProtoEquals(
                    """
            confusion_matrix_at_thresholds {
              matrices {
                threshold: 0.25
                false_negatives: 0.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 2.0
                precision: 1.0
                recall: 1.0
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 1.00
                false_negatives: 2.0
                true_negatives: 1.0
                false_positives: 0.0
                true_positives: 0.0
                precision: nan
                recall: 0.0
              }
            }
            """, output_metrics[metric_keys.CONFUSION_MATRIX_AT_THRESHOLDS])
            except AssertionError as err:
                raise util.BeamAssertException(err)

        self._runTestWithCustomCheck(examples,
                                     eval_export_dir,
                                     [confusion_matrix_at_thresholds_metric],
                                     custom_metrics_check=check_result)
    def _write_tfma(self,
                    tfma_path: str,
                    output_file_format: str,
                    store: Optional[mlmd.MetadataStore] = None):
        _, eval_saved_model_path = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                export_path=None,
                eval_export_path=os.path.join(self.tmpdir, 'eval_export_dir')))
        eval_config = tfma.EvalConfig(model_specs=[tfma.ModelSpec()])
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_saved_model_path,
            add_metrics_callbacks=[
                tfma.post_export_metrics.example_count(),
                tfma.post_export_metrics.
                calibration_plot_and_prediction_histogram(num_buckets=2)
            ])
        extractors = [
            tfma.extractors.legacy_predict_extractor.PredictExtractor(
                eval_shared_model, eval_config=eval_config),
            tfma.extractors.unbatch_extractor.UnbatchExtractor(),
            tfma.extractors.slice_key_extractor.SliceKeyExtractor()
        ]
        evaluators = [
            tfma.evaluators.legacy_metrics_and_plots_evaluator.
            MetricsAndPlotsEvaluator(eval_shared_model)
        ]
        writers = [
            tfma.writers.MetricsPlotsAndValidationsWriter(
                output_paths={
                    'metrics': os.path.join(tfma_path, 'metrics'),
                    'plots': os.path.join(tfma_path, 'plots')
                },
                output_file_format=output_file_format,
                eval_config=eval_config,
                add_metrics_callbacks=eval_shared_model.add_metrics_callbacks)
        ]

        tfx_io = raw_tf_record.RawBeamRecordTFXIO(
            physical_format='inmemory',
            raw_record_column_name='__raw_record__',
            telemetry_descriptors=['TFMATest'])
        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=0.0, label=1.0)
            example2 = self._makeExample(prediction=1.0, label=1.0)
            _ = (pipeline
                 | 'Create' >> beam.Create([
                     example1.SerializeToString(),
                     example2.SerializeToString(),
                 ])
                 | 'BatchExamples' >> tfx_io.BeamSource()
                 | 'ExtractEvaluateAndWriteResults' >>
                 tfma.ExtractEvaluateAndWriteResults(
                     eval_config=eval_config,
                     eval_shared_model=eval_shared_model,
                     extractors=extractors,
                     evaluators=evaluators,
                     writers=writers))

        if store:
            eval_type = metadata_store_pb2.ArtifactType()
            eval_type.name = standard_artifacts.ModelEvaluation.TYPE_NAME
            eval_type_id = store.put_artifact_type(eval_type)

            artifact = metadata_store_pb2.Artifact()
            artifact.uri = tfma_path
            artifact.type_id = eval_type_id
            store.put_artifacts([artifact])