示例#1
0
 def testCalibrationPlotSerialization(self):
     # Calibration plots for the model
     # {prediction:0.3, true_label:+},
     # {prediction:0.7, true_label:-}
     #
     # These plots were generated by hand. For this test to make sense
     # it must actually match the kind of output the TFMA produces.
     tfma_plots = {
         metric_keys.CALIBRATION_PLOT_MATRICES:
         np.array([
             [0.0, 0.0, 0.0],
             [0.3, 1.0, 1.0],
             [0.7, 0.0, 1.0],
             [0.0, 0.0, 0.0],
         ]),
         metric_keys.CALIBRATION_PLOT_BOUNDARIES:
         np.array([0.0, 0.5, 1.0]),
     }
     expected_plot_data = """
   calibration_histogram_buckets {
     buckets {
       lower_threshold_inclusive: -inf
       upper_threshold_exclusive: 0.0
       num_weighted_examples { value: 0.0 }
       total_weighted_label { value: 0.0 }
       total_weighted_refined_prediction { value: 0.0 }
     }
     buckets {
       lower_threshold_inclusive: 0.0
       upper_threshold_exclusive: 0.5
       num_weighted_examples { value: 1.0 }
       total_weighted_label { value: 1.0 }
       total_weighted_refined_prediction { value: 0.3 }
     }
     buckets {
       lower_threshold_inclusive: 0.5
       upper_threshold_exclusive: 1.0
       num_weighted_examples { value: 1.0 }
       total_weighted_label { value: 0.0 }
       total_weighted_refined_prediction { value: 0.7 }
     }
     buckets {
       lower_threshold_inclusive: 1.0
       upper_threshold_exclusive: inf
       num_weighted_examples { value: 0.0 }
       total_weighted_label { value: 0.0 }
       total_weighted_refined_prediction { value: 0.0 }
     }
   }
 """
     plot_data = metrics_for_slice_pb2.PlotData()
     calibration_plot = (
         post_export_metrics.calibration_plot_and_prediction_histogram())
     calibration_plot.populate_plots_and_pop(tfma_plots, plot_data)
     self.assertProtoEquals(expected_plot_data, plot_data)
     self.assertFalse(metric_keys.CALIBRATION_PLOT_MATRICES in tfma_plots)
     self.assertFalse(metric_keys.CALIBRATION_PLOT_BOUNDARIES in tfma_plots)
示例#2
0
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         self.assertIn(metric_keys.AUC_PLOTS_MATRICES, value)
         matrices = value[metric_keys.AUC_PLOTS_MATRICES]
         #            |      | --------- Threshold -----------
         # true label | pred | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0
         #     -      | 0.0  | FP    | TN  | TN  | TN  | TN
         #     +      | 0.0  | TP    | FN  | FN  | FN  | FN
         #     +      | 0.7  | TP    | TP  | FN  | FN  | FN
         #     -      | 0.8  | FP    | FP  | FP  | TN  | TN
         #     +      | 1.0  | TP    | TP  | TP  | TP  | FN
         self.assertSequenceAlmostEqual(matrices[0],
                                        [0, 0, 2, 3, 3.0 / 5.0, 1.0])
         self.assertSequenceAlmostEqual(
             matrices[1], [1, 1, 1, 2, 2.0 / 3.0, 2.0 / 3.0])
         self.assertSequenceAlmostEqual(
             matrices[7001], [2, 1, 1, 1, 1.0 / 2.0, 1.0 / 3.0])
         self.assertSequenceAlmostEqual(
             matrices[8001], [2, 2, 0, 1, 1.0 / 1.0, 1.0 / 3.0])
         self.assertSequenceAlmostEqual(
             matrices[10001],
             [3, 2, 0, 0, float('nan'), 0.0])
         self.assertIn(metric_keys.AUC_PLOTS_THRESHOLDS, value)
         thresholds = value[metric_keys.AUC_PLOTS_THRESHOLDS]
         self.assertAlmostEqual(0.0, thresholds[1])
         self.assertAlmostEqual(0.001, thresholds[11])
         self.assertAlmostEqual(0.005, thresholds[51])
         self.assertAlmostEqual(0.010, thresholds[101])
         self.assertAlmostEqual(0.100, thresholds[1001])
         self.assertAlmostEqual(0.800, thresholds[8001])
         self.assertAlmostEqual(1.000, thresholds[10001])
         plot_data = metrics_for_slice_pb2.PlotData()
         auc_plots.populate_plots_and_pop(value, plot_data)
         self.assertProtoEquals(
             """threshold: 1.0
     false_negatives: 3.0
     true_negatives: 2.0
     precision: nan""",
             plot_data.confusion_matrix_at_thresholds.matrices[10001])
     except AssertionError as err:
         raise util.BeamAssertException(err)
示例#3
0
 def check_result(got):  # pylint: disable=invalid-name
     try:
         self.assertEqual(1, len(got), 'got: %s' % got)
         (slice_key, value) = got[0]
         self.assertEqual((), slice_key)
         self.assertIn(metric_keys.CALIBRATION_PLOT_MATRICES, value)
         buckets = value[metric_keys.CALIBRATION_PLOT_MATRICES]
         self.assertSequenceAlmostEqual(buckets[0], [-19.0, -17.0, 2.0])
         self.assertSequenceAlmostEqual(buckets[1], [0.0, 1.0, 1.0])
         self.assertSequenceAlmostEqual(buckets[11],
                                        [0.00303, 3.00303, 3.0])
         self.assertSequenceAlmostEqual(buckets[10000],
                                        [1.99997, 3.99997, 2.0])
         self.assertSequenceAlmostEqual(buckets[10001],
                                        [28.0, 32.0, 4.0])
         self.assertIn(metric_keys.CALIBRATION_PLOT_BOUNDARIES, value)
         boundaries = value[metric_keys.CALIBRATION_PLOT_BOUNDARIES]
         self.assertAlmostEqual(0.0, boundaries[0])
         self.assertAlmostEqual(0.001, boundaries[10])
         self.assertAlmostEqual(0.005, boundaries[50])
         self.assertAlmostEqual(0.010, boundaries[100])
         self.assertAlmostEqual(0.100, boundaries[1000])
         self.assertAlmostEqual(0.800, boundaries[8000])
         self.assertAlmostEqual(1.000, boundaries[10000])
         plot_data = metrics_for_slice_pb2.PlotData()
         calibration_plot.populate_plots_and_pop(value, plot_data)
         self.assertProtoEquals(
             """lower_threshold_inclusive:1.0
     upper_threshold_exclusive: inf
     num_weighted_examples {
       value: 4.0
     }
     total_weighted_label {
       value: 32.0
     }
     total_weighted_refined_prediction {
       value: 28.0
     }""", plot_data.calibration_histogram_buckets.buckets[10001])
     except AssertionError as err:
         raise util.BeamAssertException(err)
示例#4
0
 def testAucPlotSerialization(self):
     # Auc for the model
     # {prediction:0.3, true_label:+},
     # {prediction:0.7, true_label:-}
     #
     # These plots were generated by hand. For this test to make sense
     # it must actually match the kind of output the TFMA produces.
     tfma_plots = {
         metric_keys.AUC_PLOTS_MATRICES:
         np.array([
             [0.0, 0.0, 1.0, 1.0, 0.5, 1.0],
             [0.0, 0.0, 1.0, 1.0, 0.5, 1.0],
             [1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
             [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
         ]),
         metric_keys.AUC_PLOTS_THRESHOLDS:
         np.array([1e-6, 0, 0.5, 1.0]),
     }
     expected_plot_data = """
   confusion_matrix_at_thresholds {
     matrices {
       threshold: 1e-6
       true_positives: 1.0
       false_positives: 1.0
       true_negatives: 0.0
       false_negatives: 0.0
       precision: 0.5
       recall: 1.0
     }
   }
   confusion_matrix_at_thresholds {
     matrices {
       threshold: 0
       true_positives: 1.0
       false_positives: 1.0
       true_negatives: 0.0
       false_negatives: 0.0
       precision: 0.5
       recall: 1.0
     }
   }
   confusion_matrix_at_thresholds {
     matrices {
       threshold: 0.5
       true_positives: 0.0
       false_positives: 1.0
       true_negatives: 0.0
       false_negatives: 1.0
       precision: 0.0
       recall: 0.0
     }
   }
   confusion_matrix_at_thresholds {
     matrices {
       threshold: 1.0
       true_positives: 0.0
       false_positives: 0.0
       true_negatives: 1.0
       false_negatives: 1.0
       precision: 0.0
       recall: 0.0
     }
   }
 """
     plot_data = metrics_for_slice_pb2.PlotData()
     auc_plots = post_export_metrics.auc_plots()
     auc_plots.populate_plots_and_pop(tfma_plots, plot_data)
     self.assertProtoEquals(expected_plot_data, plot_data)
     self.assertFalse(metric_keys.AUC_PLOTS_MATRICES in tfma_plots)
     self.assertFalse(metric_keys.AUC_PLOTS_THRESHOLDS in tfma_plots)