def populate_stats_and_pop(
     self, slice_key: slicer.SliceKeyType, combine_metrics: Dict[Text, Any],
     output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
   # Remove metrics if it's not Overall slice. This post export metrics
   # calculate subgroup_auc, bpsn_auc, bnsp_auc. All of these are based on
   # all examples. That's why only the overall slice makes sence and the rest
   # will be removed.
   if slice_key:
     for metrics_key in (self._subgroup_auc_metric, self._bpsn_auc_metric,
                         self._bnsp_auc_metric):
       combine_metrics.pop(metric_keys.lower_bound_key(metrics_key))
       combine_metrics.pop(metric_keys.upper_bound_key(metrics_key))
       combine_metrics.pop(metrics_key)
   else:
     for metrics_key in (self._subgroup_auc_metric, self._bpsn_auc_metric,
                         self._bnsp_auc_metric):
       post_export_metrics._populate_to_auc_bounded_value_and_pop(
           combine_metrics, output_metrics, metrics_key)
    def testSerializeMetrics(self):
        slice_key = _make_slice_key('age', 5, 'language', 'english', 'price',
                                    0.3)
        slice_metrics = {
            'accuracy': 0.8,
            metric_keys.AUPRC: 0.1,
            metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
            metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
            metric_keys.AUC: 0.2,
            metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
            metric_keys.upper_bound_key(metric_keys.AUC): 0.3
        }
        expected_metrics_for_slice = text_format.Parse(
            string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
            metrics_for_slice_pb2.MetricsForSlice())

        got = metrics_and_plots_evaluator._serialize_metrics(
            (slice_key, slice_metrics),
            [post_export_metrics.auc(),
             post_export_metrics.auc(curve='PR')])
        self.assertProtoEquals(
            expected_metrics_for_slice,
            metrics_for_slice_pb2.MetricsForSlice.FromString(got))
Пример #3
0
  def testConvertSliceMetricsToProtoMetricsRanges(self):
    slice_key = _make_slice_key('age', 5, 'language', 'english', 'price', 0.3)
    slice_metrics = {
        'accuracy': types.ValueWithTDistribution(0.8, 0.1, 9, 0.8),
        metric_keys.AUPRC: 0.1,
        metric_keys.lower_bound_key(metric_keys.AUPRC): 0.05,
        metric_keys.upper_bound_key(metric_keys.AUPRC): 0.17,
        metric_keys.AUC: 0.2,
        metric_keys.lower_bound_key(metric_keys.AUC): 0.1,
        metric_keys.upper_bound_key(metric_keys.AUC): 0.3
    }
    expected_metrics_for_slice = text_format.Parse(
        string.Template("""
        slice_key {
          single_slice_keys {
            column: 'age'
            int64_value: 5
          }
          single_slice_keys {
            column: 'language'
            bytes_value: 'english'
          }
          single_slice_keys {
            column: 'price'
            float_value: 0.3
          }
        }
        metrics {
          key: "accuracy"
          value {
            bounded_value {
              value {
                value: 0.8
              }
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              methodology: POISSON_BOOTSTRAP
            }
            confidence_interval {
              lower_bound {
                value: 0.5737843
              }
              upper_bound {
                value: 1.0262157
              }
              t_distribution_value {
                sample_mean {
                  value: 0.8
                }
                sample_standard_deviation {
                  value: 0.1
                }
                sample_degrees_of_freedom {
                  value: 9
                }
                unsampled_value {
                  value: 0.8
                }
              }
            }
          }
        }
        metrics {
          key: "$auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
              methodology: RIEMANN_SUM
            }
          }
        }
        metrics {
          key: "$auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
              methodology: RIEMANN_SUM
            }
          }
        }""").substitute(auc=metric_keys.AUC, auprc=metric_keys.AUPRC),
        metrics_for_slice_pb2.MetricsForSlice())

    got = metrics_plots_and_validations_writer.convert_slice_metrics_to_proto(
        (slice_key, slice_metrics),
        [post_export_metrics.auc(),
         post_export_metrics.auc(curve='PR')])
    self.assertProtoEquals(expected_metrics_for_slice, got)