Esempio n. 1
0
  def testSerializePlots_emptyPlot(self):
    slice_key = _make_slice_key('fruit', 'apple')
    tfma_plots = {metric_keys.ERROR_METRIC: 'error_message'}

    actual_plot = metrics_and_plots_serialization._serialize_plots(
        (slice_key, tfma_plots), [])
    expected_plot = metrics_for_slice_pb2.PlotsForSlice()
    expected_plot.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))
    expected_plot.plots[
        metric_keys.ERROR_METRIC].debug_message = 'error_message'
    self.assertProtoEquals(
        expected_plot,
        metrics_for_slice_pb2.PlotsForSlice.FromString(actual_plot))
Esempio n. 2
0
    def testSerializePlots(self):
        slice_key = _make_slice_key('fruit', 'apple')
        plot_key = metric_types.PlotKey(name='calibration_plot',
                                        output_name='output_name')
        calibration_plot = text_format.Parse(
            """
        buckets {
          lower_threshold_inclusive: -inf
          upper_threshold_exclusive: 0.0
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
        buckets {
          lower_threshold_inclusive: 0.0
          upper_threshold_exclusive: 0.5
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 1.0 }
          total_weighted_refined_prediction { value: 0.3 }
        }
        buckets {
          lower_threshold_inclusive: 0.5
          upper_threshold_exclusive: 1.0
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.7 }
        }
        buckets {
          lower_threshold_inclusive: 1.0
          upper_threshold_exclusive: inf
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
     """, metrics_for_slice_pb2.CalibrationHistogramBuckets())

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {
        single_slice_keys {
          column: 'fruit'
          bytes_value: 'apple'
        }
      }
      plot_keys_and_values {
        key {
          output_name: "output_name"
        }
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              upper_threshold_exclusive: 0.0
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
            buckets {
              lower_threshold_inclusive: 0.0
              upper_threshold_exclusive: 0.5
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 1.0 }
              total_weighted_refined_prediction { value: 0.3 }
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.7 }
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
          }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        got = metrics_and_plots_serialization._serialize_plots(
            (slice_key, {
                plot_key: calibration_plot
            }), None)
        self.assertProtoEquals(
            expected_plots_for_slice,
            metrics_for_slice_pb2.PlotsForSlice.FromString(got))
Esempio n. 3
0
 def testSerializePlotsLegacyStringKeys(self):
     slice_key = _make_slice_key('fruit', 'apple')
     tfma_plots = {
         metric_keys.CALIBRATION_PLOT_MATRICES:
         np.array([
             [0.0, 0.0, 0.0],
             [0.3, 1.0, 1.0],
             [0.7, 0.0, 1.0],
             [0.0, 0.0, 0.0],
         ]),
         metric_keys.CALIBRATION_PLOT_BOUNDARIES:
         np.array([0.0, 0.5, 1.0]),
     }
     expected_plot_data = """
   slice_key {
     single_slice_keys {
       column: 'fruit'
       bytes_value: 'apple'
     }
   }
   plots {
     key: "post_export_metrics"
     value {
       calibration_histogram_buckets {
         buckets {
           lower_threshold_inclusive: -inf
           upper_threshold_exclusive: 0.0
           num_weighted_examples { value: 0.0 }
           total_weighted_label { value: 0.0 }
           total_weighted_refined_prediction { value: 0.0 }
         }
         buckets {
           lower_threshold_inclusive: 0.0
           upper_threshold_exclusive: 0.5
           num_weighted_examples { value: 1.0 }
           total_weighted_label { value: 1.0 }
           total_weighted_refined_prediction { value: 0.3 }
         }
         buckets {
           lower_threshold_inclusive: 0.5
           upper_threshold_exclusive: 1.0
           num_weighted_examples { value: 1.0 }
           total_weighted_label { value: 0.0 }
           total_weighted_refined_prediction { value: 0.7 }
         }
         buckets {
           lower_threshold_inclusive: 1.0
           upper_threshold_exclusive: inf
           num_weighted_examples { value: 0.0 }
           total_weighted_label { value: 0.0 }
           total_weighted_refined_prediction { value: 0.0 }
         }
       }
     }
   }
 """
     calibration_plot = (
         post_export_metrics.calibration_plot_and_prediction_histogram())
     serialized = metrics_and_plots_serialization._serialize_plots(
         (slice_key, tfma_plots), [calibration_plot])
     self.assertProtoEquals(
         expected_plot_data,
         metrics_for_slice_pb2.PlotsForSlice.FromString(serialized))