def testCalibrationPlotSerialization(self): # Calibration plots for the model # {prediction:0.3, true_label:+}, # {prediction:0.7, true_label:-} # # These plots were generated by hand. For this test to make sense # it must actually match the kind of output the TFMA produces. tfma_plots = { metric_keys.CALIBRATION_PLOT_MATRICES: np.array([ [0.0, 0.0, 0.0], [0.3, 1.0, 1.0], [0.7, 0.0, 1.0], [0.0, 0.0, 0.0], ]), metric_keys.CALIBRATION_PLOT_BOUNDARIES: np.array([0.0, 0.5, 1.0]), } expected_plot_data = """ calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } """ plot_data = metrics_for_slice_pb2.PlotData() calibration_plot = ( post_export_metrics.calibration_plot_and_prediction_histogram()) calibration_plot.populate_plots_and_pop(tfma_plots, plot_data) self.assertProtoEquals(expected_plot_data, plot_data) self.assertFalse(metric_keys.CALIBRATION_PLOT_MATRICES in tfma_plots) self.assertFalse(metric_keys.CALIBRATION_PLOT_BOUNDARIES in tfma_plots)
def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn(metric_keys.AUC_PLOTS_MATRICES, value) matrices = value[metric_keys.AUC_PLOTS_MATRICES] # | | --------- Threshold ----------- # true label | pred | -1e-6 | 0.0 | 0.7 | 0.8 | 1.0 # - | 0.0 | FP | TN | TN | TN | TN # + | 0.0 | TP | FN | FN | FN | FN # + | 0.7 | TP | TP | FN | FN | FN # - | 0.8 | FP | FP | FP | TN | TN # + | 1.0 | TP | TP | TP | TP | FN self.assertSequenceAlmostEqual(matrices[0], [0, 0, 2, 3, 3.0 / 5.0, 1.0]) self.assertSequenceAlmostEqual( matrices[1], [1, 1, 1, 2, 2.0 / 3.0, 2.0 / 3.0]) self.assertSequenceAlmostEqual( matrices[7001], [2, 1, 1, 1, 1.0 / 2.0, 1.0 / 3.0]) self.assertSequenceAlmostEqual( matrices[8001], [2, 2, 0, 1, 1.0 / 1.0, 1.0 / 3.0]) self.assertSequenceAlmostEqual( matrices[10001], [3, 2, 0, 0, float('nan'), 0.0]) self.assertIn(metric_keys.AUC_PLOTS_THRESHOLDS, value) thresholds = value[metric_keys.AUC_PLOTS_THRESHOLDS] self.assertAlmostEqual(0.0, thresholds[1]) self.assertAlmostEqual(0.001, thresholds[11]) self.assertAlmostEqual(0.005, thresholds[51]) self.assertAlmostEqual(0.010, thresholds[101]) self.assertAlmostEqual(0.100, thresholds[1001]) self.assertAlmostEqual(0.800, thresholds[8001]) self.assertAlmostEqual(1.000, thresholds[10001]) plot_data = metrics_for_slice_pb2.PlotData() auc_plots.populate_plots_and_pop(value, plot_data) self.assertProtoEquals( """threshold: 1.0 false_negatives: 3.0 true_negatives: 2.0 precision: nan""", plot_data.confusion_matrix_at_thresholds.matrices[10001]) except AssertionError as err: raise util.BeamAssertException(err)
def check_result(got): # pylint: disable=invalid-name try: self.assertEqual(1, len(got), 'got: %s' % got) (slice_key, value) = got[0] self.assertEqual((), slice_key) self.assertIn(metric_keys.CALIBRATION_PLOT_MATRICES, value) buckets = value[metric_keys.CALIBRATION_PLOT_MATRICES] self.assertSequenceAlmostEqual(buckets[0], [-19.0, -17.0, 2.0]) self.assertSequenceAlmostEqual(buckets[1], [0.0, 1.0, 1.0]) self.assertSequenceAlmostEqual(buckets[11], [0.00303, 3.00303, 3.0]) self.assertSequenceAlmostEqual(buckets[10000], [1.99997, 3.99997, 2.0]) self.assertSequenceAlmostEqual(buckets[10001], [28.0, 32.0, 4.0]) self.assertIn(metric_keys.CALIBRATION_PLOT_BOUNDARIES, value) boundaries = value[metric_keys.CALIBRATION_PLOT_BOUNDARIES] self.assertAlmostEqual(0.0, boundaries[0]) self.assertAlmostEqual(0.001, boundaries[10]) self.assertAlmostEqual(0.005, boundaries[50]) self.assertAlmostEqual(0.010, boundaries[100]) self.assertAlmostEqual(0.100, boundaries[1000]) self.assertAlmostEqual(0.800, boundaries[8000]) self.assertAlmostEqual(1.000, boundaries[10000]) plot_data = metrics_for_slice_pb2.PlotData() calibration_plot.populate_plots_and_pop(value, plot_data) self.assertProtoEquals( """lower_threshold_inclusive:1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 4.0 } total_weighted_label { value: 32.0 } total_weighted_refined_prediction { value: 28.0 }""", plot_data.calibration_histogram_buckets.buckets[10001]) except AssertionError as err: raise util.BeamAssertException(err)
def testAucPlotSerialization(self): # Auc for the model # {prediction:0.3, true_label:+}, # {prediction:0.7, true_label:-} # # These plots were generated by hand. For this test to make sense # it must actually match the kind of output the TFMA produces. tfma_plots = { metric_keys.AUC_PLOTS_MATRICES: np.array([ [0.0, 0.0, 1.0, 1.0, 0.5, 1.0], [0.0, 0.0, 1.0, 1.0, 0.5, 1.0], [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], ]), metric_keys.AUC_PLOTS_THRESHOLDS: np.array([1e-6, 0, 0.5, 1.0]), } expected_plot_data = """ confusion_matrix_at_thresholds { matrices { threshold: 1e-6 true_positives: 1.0 false_positives: 1.0 true_negatives: 0.0 false_negatives: 0.0 precision: 0.5 recall: 1.0 } } confusion_matrix_at_thresholds { matrices { threshold: 0 true_positives: 1.0 false_positives: 1.0 true_negatives: 0.0 false_negatives: 0.0 precision: 0.5 recall: 1.0 } } confusion_matrix_at_thresholds { matrices { threshold: 0.5 true_positives: 0.0 false_positives: 1.0 true_negatives: 0.0 false_negatives: 1.0 precision: 0.0 recall: 0.0 } } confusion_matrix_at_thresholds { matrices { threshold: 1.0 true_positives: 0.0 false_positives: 0.0 true_negatives: 1.0 false_negatives: 1.0 precision: 0.0 recall: 0.0 } } """ plot_data = metrics_for_slice_pb2.PlotData() auc_plots = post_export_metrics.auc_plots() auc_plots.populate_plots_and_pop(tfma_plots, plot_data) self.assertProtoEquals(expected_plot_data, plot_data) self.assertFalse(metric_keys.AUC_PLOTS_MATRICES in tfma_plots) self.assertFalse(metric_keys.AUC_PLOTS_THRESHOLDS in tfma_plots)