def testSerializePlots_emptyPlot(self): slice_key = _make_slice_key('fruit', 'apple') tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'} actual_plot = metrics_and_plots_serialization._serialize_plots( (slice_key, tfma_plots), []) expected_plot = metrics_for_slice_pb2.PlotsForSlice() expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key)) expected_plot.plots[ metric_keys.ERROR_METRIC].debug_message = 'error_message' self.assertProtoEquals( expected_plot, metrics_for_slice_pb2.PlotsForSlice.FromString(actual_plot))
def testSerializePlots(self): slice_key = _make_slice_key('fruit', 'apple') plot_key = metric_types.PlotKey(name='calibration_plot', output_name='output_name') calibration_plot = text_format.Parse( """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } """, metrics_for_slice_pb2.CalibrationHistogramBuckets()) expected_plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: 'fruit' bytes_value: 'apple' } } plot_keys_and_values { key { output_name: "output_name" } value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) got = metrics_and_plots_serialization._serialize_plots( (slice_key, { plot_key: calibration_plot }), None) self.assertProtoEquals( expected_plots_for_slice, metrics_for_slice_pb2.PlotsForSlice.FromString(got))
def testSerializePlotsLegacyStringKeys(self): slice_key = _make_slice_key('fruit', 'apple') tfma_plots = { metric_keys.CALIBRATION_PLOT_MATRICES: np.array([ [0.0, 0.0, 0.0], [0.3, 1.0, 1.0], [0.7, 0.0, 1.0], [0.0, 0.0, 0.0], ]), metric_keys.CALIBRATION_PLOT_BOUNDARIES: np.array([0.0, 0.5, 1.0]), } expected_plot_data = """ slice_key { single_slice_keys { column: 'fruit' bytes_value: 'apple' } } plots { key: "post_export_metrics" value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } } """ calibration_plot = ( post_export_metrics.calibration_plot_and_prediction_histogram()) serialized = metrics_and_plots_serialization._serialize_plots( (slice_key, tfma_plots), [calibration_plot]) self.assertProtoEquals( expected_plot_data, metrics_for_slice_pb2.PlotsForSlice.FromString(serialized))