def testMetricKeysToSkipForConfidenceIntervals(self):
     metrics_specs = [
         config_pb2.MetricsSpec(metrics=[
             config_pb2.MetricConfig(
                 class_name='ExampleCount',
                 config=json.dumps({'name': 'example_count'}),
                 threshold=config_pb2.MetricThreshold(
                     value_threshold=config_pb2.GenericValueThreshold())),
             config_pb2.MetricConfig(
                 class_name='MeanLabel',
                 config=json.dumps({'name': 'mean_label'}),
                 threshold=config_pb2.MetricThreshold(
                     change_threshold=config_pb2.GenericChangeThreshold())),
             config_pb2.MetricConfig(
                 class_name='MeanSquaredError',
                 config=json.dumps({'name': 'mse'}),
                 threshold=config_pb2.MetricThreshold(
                     change_threshold=config_pb2.GenericChangeThreshold()))
         ],
                                model_names=['model_name1', 'model_name2'],
                                output_names=[
                                    'output_name1', 'output_name2'
                                ]),
     ]
     metrics_specs += metric_specs.specs_from_metrics(
         [tf.keras.metrics.MeanSquaredError('mse')],
         model_names=['model_name1', 'model_name2'])
     keys = metric_specs.metric_keys_to_skip_for_confidence_intervals(
         metrics_specs, eval_config=config_pb2.EvalConfig())
     self.assertLen(keys, 8)
     self.assertIn(
         metric_types.MetricKey(name='example_count',
                                model_name='model_name1',
                                output_name='output_name1'), keys)
     self.assertIn(
         metric_types.MetricKey(name='example_count',
                                model_name='model_name1',
                                output_name='output_name2'), keys)
     self.assertIn(
         metric_types.MetricKey(name='example_count',
                                model_name='model_name2',
                                output_name='output_name1'), keys)
     self.assertIn(
         metric_types.MetricKey(name='example_count',
                                model_name='model_name2',
                                output_name='output_name2'), keys)
     self.assertIn(
         metric_types.MetricKey(name='example_count',
                                model_name='model_name1'), keys)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name1',
                                example_weighted=True), keys)
     self.assertIn(
         metric_types.MetricKey(name='example_count',
                                model_name='model_name2'), keys)
     self.assertIn(
         metric_types.MetricKey(name='weighted_example_count',
                                model_name='model_name2',
                                example_weighted=True), keys)
def _get_confidence_interval_params(
    eval_config: config_pb2.EvalConfig,
    metrics_specs: Iterable[config_pb2.MetricsSpec]
) -> _ConfidenceIntervalParams:
  """Helper method for extracting confidence interval info from configs.

  Args:
    eval_config: The eval_config.
    metrics_specs: The metrics_specs containing either all metrics, or the ones
      which share a query key.

  Returns:
    A _ConfidenceIntervalParams object containing the number of jacknife samples
    to use for computing a jackknife confidence interval, the number of
    bootstrap samples to use for computing Poisson bootstrap confidence
    intervals, and the set of metric keys which should not have confidence
    intervals displayed in the output.
  """
  skip_ci_metric_keys = (
      metric_specs.metric_keys_to_skip_for_confidence_intervals(metrics_specs))
  num_jackknife_samples = 0
  num_bootstrap_samples = 0
  ci_method = eval_config.options.confidence_intervals.method
  if eval_config.options.compute_confidence_intervals.value:
    if ci_method == config_pb2.ConfidenceIntervalOptions.JACKKNIFE:
      num_jackknife_samples = _DEFAULT_NUM_JACKKNIFE_BUCKETS
    elif ci_method == config_pb2.ConfidenceIntervalOptions.POISSON_BOOTSTRAP:
      num_bootstrap_samples = _DEFAULT_NUM_BOOTSTRAP_SAMPLES
  return _ConfidenceIntervalParams(num_jackknife_samples, num_bootstrap_samples,
                                   skip_ci_metric_keys)
Exemplo n.º 3
0
 def testMetricKeysToSkipForConfidenceIntervals(self):
   metrics_specs = [
       config.MetricsSpec(
           metrics=[
               config.MetricConfig(
                   class_name='ExampleCount',
                   config=json.dumps({'name': 'example_count'}),
                   threshold=config.MetricThreshold(
                       value_threshold=config.GenericValueThreshold())),
               config.MetricConfig(
                   class_name='MeanLabel',
                   config=json.dumps({'name': 'mean_label'}),
                   threshold=config.MetricThreshold(
                       change_threshold=config.GenericChangeThreshold())),
               config.MetricConfig(
                   class_name='MeanSquaredError',
                   config=json.dumps({'name': 'mse'}),
                   threshold=config.MetricThreshold(
                       change_threshold=config.GenericChangeThreshold()))
           ],
           # Model names and output_names should be ignored because
           # ExampleCount is model independent.
           model_names=['model_name1', 'model_name2'],
           output_names=['output_name1', 'output_name2']),
   ]
   metrics_specs += metric_specs.specs_from_metrics(
       [tf.keras.metrics.MeanSquaredError('mse')])
   keys = metric_specs.metric_keys_to_skip_for_confidence_intervals(
       metrics_specs)
   self.assertLen(keys, 1)
   self.assertIn(metric_types.MetricKey(name='example_count'), keys)