示例#1
0
def _to_proto(
    thresholds: List[float], histogram: calibration_histogram.Histogram
) -> metrics_for_slice_pb2.CalibrationHistogramBuckets:
  """Converts histogram into CalibrationHistogramBuckets proto.

  Args:
    thresholds: Thresholds associated with histogram buckets.
    histogram: Calibration histogram.

  Returns:
    A histogram in CalibrationHistogramBuckets proto format.
  """
  pb = metrics_for_slice_pb2.CalibrationHistogramBuckets()
  lower_threshold = float('-inf')
  for i, bucket in enumerate(histogram):
    if i >= len(thresholds) - 1:
      upper_threshold = float('inf')
    else:
      upper_threshold = thresholds[i + 1]
    pb.buckets.add(
        lower_threshold_inclusive=lower_threshold,
        upper_threshold_exclusive=upper_threshold,
        total_weighted_label={'value': bucket.weighted_labels},
        total_weighted_refined_prediction={
            'value': bucket.weighted_predictions
        },
        num_weighted_examples={'value': bucket.weighted_examples})
    lower_threshold = upper_threshold
  return pb
示例#2
0
    def testSerializePlots(self):
        slice_key = _make_slice_key('fruit', 'apple')
        plot_key = metric_types.PlotKey(name='calibration_plot',
                                        output_name='output_name')
        calibration_plot = text_format.Parse(
            """
        buckets {
          lower_threshold_inclusive: -inf
          upper_threshold_exclusive: 0.0
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
        buckets {
          lower_threshold_inclusive: 0.0
          upper_threshold_exclusive: 0.5
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 1.0 }
          total_weighted_refined_prediction { value: 0.3 }
        }
        buckets {
          lower_threshold_inclusive: 0.5
          upper_threshold_exclusive: 1.0
          num_weighted_examples { value: 1.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.7 }
        }
        buckets {
          lower_threshold_inclusive: 1.0
          upper_threshold_exclusive: inf
          num_weighted_examples { value: 0.0 }
          total_weighted_label { value: 0.0 }
          total_weighted_refined_prediction { value: 0.0 }
        }
     """, metrics_for_slice_pb2.CalibrationHistogramBuckets())

        expected_plots_for_slice = text_format.Parse(
            """
      slice_key {
        single_slice_keys {
          column: 'fruit'
          bytes_value: 'apple'
        }
      }
      plot_keys_and_values {
        key {
          output_name: "output_name"
        }
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              upper_threshold_exclusive: 0.0
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
            buckets {
              lower_threshold_inclusive: 0.0
              upper_threshold_exclusive: 0.5
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 1.0 }
              total_weighted_refined_prediction { value: 0.3 }
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples { value: 1.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.7 }
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples { value: 0.0 }
              total_weighted_label { value: 0.0 }
              total_weighted_refined_prediction { value: 0.0 }
            }
          }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

        got = metrics_and_plots_serialization._serialize_plots(
            (slice_key, {
                plot_key: calibration_plot
            }), None)
        self.assertProtoEquals(
            expected_plots_for_slice,
            metrics_for_slice_pb2.PlotsForSlice.FromString(got))