def testPlotKeyFromProto(self): plot_keys = [ metric_types.PlotKey(name=''), metric_types.PlotKey(name='', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1)), metric_types.MetricKey(name='', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(top_k=2)) ] for key in plot_keys: got_key = metric_types.PlotKey.from_proto(key.to_proto()) self.assertEqual(key, got_key, '{} != {}'.format(key, got_key))
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey( name='multi_class_confusion_matrix_plot') got_matrix = got_plots[key] self.assertProtoEquals( """ matrices { threshold: 0.0 entries { actual_class_id: 0 predicted_class_id: 0 num_weighted_examples: 0.5 } entries { actual_class_id: 2 predicted_class_id: 2 num_weighted_examples: 1.0 } } """, got_matrix) except AssertionError as err: raise util.BeamAssertException(err)
def _multi_label_confusion_matrix_plot( thresholds: Optional[List[float]] = None, num_thresholds: Optional[int] = None, name: Text = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', ) -> metric_types.MetricComputations: """Returns computations for multi-label confusion matrix at thresholds.""" if num_thresholds is not None and thresholds is not None: raise ValueError( 'only one of thresholds or num_thresholds can be set at a time') if num_thresholds is None and thresholds is None: thresholds = [0.5] if num_thresholds is not None: thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)] thresholds = [-_EPSILON] + thresholds + [1.0 + _EPSILON] key = metric_types.PlotKey(name=name, model_name=model_name, output_name=output_name) return [ metric_types.MetricComputation( keys=[key], preprocessor=None, combiner=_MultiLabelConfusionMatrixPlotCombiner( key=key, eval_config=eval_config, thresholds=thresholds)) ]
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey(name='auc_plot') self.assertIn(key, got_plots) got_plot = got_plots[key] self.assertProtoEquals( """ matrices { threshold: -1e-06 false_positives: 2.0 true_positives: 2.0 precision: 0.5 recall: 1.0 } matrices { true_negatives: 1.0 false_positives: 1.0 true_positives: 2.0 precision: 0.6666667 recall: 1.0 } matrices { threshold: 0.25 true_negatives: 1.0 false_positives: 1.0 true_positives: 2.0 precision: 0.6666667 recall: 1.0 } matrices { threshold: 0.5 false_negatives: 1.0 true_negatives: 2.0 true_positives: 1.0 precision: 1.0 recall: 0.5 } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 2.0 true_positives: 1.0 precision: 1.0 recall: 0.5 } matrices { threshold: 1.0 false_negatives: 2.0 true_negatives: 2.0 precision: 1.0 recall: 0.0 } """, got_plot) except AssertionError as err: raise util.BeamAssertException(err)
def _calibration_plot( num_buckets: int = DEFAULT_NUM_BUCKETS, left: Optional[float] = None, right: Optional[float] = None, name: Text = CALIBRATION_PLOT_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, schema: Optional[schema_pb2.Schema] = None, model_name: Text = '', output_name: Text = '', sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None ) -> metric_types.MetricComputations: """Returns metric computations for calibration plot.""" key = metric_types.PlotKey( name=name, model_name=model_name, output_name=output_name, sub_key=sub_key) label_left, label_right = None, None if (left is None or right is None) and eval_config and schema: label_left, label_right = _find_label_domain(eval_config, schema, model_name, output_name) if left is None: left = label_left if label_left is not None else 0.0 if right is None: right = label_right if label_right is not None else 1.0 # Make sure calibration histogram is calculated. Note we are using the default # number of buckets assigned to the histogram instead of the value used for # the plots just in case the computation is shared with other metrics and # plots that need higher preicion. It will be downsampled later. computations = calibration_histogram.calibration_histogram( eval_config=eval_config, model_name=model_name, output_name=output_name, sub_key=sub_key, left=left, right=right, aggregation_type=aggregation_type, class_weights=class_weights) histogram_key = computations[-1].keys[-1] def result( metrics: Dict[metric_types.MetricKey, Any] ) -> Dict[metric_types.MetricKey, Any]: thresholds = [ left + i * (right - left) / num_buckets for i in range(num_buckets + 1) ] thresholds = [float('-inf')] + thresholds histogram = calibration_histogram.rebin( thresholds, metrics[histogram_key], left=left, right=right) return {key: _to_proto(thresholds, histogram)} derived_computation = metric_types.DerivedMetricComputation( keys=[key], result=result) computations.append(derived_computation) return computations
def calibration_histogram( num_buckets: Optional[int] = None, left: Optional[float] = None, right: Optional[float] = None, name: Text = None, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None ) -> metric_types.MetricComputations: """Returns metric computations for calibration histogram. Args: num_buckets: Number of buckets to use. Note that the actual number of buckets will be num_buckets + 2 to account for the edge cases. left: Start of predictions interval. right: End of predictions interval. name: Metric name. eval_config: Eval config. model_name: Optional model name (if multi-model evaluation). output_name: Optional output name (if multi-output model type). sub_key: Optional sub key. aggregation_type: Optional aggregation type. class_weights: Optional class weights to apply to multi-class / multi-label labels and predictions prior to flattening (when micro averaging is used). Returns: MetricComputations for computing the histogram(s). """ if num_buckets is None: num_buckets = DEFAULT_NUM_BUCKETS if left is None: left = 0.0 if right is None: right = 1.0 if name is None: name = '{}_{}'.format(CALIBRATION_HISTOGRAM_NAME, num_buckets) key = metric_types.PlotKey( name=name, model_name=model_name, output_name=output_name, sub_key=sub_key) return [ metric_types.MetricComputation( keys=[key], preprocessor=None, combiner=_CalibrationHistogramCombiner( key=key, eval_config=eval_config, aggregation_type=aggregation_type, class_weights=class_weights, num_buckets=num_buckets, left=left, right=right)) ]
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey( name='_calibration_histogram_10000', sub_key=metric_types.SubKey(top_k=2), example_weighted=True) self.assertIn(key, got_plots) got_histogram = got_plots[key] self.assertLen(got_histogram, 5) self.assertEqual( got_histogram[0], calibration_histogram.Bucket( bucket_id=0, weighted_labels=3.0 + 4.0, weighted_predictions=(2 * 1.0 * float('-inf') + 2 * 2.0 * float('-inf') + 2 * 3.0 * float('-inf') + 2 * 4.0 * float('-inf') + -0.1 * 4.0), weighted_examples=(1.0 * 2.0 + 2.0 * 2.0 + 3.0 * 2.0 + 4.0 * 3.0))) self.assertEqual( got_histogram[1], calibration_histogram.Bucket( bucket_id=2001, weighted_labels=0.0 + 0.0, weighted_predictions=0.2 + 3 * 0.2, weighted_examples=1.0 + 3.0)) self.assertEqual( got_histogram[2], calibration_histogram.Bucket( bucket_id=5001, weighted_labels=1.0 + 0.0 * 3.0, weighted_predictions=0.5 * 1.0 + 0.5 * 3.0, weighted_examples=1.0 + 3.0)) self.assertEqual( got_histogram[3], calibration_histogram.Bucket( bucket_id=8001, weighted_labels=0.0 * 2.0 + 1.0 * 2.0, weighted_predictions=0.8 * 2.0 + 0.8 * 2.0, weighted_examples=2.0 + 2.0)) self.assertEqual( got_histogram[4], calibration_histogram.Bucket( bucket_id=10001, weighted_labels=0.0 * 4.0, weighted_predictions=1.1 * 4.0, weighted_examples=4.0)) except AssertionError as err: raise util.BeamAssertException(err)
def _confusion_matrix_plot( num_thresholds: int = DEFAULT_NUM_THRESHOLDS, name: Text = CONFUSION_MATRIX_PLOT_NAME, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None ) -> metric_types.MetricComputations: """Returns metric computations for confusion matrix plots.""" key = metric_types.PlotKey(name=name, model_name=model_name, output_name=output_name, sub_key=sub_key) # The interoploation strategy used here matches how the legacy post export # metrics calculated its plots. thresholds = [ i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1) ] thresholds = [-1e-6] + thresholds # Make sure matrices are calculated. matrices_computations = binary_confusion_matrices.binary_confusion_matrices( # Use a custom name since we have a custom interpolation strategy which # will cause the default naming used by the binary confusion matrix to be # very long. name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + '_' + name), eval_config=eval_config, model_name=model_name, output_name=output_name, sub_key=sub_key, aggregation_type=aggregation_type, class_weights=class_weights, thresholds=thresholds) matrices_key = matrices_computations[-1].keys[-1] def result( metrics: Dict[metric_types.MetricKey, Any] ) -> Dict[metric_types.MetricKey, metrics_for_slice_pb2.ConfusionMatrixAtThresholds]: return { key: confusion_matrix_metrics.to_proto(thresholds, metrics[matrices_key]) } derived_computation = metric_types.DerivedMetricComputation(keys=[key], result=result) computations = matrices_computations computations.append(derived_computation) return computations
def check_plots(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) plot_key = metric_types.PlotKey('calibration_plot') self.assertIn(plot_key, got_plots) # 10 buckets + 2 for edge cases self.assertLen(got_plots[plot_key].buckets, 12) except AssertionError as err: raise util.BeamAssertException(err)
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey( name='_calibration_histogram_10000', sub_key=metric_types.SubKey(k=2), example_weighted=True) self.assertIn(key, got_plots) got_histogram = got_plots[key] self.assertLen(got_histogram, 5) self.assertEqual( got_histogram[0], calibration_histogram.Bucket( bucket_id=0, weighted_labels=0.0 * 4.0, weighted_predictions=-0.2 * 4.0, weighted_examples=4.0)) self.assertEqual( got_histogram[1], calibration_histogram.Bucket( bucket_id=1001, weighted_labels=1.0 + 7 * 1.0, weighted_predictions=0.1 + 7 * 0.1, weighted_examples=1.0 + 7.0)) self.assertEqual( got_histogram[2], calibration_histogram.Bucket( bucket_id=4001, weighted_labels=1.0 * 3.0 + 0.0 * 5.0, weighted_predictions=0.4 * 3.0 + 0.4 * 5.0, weighted_examples=3.0 + 5.0)) self.assertEqual( got_histogram[3], calibration_histogram.Bucket( bucket_id=7001, weighted_labels=0.0 * 2.0 + 0.0 * 6.0, weighted_predictions=0.7 * 2.0 + 0.7 * 6.0, weighted_examples=2.0 + 6.0)) self.assertEqual( got_histogram[4], calibration_histogram.Bucket( bucket_id=10001, weighted_labels=0.0 * 8.0, weighted_predictions=1.05 * 8.0, weighted_examples=8.0)) except AssertionError as err: raise util.BeamAssertException(err)
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey('_calibration_histogram_10000') self.assertIn(key, got_plots) got_histogram = got_plots[key] self.assertLen(got_histogram, 5) self.assertEqual( got_histogram[0], calibration_histogram.Bucket( bucket_id=0, weighted_labels=1.0 * 4.0, weighted_predictions=-0.1 * 4.0, weighted_examples=4.0)) self.assertEqual( got_histogram[1], calibration_histogram.Bucket( bucket_id=2001, weighted_labels=0.0 + 0.0, weighted_predictions=0.2 + 7 * 0.2, weighted_examples=1.0 + 7.0)) self.assertEqual( got_histogram[2], calibration_histogram.Bucket( bucket_id=5001, weighted_labels=1.0 * 5.0, weighted_predictions=0.5 * 3.0 + 0.5 * 5.0, weighted_examples=3.0 + 5.0)) self.assertEqual( got_histogram[3], calibration_histogram.Bucket( bucket_id=8001, weighted_labels=1.0 * 2.0 + 1.0 * 6.0, weighted_predictions=0.8 * 2.0 + 0.8 * 6.0, weighted_examples=2.0 + 6.0)) self.assertEqual( got_histogram[4], calibration_histogram.Bucket(bucket_id=10001, weighted_labels=1.0 * 8.0, weighted_predictions=1.1 * 8.0, weighted_examples=8.0)) except AssertionError as err: raise util.BeamAssertException(err)
def _auc_plot( num_thresholds: int = DEFAULT_NUM_THRESHOLDS, name: Text = AUC_PLOT_NAME, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', sub_key: Optional[metric_types.SubKey] = None, class_weights: Optional[Dict[int, float]] = None ) -> metric_types.MetricComputations: """Returns metric computations for AUC plots.""" key = metric_types.PlotKey(name=name, model_name=model_name, output_name=output_name, sub_key=sub_key) # The interoploation stragety used here matches how the legacy post export # metrics calculated its plots. thresholds = [ i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1) ] thresholds = [-1e-6] + thresholds # Make sure matrices are calculated. matrices_computations = binary_confusion_matrices.binary_confusion_matrices( eval_config=eval_config, model_name=model_name, output_name=output_name, sub_key=sub_key, class_weights=class_weights, thresholds=thresholds) matrices_key = matrices_computations[-1].keys[-1] def result( metrics: Dict[metric_types.MetricKey, Any] ) -> Dict[metric_types.MetricKey, metrics_for_slice_pb2.ConfusionMatrixAtThresholds]: return { key: confusion_matrix_at_thresholds.to_proto(thresholds, metrics[matrices_key]) } derived_computation = metric_types.DerivedMetricComputation(keys=[key], result=result) computations = matrices_computations computations.append(derived_computation) return computations
def _calibration_plot( num_buckets: int = DEFAULT_NUM_BUCKETS, left: float = 0.0, right: float = 1.0, name: Text = CALIBRATION_PLOT_NAME, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', sub_key: Optional[metric_types.SubKey] = None ) -> metric_types.MetricComputations: """Returns metric computations for calibration plot.""" key = metric_types.PlotKey(name=name, model_name=model_name, output_name=output_name, sub_key=sub_key) # Make sure calibration histogram is calculated. Note we are using the default # number of buckets assigned to the histogram instead of the value used for # the plots just in case the computation is shared with other metrics and # plots that need higher preicion. It will be downsampled later. computations = calibration_histogram.calibration_histogram( eval_config=eval_config, model_name=model_name, output_name=output_name, sub_key=sub_key, left=left, right=right) histogram_key = computations[-1].keys[-1] def result( metrics: Dict[metric_types.MetricKey, Any] ) -> Dict[metric_types.MetricKey, Any]: thresholds = [ left + i * (right - left) / num_buckets for i in range(num_buckets + 1) ] thresholds = [float('-inf')] + thresholds histogram = calibration_histogram.rebin(thresholds, metrics[histogram_key], left=left, right=right) return {key: _to_proto(thresholds, histogram)} derived_computation = metric_types.DerivedMetricComputation(keys=[key], result=result) computations.append(derived_computation) return computations
def _multi_class_confusion_matrix_at_thresholds( thresholds: Optional[List[float]] = None, name: Text = MULTI_CLASS_CONFUSION_MATRIX_AT_THRESHOLDS_NAME, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', ) -> metric_types.MetricComputations: """Returns computations for multi-class confusion matrix at thresholds.""" key = metric_types.PlotKey( name=name, model_name=model_name, output_name=output_name) return [ metric_types.MetricComputation( keys=[key], preprocessor=None, combiner=_MultiClassConfusionMatrixAtThresholdsCombiner( key=key, eval_config=eval_config, thresholds=thresholds)) ]
def _multi_class_confusion_matrix_plot( thresholds: Optional[List[float]] = None, num_thresholds: Optional[int] = None, name: str = MULTI_CLASS_CONFUSION_MATRIX_PLOT_NAME, eval_config: Optional[config_pb2.EvalConfig] = None, model_name: str = '', output_name: str = '', example_weighted: bool = False) -> metric_types.MetricComputations: """Returns computations for multi-class confusion matrix plot.""" if num_thresholds is None and thresholds is None: thresholds = [0.0] key = metric_types.PlotKey( name=name, model_name=model_name, output_name=output_name, example_weighted=example_weighted) # Make sure matrices are calculated. matrices_computations = ( multi_class_confusion_matrix_metrics.multi_class_confusion_matrices( thresholds=thresholds, num_thresholds=num_thresholds, eval_config=eval_config, model_name=model_name, output_name=output_name, example_weighted=example_weighted)) matrices_key = matrices_computations[-1].keys[-1] def result( metrics: Dict[metric_types.MetricKey, multi_class_confusion_matrix_metrics.Matrices] ) -> Dict[metric_types.PlotKey, metrics_for_slice_pb2.MultiClassConfusionMatrixAtThresholds]: return { key: metrics[matrices_key].to_proto() .multi_class_confusion_matrix_at_thresholds } derived_computation = metric_types.DerivedMetricComputation( keys=[key], result=result) computations = matrices_computations computations.append(derived_computation) return computations
def _multi_label_confusion_matrix_plot( thresholds: Optional[List[float]] = None, name: Text = MULTI_LABEL_CONFUSION_MATRIX_PLOT_NAME, eval_config: Optional[config.EvalConfig] = None, model_name: Text = '', output_name: Text = '', ) -> metric_types.MetricComputations: """Returns computations for multi-label confusion matrix at thresholds.""" key = metric_types.PlotKey(name=name, model_name=model_name, output_name=output_name) return [ metric_types.MetricComputation( keys=[key], preprocessor=None, combiner=_MultiLabelConfusionMatrixPlotCombiner( key=key, eval_config=eval_config, thresholds=thresholds)) ]
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey( name='multi_label_confusion_matrix_plot') got_matrix = got_plots[key] self.assertProtoEquals( """ matrices { threshold: 0.5 entries { actual_class_id: 0 predicted_class_id: 0 false_negatives: 0.0 true_negatives: 0.0 false_positives: 0.0 true_positives: 2.0 } entries { actual_class_id: 0 predicted_class_id: 1 false_negatives: 1.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 0.0 } entries { actual_class_id: 0 predicted_class_id: 2 false_negatives: 0.0 true_negatives: 2.0 false_positives: 0.0 true_positives: 0.0 } entries { actual_class_id: 1 predicted_class_id: 0 false_negatives: 0.0 true_negatives: 1.0 false_positives: 0.0 true_positives: 1.0 } entries { actual_class_id: 1 predicted_class_id: 1 false_negatives: 1.0 true_negatives: 0.0 false_positives: 0.0 true_positives: 1.0 } entries { actual_class_id: 1 predicted_class_id: 2 false_negatives: 0.0 false_positives: 0.0 true_negatives: 2.0 true_positives: 0.0 } } """, got_matrix) except AssertionError as err: raise util.BeamAssertException(err)
def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey(name='calibration_plot') self.assertIn(key, got_plots) got_plot = got_plots[key] self.assertProtoEquals( """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 total_weighted_label { value: 4.0 } total_weighted_refined_prediction { value: -0.4 } num_weighted_examples { value: 4.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.1 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.1 upper_threshold_exclusive: 0.2 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.2 upper_threshold_exclusive: 0.3 total_weighted_label { } total_weighted_refined_prediction { value: 1.6 } num_weighted_examples { value: 8.0 } } buckets { lower_threshold_inclusive: 0.3 upper_threshold_exclusive: 0.4 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.4 upper_threshold_exclusive: 0.5 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 0.6 total_weighted_label { value: 5.0 } total_weighted_refined_prediction { value: 4.0 } num_weighted_examples { value: 8.0 } } buckets { lower_threshold_inclusive: 0.6 upper_threshold_exclusive: 0.7 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.7 upper_threshold_exclusive: 0.8 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 0.8 upper_threshold_exclusive: 0.9 total_weighted_label { value: 8.0 } total_weighted_refined_prediction { value: 6.4 } num_weighted_examples { value: 8.0 } } buckets { lower_threshold_inclusive: 0.9 upper_threshold_exclusive: 1.0 total_weighted_label { } total_weighted_refined_prediction { } num_weighted_examples { } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf total_weighted_label { value: 8.0 } total_weighted_refined_prediction { value: 8.8 } num_weighted_examples { value: 8.0 } } """, got_plot) except AssertionError as err: raise util.BeamAssertException(err)
def testSerializePlots(self): slice_key = _make_slice_key('fruit', 'apple') plot_key = metric_types.PlotKey(name='calibration_plot', output_name='output_name') calibration_plot = text_format.Parse( """ buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } """, metrics_for_slice_pb2.CalibrationHistogramBuckets()) expected_plots_for_slice = text_format.Parse( """ slice_key { single_slice_keys { column: 'fruit' bytes_value: 'apple' } } plot_keys_and_values { key { output_name: "output_name" } value { calibration_histogram_buckets { buckets { lower_threshold_inclusive: -inf upper_threshold_exclusive: 0.0 num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } buckets { lower_threshold_inclusive: 0.0 upper_threshold_exclusive: 0.5 num_weighted_examples { value: 1.0 } total_weighted_label { value: 1.0 } total_weighted_refined_prediction { value: 0.3 } } buckets { lower_threshold_inclusive: 0.5 upper_threshold_exclusive: 1.0 num_weighted_examples { value: 1.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.7 } } buckets { lower_threshold_inclusive: 1.0 upper_threshold_exclusive: inf num_weighted_examples { value: 0.0 } total_weighted_label { value: 0.0 } total_weighted_refined_prediction { value: 0.0 } } } } } """, metrics_for_slice_pb2.PlotsForSlice()) got = metrics_and_plots_serialization._serialize_plots( (slice_key, { plot_key: calibration_plot }), None) self.assertProtoEquals( expected_plots_for_slice, metrics_for_slice_pb2.PlotsForSlice.FromString(got))
def calibration_histogram( num_buckets: Optional[int] = None, left: Optional[float] = None, right: Optional[float] = None, name: Optional[str] = None, eval_config: Optional[config_pb2.EvalConfig] = None, model_name: str = '', output_name: str = '', sub_key: Optional[metric_types.SubKey] = None, aggregation_type: Optional[metric_types.AggregationType] = None, class_weights: Optional[Dict[int, float]] = None, example_weighted: bool = False, prediction_based_bucketing: bool = True, fractional_labels: Optional[bool] = None, ) -> metric_types.MetricComputations: """Returns metric computations for calibration histogram. Args: num_buckets: Number of buckets to use. Note that the actual number of buckets will be num_buckets + 2 to account for the edge cases. left: Start of predictions interval. right: End of predictions interval. name: Metric name. eval_config: Eval config. model_name: Optional model name (if multi-model evaluation). output_name: Optional output name (if multi-output model type). sub_key: Optional sub key. aggregation_type: Optional aggregation type. class_weights: Optional class weights to apply to multi-class / multi-label labels and predictions prior to flattening (when micro averaging is used). example_weighted: True if example weights should be applied. prediction_based_bucketing: If true, create buckets based on predictions else use labels to perform bucketing. fractional_labels: If true, each incoming tuple of (label, prediction, and example weight) will be split into two tuples as follows (where l, p, w represent the resulting label, prediction, and example weight values): (1) l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l = 1.0, p = prediction, and w = example_weight * label If enabled, an exception will be raised if labels are not within [0, 1]. The implementation is such that tuples associated with a weight of zero are not yielded. This means it is safe to enable fractional_labels even when the labels only take on the values of 0.0 or 1.0. Returns: MetricComputations for computing the histogram(s). """ if num_buckets is None: num_buckets = DEFAULT_NUM_BUCKETS if left is None: left = 0.0 if right is None: right = 1.0 if fractional_labels is None: fractional_labels = (left == 0.0 and right == 1.0) if name is None: name = f'{CALIBRATION_HISTOGRAM_NAME}_{num_buckets}' key = metric_types.PlotKey(name=name, model_name=model_name, output_name=output_name, sub_key=sub_key, example_weighted=example_weighted) return [ metric_types.MetricComputation( keys=[key], preprocessor=None, combiner=_CalibrationHistogramCombiner( key=key, eval_config=eval_config, aggregation_type=aggregation_type, class_weights=class_weights, example_weighted=example_weighted, num_buckets=num_buckets, left=left, right=right, prediction_based_bucketing=prediction_based_bucketing, fractional_labels=fractional_labels)) ]