def testMetricKeysToSkipForConfidenceIntervals(self): metrics_specs = [ config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold())), config_pb2.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold())), config_pb2.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=[ 'output_name1', 'output_name2' ]), ] metrics_specs += metric_specs.specs_from_metrics( [tf.keras.metrics.MeanSquaredError('mse')], model_names=['model_name1', 'model_name2']) keys = metric_specs.metric_keys_to_skip_for_confidence_intervals( metrics_specs, eval_config=config_pb2.EvalConfig()) self.assertLen(keys, 8) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name1'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name2'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name1'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name2'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name1'), keys) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', example_weighted=True), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name2'), keys) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', example_weighted=True), keys)
def _get_confidence_interval_params( eval_config: config_pb2.EvalConfig, metrics_specs: Iterable[config_pb2.MetricsSpec] ) -> _ConfidenceIntervalParams: """Helper method for extracting confidence interval info from configs. Args: eval_config: The eval_config. metrics_specs: The metrics_specs containing either all metrics, or the ones which share a query key. Returns: A _ConfidenceIntervalParams object containing the number of jacknife samples to use for computing a jackknife confidence interval, the number of bootstrap samples to use for computing Poisson bootstrap confidence intervals, and the set of metric keys which should not have confidence intervals displayed in the output. """ skip_ci_metric_keys = ( metric_specs.metric_keys_to_skip_for_confidence_intervals(metrics_specs)) num_jackknife_samples = 0 num_bootstrap_samples = 0 ci_method = eval_config.options.confidence_intervals.method if eval_config.options.compute_confidence_intervals.value: if ci_method == config_pb2.ConfidenceIntervalOptions.JACKKNIFE: num_jackknife_samples = _DEFAULT_NUM_JACKKNIFE_BUCKETS elif ci_method == config_pb2.ConfidenceIntervalOptions.POISSON_BOOTSTRAP: num_bootstrap_samples = _DEFAULT_NUM_BOOTSTRAP_SAMPLES return _ConfidenceIntervalParams(num_jackknife_samples, num_bootstrap_samples, skip_ci_metric_keys)
def testMetricKeysToSkipForConfidenceIntervals(self): metrics_specs = [ config.MetricsSpec( metrics=[ config.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config.MetricThreshold( value_threshold=config.GenericValueThreshold())), config.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold())), config.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config.MetricThreshold( change_threshold=config.GenericChangeThreshold())) ], # Model names and output_names should be ignored because # ExampleCount is model independent. model_names=['model_name1', 'model_name2'], output_names=['output_name1', 'output_name2']), ] metrics_specs += metric_specs.specs_from_metrics( [tf.keras.metrics.MeanSquaredError('mse')]) keys = metric_specs.metric_keys_to_skip_for_confidence_intervals( metrics_specs) self.assertLen(keys, 1) self.assertIn(metric_types.MetricKey(name='example_count'), keys)