Example #1
0
    def testRunModelAnalysisWithMultiplePlots(self):
        model_location = self._exportEvalSavedModel(
            fixed_prediction_estimator.simple_fixed_prediction_estimator)
        examples = [
            self._makeExample(prediction=0.0, label=1.0),
            self._makeExample(prediction=0.7, label=0.0),
            self._makeExample(prediction=0.8, label=1.0),
            self._makeExample(prediction=1.0, label=1.0),
            self._makeExample(prediction=1.0, label=1.0)
        ]
        data_location = self._writeTFExamplesToTFRecords(examples)
        eval_config = config.EvalConfig(
            input_data_specs=[config.InputDataSpec(location=data_location)],
            model_specs=[config.ModelSpec(location=model_location)],
            output_data_specs=[
                config.OutputDataSpec(default_location=self._getTempDir())
            ])
        eval_shared_model = model_eval_lib.default_eval_shared_model(
            eval_saved_model_path=model_location,
            add_metrics_callbacks=[
                post_export_metrics.auc_plots(),
                post_export_metrics.auc_plots(metric_tag='test')
            ])
        eval_result = model_eval_lib.run_model_analysis(
            eval_config=eval_config, eval_shared_models=[eval_shared_model])

        # pipeline works.
        expected_metrics = {
            (): {
                metric_keys.EXAMPLE_COUNT: {
                    'doubleValue': 5.0
                },
            }
        }
        expected_matrix = {
            'threshold': 0.8,
            'falseNegatives': 2.0,
            'trueNegatives': 1.0,
            'truePositives': 2.0,
            'precision': 1.0,
            'recall': 0.5
        }
        self.assertMetricsAlmostEqual(eval_result.slicing_metrics,
                                      expected_metrics)
        self.assertEqual(len(eval_result.plots), 1)
        slice_key, plots = eval_result.plots[0]
        self.assertEqual((), slice_key)
        tf.compat.v1.logging.info(plots.keys())
        self.assertDictElementsAlmostEqual(
            plots['']['']['post_export_metrics']['confusionMatrixAtThresholds']
            ['matrices'][8001], expected_matrix)
        self.assertDictElementsAlmostEqual(
            plots['']['']['post_export_metrics/test']
            ['confusionMatrixAtThresholds']['matrices'][8001], expected_matrix)
 def testRunModelAnalysisWithMultiplePlots(self):
     model_location = self._exportEvalSavedModel(
         fixed_prediction_estimator.simple_fixed_prediction_estimator)
     examples = [
         self._makeExample(prediction=0.0, label=1.0),
         self._makeExample(prediction=0.7, label=0.0),
         self._makeExample(prediction=0.8, label=1.0),
         self._makeExample(prediction=1.0, label=1.0),
         self._makeExample(prediction=1.0, label=1.0)
     ]
     eval_shared_model = model_eval_lib.default_eval_shared_model(
         eval_saved_model_path=model_location,
         add_metrics_callbacks=[
             post_export_metrics.auc_plots(),
             post_export_metrics.auc_plots(metric_tag='test')
         ])
     data_location = self._writeTFExamplesToTFRecords(examples)
     eval_result = model_eval_lib.run_model_analysis(
         eval_shared_model, data_location)
     # We only check some of the metrics to ensure that the end-to-end
     # pipeline works.
     expected_metrics = {
         (): {
             metric_keys.EXAMPLE_COUNT: {
                 'doubleValue': 5.0
             },
         }
     }
     expected_matrix = {
         'threshold': 0.8,
         'falseNegatives': 2.0,
         'trueNegatives': 1.0,
         'truePositives': 2.0,
         'precision': 1.0,
         'recall': 0.5
     }
     self.assertMetricsAlmostEqual(eval_result.slicing_metrics,
                                   expected_metrics)
     self.assertEqual(len(eval_result.plots), 1)
     slice_key, plots = eval_result.plots[0]
     self.assertEqual((), slice_key)
     tf.logging.info(plots.keys())
     self.assertDictElementsAlmostEqual(
         plots['post_export_metrics']['confusionMatrixAtThresholds']
         ['matrices'][8001], expected_matrix)
     self.assertDictElementsAlmostEqual(
         plots['post_export_metrics/test']['confusionMatrixAtThresholds']
         ['matrices'][8001], expected_matrix)
    def testEvaluateWithPlots(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                post_export_metrics.example_count(),
                post_export_metrics.auc_plots()
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=0.0, label=1.0)
            example2 = self._makeExample(prediction=0.7, label=0.0)
            example3 = self._makeExample(prediction=0.8, label=1.0)
            example4 = self._makeExample(prediction=1.0, label=1.0)

            metrics, plots = (
                pipeline
                | 'Create' >> beam.Create([
                    example1.SerializeToString(),
                    example2.SerializeToString(),
                    example3.SerializeToString(),
                    example4.SerializeToString()
                ])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
                | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator.
                ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

            def check_metrics(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictElementsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            metric_keys.EXAMPLE_COUNT: 4.0,
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics, check_metrics, label='metrics')

            def check_plots(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictMatrixRowsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            _full_key(metric_keys.AUC_PLOTS_MATRICES):
                            [(8001, [2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0])],
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(plots, check_plots, label='plots')
 def testAucPlotSerialization(self):
   # Auc for the model
   # {prediction:0.3, true_label:+},
   # {prediction:0.7, true_label:-}
   #
   # These plots were generated by hand. For this test to make sense
   # it must actually match the kind of output the TFMA produces.
   tfma_plots = {
       metric_keys.AUC_PLOTS_MATRICES:
           np.array([
               [0.0, 0.0, 1.0, 1.0, 0.5, 1.0],
               [0.0, 0.0, 1.0, 1.0, 0.5, 1.0],
               [1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
               [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
           ]),
       metric_keys.AUC_PLOTS_THRESHOLDS:
           np.array([1e-6, 0, 0.5, 1.0]),
   }
   expected_plot_data = """
     confusion_matrix_at_thresholds {
       matrices {
         threshold: 1e-6
         true_positives: 1.0
         false_positives: 1.0
         true_negatives: 0.0
         false_negatives: 0.0
         precision: 0.5
         recall: 1.0
       }
     }
     confusion_matrix_at_thresholds {
       matrices {
         threshold: 0
         true_positives: 1.0
         false_positives: 1.0
         true_negatives: 0.0
         false_negatives: 0.0
         precision: 0.5
         recall: 1.0
       }
     }
     confusion_matrix_at_thresholds {
       matrices {
         threshold: 0.5
         true_positives: 0.0
         false_positives: 1.0
         true_negatives: 0.0
         false_negatives: 1.0
         precision: 0.0
         recall: 0.0
       }
     }
     confusion_matrix_at_thresholds {
       matrices {
         threshold: 1.0
         true_positives: 0.0
         false_positives: 0.0
         true_negatives: 1.0
         false_negatives: 1.0
         precision: 0.0
         recall: 0.0
       }
     }
   """
   plot_data = metrics_for_slice_pb2.PlotData()
   auc_plots = post_export_metrics.auc_plots()
   auc_plots.populate_plots_and_pop(tfma_plots, plot_data)
   self.assertProtoEquals(expected_plot_data, plot_data)
   self.assertFalse(metric_keys.AUC_PLOTS_MATRICES in tfma_plots)
   self.assertFalse(metric_keys.AUC_PLOTS_THRESHOLDS in tfma_plots)
  def testAucPlotsUnweighted(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    examples = [
        self._makeExample(prediction=0.0000, label=0.0000),
        self._makeExample(prediction=0.0000, label=1.0000),
        self._makeExample(prediction=0.7000, label=1.0000),
        self._makeExample(prediction=0.8000, label=0.0000),
        self._makeExample(prediction=1.0000, label=1.0000),
    ]

    auc_plots = post_export_metrics.auc_plots()

    def check_result(got):  # pylint: disable=invalid-name
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        self.assertIn(metric_keys.AUC_PLOTS_MATRICES, value)
        matrices = value[metric_keys.AUC_PLOTS_MATRICES]
        #            |      | --------- Threshold -----------
        # true label | pred | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0
        #     -      | 0.0  | FP    | TN  | TN  | TN  | TN
        #     +      | 0.0  | TP    | FN  | FN  | FN  | FN
        #     +      | 0.7  | TP    | TP  | FN  | FN  | FN
        #     -      | 0.8  | FP    | FP  | FP  | TN  | TN
        #     +      | 1.0  | TP    | TP  | TP  | TP  | FN
        self.assertSequenceAlmostEqual(matrices[0],
                                       [0, 0, 2, 3, 3.0 / 5.0, 1.0])
        self.assertSequenceAlmostEqual(matrices[1],
                                       [1, 1, 1, 2, 2.0 / 3.0, 2.0 / 3.0])
        self.assertSequenceAlmostEqual(matrices[7001],
                                       [2, 1, 1, 1, 1.0 / 2.0, 1.0 / 3.0])
        self.assertSequenceAlmostEqual(matrices[8001],
                                       [2, 2, 0, 1, 1.0 / 1.0, 1.0 / 3.0])
        self.assertSequenceAlmostEqual(
            matrices[10001], [3, 2, 0, 0, float('nan'), 0.0])
        self.assertIn(metric_keys.AUC_PLOTS_THRESHOLDS, value)
        thresholds = value[metric_keys.AUC_PLOTS_THRESHOLDS]
        self.assertAlmostEqual(0.0, thresholds[1])
        self.assertAlmostEqual(0.001, thresholds[11])
        self.assertAlmostEqual(0.005, thresholds[51])
        self.assertAlmostEqual(0.010, thresholds[101])
        self.assertAlmostEqual(0.100, thresholds[1001])
        self.assertAlmostEqual(0.800, thresholds[8001])
        self.assertAlmostEqual(1.000, thresholds[10001])
        plot_data = metrics_for_slice_pb2.PlotData()
        auc_plots.populate_plots_and_pop(value, plot_data)
        self.assertProtoEquals(
            """threshold: 1.0
            false_negatives: 3.0
            true_negatives: 2.0
            precision: nan""",
            plot_data.confusion_matrix_at_thresholds.matrices[10001])
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples, eval_export_dir, [auc_plots], custom_plots_check=check_result)