def example_count_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, output_weights: Optional[Dict[Text, float]] = None, include_example_count: bool = True, include_weighted_example_count: bool = True ) -> List[config_pb2.MetricsSpec]: """Returns metric specs for example count and weighted example counts. Args: model_names: Optional list of model names (if multi-model evaluation). output_names: Optional list of output names (if multi-output model). output_weights: Optional output weights for creating overall metric aggregated across outputs (if multi-output model). If a weight is not provided for an output, it's weight defaults to 0.0 (i.e. output ignored). include_example_count: True to add example_count metric. include_weighted_example_count: True to add weighted_example_count metric. A weighted example count will be added per output for multi-output models. """ specs = [] if include_example_count: metric_config = _serialize_tfma_metric(example_count.ExampleCount()) specs.append( config_pb2.MetricsSpec( metrics=[metric_config], model_names=model_names)) if include_weighted_example_count: metric_config = _serialize_tfma_metric( weighted_example_count.WeightedExampleCount()) specs.append( config_pb2.MetricsSpec( metrics=[metric_config], model_names=model_names, output_names=output_names, output_weights=output_weights)) return specs
def metrics_spec( self, me_metrics_spec: me_proto.MetricsSpec) -> config_pb2.MetricsSpec: """Convert ME MetricsSpec into TFMA. Args: me_metrics_spec: Input ME MetricsSpec. Returns: TFMA MetricsSpec. """ if not me_metrics_spec: return None tfma_metrics_spec = config_pb2.MetricsSpec() for metric_config in me_metrics_spec.metrics: tfma_metrics_spec.metrics.append( self._metric_config(metric_config)) if me_metrics_spec.HasField('binarize'): tfma_metrics_spec.binarize.CopyFrom( self._binarization_options(me_metrics_spec.binarize)) if me_metrics_spec.HasField('aggregate'): tfma_metrics_spec.aggregate.CopyFrom( self._aggregation_options(me_metrics_spec.aggregate)) tfma_metrics_spec.model_names.extend(me_metrics_spec.model_names) return tfma_metrics_spec
def testMetricKeysToSkipForConfidenceIntervals(self): metrics_specs = [ config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold())), config_pb2.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold())), config_pb2.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=[ 'output_name1', 'output_name2' ]), ] metrics_specs += metric_specs.specs_from_metrics( [tf.keras.metrics.MeanSquaredError('mse')], model_names=['model_name1', 'model_name2']) keys = metric_specs.metric_keys_to_skip_for_confidence_intervals( metrics_specs, eval_config=config_pb2.EvalConfig()) self.assertLen(keys, 8) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name1'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name1', output_name='output_name2'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name1'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name2', output_name='output_name2'), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name1'), keys) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', example_weighted=True), keys) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name2'), keys) self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', example_weighted=True), keys)
def testToComputationsWithMixedAggregationAndNonAggregationMetrics(self): computations = metric_specs.to_computations([ config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig(class_name='CategoricalAccuracy') ]), config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='BinaryCrossentropy') ], binarize=config_pb2.BinarizationOptions( class_ids={'values': [1]}), aggregate=config_pb2.AggregationOptions(micro_average=True)) ], config_pb2.EvalConfig()) # 3 separate computations should be used (one for aggregated metrics, one # for non-aggregated metrics, and one for metrics associated with class 1) self.assertLen(computations, 3)
def testMetricsSpecBeamCounter(self): with beam.Pipeline() as pipeline: metrics_spec = config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig(class_name='FairnessIndicators') ]) model_types = set(['tf_js', 'tf_keras']) _ = pipeline | counter_util.IncrementMetricsSpecsCounters( [metrics_spec], model_types) result = pipeline.run() for model_type in model_types: metric_filter = beam.metrics.metric.MetricsFilter().with_namespace( constants.METRICS_NAMESPACE).with_name( 'metric_computed_FairnessIndicators_v2_' + model_type) actual_metrics_count = result.metrics().query( filter=metric_filter)['counters'][0].committed self.assertEqual(actual_metrics_count, 1)
def _EvaluateMetricsPlotsAndValidations( # pylint: disable=invalid-name extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig, eval_shared_models: Optional[Dict[Text, types.EvalSharedModel]] = None, metrics_key: Text = constants.METRICS_KEY, plots_key: Text = constants.PLOTS_KEY, attributions_key: Text = constants.ATTRIBUTIONS_KEY, validations_key: Text = constants.VALIDATIONS_KEY, schema: Optional[schema_pb2.Schema] = None, random_seed_for_testing: Optional[int] = None, tensor_adapter_config: Optional[tensor_adapter.TensorAdapterConfig] = None ) -> evaluator.Evaluation: """Evaluates metrics, plots, and validations. Args: extracts: PCollection of Extracts. The extracts must contain a list of slices of type SliceKeyType keyed by tfma.SLICE_KEY_TYPES_KEY as well as any extracts required by the metric implementations (typically this will include labels keyed by tfma.LABELS_KEY, predictions keyed by tfma.PREDICTIONS_KEY, and example weights keyed by tfma.EXAMPLE_WEIGHTS_KEY). Usually these will be added by calling the default_extractors function. eval_config: Eval config. eval_shared_models: Optional dict of shared models keyed by model name. Only required if there are metrics to be computed in-graph using the model. metrics_key: Name to use for metrics key in Evaluation output. plots_key: Name to use for plots key in Evaluation output. attributions_key: Name to use for attributions key in Evaluation output. validations_key: Name to use for validation key in Evaluation output. schema: A schema to use for customizing metrics and plots. random_seed_for_testing: Seed to use for unit testing. tensor_adapter_config: Tensor adapter config which specifies how to obtain tensors from the Arrow RecordBatch. The model's signature will be invoked with those tensors (matched by names). If None, an attempt will be made to create an adapter based on the model's input signature otherwise the model will be invoked with raw examples (assuming a signature of a single 1-D string tensor). Returns: Evaluation containing dict of PCollections of (slice_key, results_dict) tuples where the dict is keyed by either the metrics_key (e.g. 'metrics'), plots_key (e.g. 'plots'), attributions_key (e.g. 'attributions'), or validation_key (e.g. 'validations') depending on what the results_dict contains. """ # Separate metrics based on query_key (which may be None). metrics_specs_by_query_key = {} for spec in eval_config.metrics_specs: if spec.query_key not in metrics_specs_by_query_key: metrics_specs_by_query_key[spec.query_key] = [] metrics_specs_by_query_key[spec.query_key].append(spec) # If there are no metrics specs then add an empty one (this is required for # cases where only the default metrics from the model are used). if not metrics_specs_by_query_key: metrics_specs_by_query_key[''] = [config_pb2.MetricsSpec()] # pylint: disable=no-value-for-parameter evaluations = {} for query_key, metrics_specs in metrics_specs_by_query_key.items(): query_key_text = query_key if query_key else '' if query_key: extracts_for_evaluation = ( extracts | 'GroupByQueryKey({})'.format(query_key_text) >> _GroupByQueryKey(query_key)) include_default_metrics = False else: extracts_for_evaluation = extracts include_default_metrics = ( eval_config and (not eval_config.options.HasField('include_default_metrics') or eval_config.options.include_default_metrics.value)) evaluation = ( extracts_for_evaluation | 'ComputeMetricsAndPlots({})'.format(query_key_text) >> _ComputeMetricsAndPlots( eval_config=eval_config, metrics_specs=metrics_specs, eval_shared_models=(eval_shared_models if include_default_metrics else None), metrics_key=metrics_key, plots_key=plots_key, attributions_key=attributions_key, schema=schema, random_seed_for_testing=random_seed_for_testing, tensor_adapter_config=tensor_adapter_config)) for k, v in evaluation.items(): if k not in evaluations: evaluations[k] = [] evaluations[k].append(v) evaluation_results = evaluator.combine_dict_based_evaluations(evaluations) validations = ( evaluation_results[metrics_key] | 'ValidateMetrics' >> beam.Map(metrics_validator.validate_metrics, eval_config)) evaluation_results[validations_key] = validations return evaluation_results
def specs_from_metrics( metrics: Union[List[_TFOrTFMAMetric], Dict[Text, List[_TFOrTFMAMetric]]], model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, output_weights: Optional[Dict[Text, float]] = None, binarize: Optional[config_pb2.BinarizationOptions] = None, aggregate: Optional[config_pb2.AggregationOptions] = None, query_key: Optional[Text] = None, include_example_count: Optional[bool] = None, include_weighted_example_count: Optional[bool] = None ) -> List[config_pb2.MetricsSpec]: """Returns specs for tf.keras.metrics/losses or tfma.metrics classes. Examples: metrics_specs = specs_from_metrics([ tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.AUC(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), tfma.metrics.MeanLabel(), tfma.metrics.MeanPrediction() ... ]) metrics_specs = specs_from_metrics({ 'output1': [ tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.AUC(), tfma.metrics.MeanLabel(), tfma.metrics.MeanPrediction() ... ], 'output2': [ tf.keras.metrics.Precision(), tf.keras.metrics.Recall(), ] }) Args: metrics: List of tf.keras.metrics.Metric, tf.keras.losses.Loss, or tfma.metrics.Metric. For multi-output models a dict of dicts may be passed where the first dict is indexed by the output_name. model_names: Optional model names (if multi-model evaluation). output_names: Optional output names (if multi-output models). If the metrics are a dict this should not be set. output_weights: Optional output weights for creating overall metric aggregated across outputs (if multi-output model). If a weight is not provided for an output, it's weight defaults to 0.0 (i.e. output ignored). binarize: Optional settings for binarizing multi-class/multi-label metrics. aggregate: Optional settings for aggregating multi-class/multi-label metrics. query_key: Optional query key for query/ranking based metrics. include_example_count: True to add example_count metric. Default is True. include_weighted_example_count: True to add weighted_example_count metric. Default is True. A weighted example count will be added per output for multi-output models. """ if isinstance(metrics, dict) and output_names: raise ValueError('metrics cannot be a dict when output_names is used: ' 'metrics={}, output_names={}'.format( metrics, output_names)) if isinstance(metrics, dict): specs = [] for output_name in sorted(metrics.keys()): specs.extend( specs_from_metrics( metrics[output_name], model_names=model_names, output_names=[output_name], binarize=binarize, aggregate=aggregate, include_example_count=include_example_count, include_weighted_example_count=include_weighted_example_count)) include_example_count = False return specs if include_example_count is None: include_example_count = True if include_weighted_example_count is None: include_weighted_example_count = True # Add the computations for the example counts and weights since they are # independent of the model and class ID. specs = example_count_specs( model_names=model_names, output_names=output_names, output_weights=output_weights, include_example_count=include_example_count, include_weighted_example_count=include_weighted_example_count) metric_configs = [] for metric in metrics: if isinstance(metric, tf.keras.metrics.Metric): metric_configs.append(_serialize_tf_metric(metric)) elif isinstance(metric, tf.keras.losses.Loss): metric_configs.append(_serialize_tf_loss(metric)) elif isinstance(metric, metric_types.Metric): metric_configs.append(_serialize_tfma_metric(metric)) else: raise NotImplementedError('unknown metric type {}: metric={}'.format( type(metric), metric)) specs.append( config_pb2.MetricsSpec( metrics=metric_configs, model_names=model_names, output_names=output_names, output_weights=output_weights, binarize=binarize, aggregate=aggregate, query_key=query_key)) return specs
def to_computations( metrics_specs: List[config_pb2.MetricsSpec], eval_config: Optional[config_pb2.EvalConfig] = None, schema: Optional[schema_pb2.Schema] = None ) -> metric_types.MetricComputations: """Returns computations associated with given metrics specs.""" computations = [] # # Split into TF metrics and TFMA metrics # # Dict[Text, Type[tf.keras.metrics.Metric]] tf_metric_classes = {} # class_name -> class # Dict[Text, Type[tf.keras.losses.Loss]] tf_loss_classes = {} # class_name -> class # List[metric_types.MetricsSpec] tf_metrics_specs = [] # Dict[Text, Type[metric_types.Metric]] tfma_metric_classes = metric_types.registered_metrics() # class_name -> class # List[metric_types.MetricsSpec] tfma_metrics_specs = [] # # Note: Lists are used instead of Dicts for the following items because # protos are are no hashable. # # List[List[_TFOrTFMAMetric]] (offsets align with metrics_specs). per_spec_metric_instances = [] # List[List[_TFMetricOrLoss]] (offsets align with tf_metrics_specs). per_tf_spec_metric_instances = [] # List[List[metric_types.Metric]]] (offsets align with tfma_metrics_specs). per_tfma_spec_metric_instances = [] for spec in metrics_specs: tf_spec = config_pb2.MetricsSpec() tf_spec.CopyFrom(spec) del tf_spec.metrics[:] tfma_spec = config_pb2.MetricsSpec() tfma_spec.CopyFrom(spec) del tfma_spec.metrics[:] for metric in spec.metrics: if metric.class_name in tfma_metric_classes: tfma_spec.metrics.append(metric) elif not metric.module: tf_spec.metrics.append(metric) else: cls = getattr(importlib.import_module(metric.module), metric.class_name) if issubclass(cls, tf.keras.metrics.Metric): tf_metric_classes[metric.class_name] = cls tf_spec.metrics.append(metric) elif issubclass(cls, tf.keras.losses.Loss): tf_loss_classes[metric.class_name] = cls tf_spec.metrics.append(metric) else: tfma_metric_classes[metric.class_name] = cls tfma_spec.metrics.append(metric) metric_instances = [] if tf_spec.metrics: tf_metrics_specs.append(tf_spec) tf_metric_instances = [] for m in tf_spec.metrics: # To distinguish losses from metrics, losses are required to set the # module name. if m.module == _TF_LOSSES_MODULE: tf_metric_instances.append(_deserialize_tf_loss(m, tf_loss_classes)) else: tf_metric_instances.append( _deserialize_tf_metric(m, tf_metric_classes)) per_tf_spec_metric_instances.append(tf_metric_instances) metric_instances.extend(tf_metric_instances) if tfma_spec.metrics: tfma_metrics_specs.append(tfma_spec) tfma_metric_instances = [ _deserialize_tfma_metric(m, tfma_metric_classes) for m in tfma_spec.metrics ] per_tfma_spec_metric_instances.append(tfma_metric_instances) metric_instances.extend(tfma_metric_instances) per_spec_metric_instances.append(metric_instances) # Process TF specs computations.extend( _process_tf_metrics_specs(tf_metrics_specs, per_tf_spec_metric_instances, eval_config)) # Process TFMA specs computations.extend( _process_tfma_metrics_specs(tfma_metrics_specs, per_tfma_spec_metric_instances, eval_config, schema)) # Process aggregation based metrics (output aggregation and macro averaging). # Note that processing of TF and TFMA specs were setup to create the binarized # metrics that macro averaging depends on. for i, spec in enumerate(metrics_specs): for aggregation_type, sub_keys in _create_sub_keys(spec).items(): output_names = spec.output_names or [''] output_weights = dict(spec.output_weights) if not set(output_weights.keys()).issubset(output_names): raise ValueError( 'one or more output_names used in output_weights does not exist: ' 'output_names={}, output_weights={}'.format(output_names, output_weights)) for model_name in spec.model_names or ['']: for sub_key in sub_keys: for metric in per_spec_metric_instances[i]: if (aggregation_type and (aggregation_type.macro_average or aggregation_type.weighted_macro_average)): class_weights = _class_weights(spec) or {} for output_name in output_names: macro_average_sub_keys = _macro_average_sub_keys( sub_key, class_weights) if aggregation_type.macro_average: computations.extend( aggregation.macro_average( metric.get_config()['name'], sub_keys=macro_average_sub_keys, eval_config=eval_config, model_name=model_name, output_name=output_name, sub_key=sub_key, class_weights=class_weights)) elif aggregation_type.weighted_macro_average: computations.extend( aggregation.weighted_macro_average( metric.get_config()['name'], sub_keys=macro_average_sub_keys, eval_config=eval_config, model_name=model_name, output_name=output_name, sub_key=sub_key, class_weights=class_weights)) if output_weights: computations.extend( aggregation.output_average( metric.get_config()['name'], output_weights=output_weights, eval_config=eval_config, model_name=model_name, sub_key=sub_key)) return computations
def testSpecsFromMetrics(self): metrics_specs = metric_specs.specs_from_metrics( { 'output_name1': [ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.MeanSquaredError('mse'), tf.keras.losses.MeanAbsoluteError(name='mae'), ], 'output_name2': [ confusion_matrix_metrics.Precision(name='precision'), tf.keras.losses.MeanAbsolutePercentageError(name='mape'), calibration.MeanPrediction('mean_prediction') ] }, unweighted_metrics={ 'output_name1': [calibration.MeanLabel('mean_label')], 'output_name2': [tf.keras.metrics.RootMeanSquaredError('rmse')] }, model_names=['model_name1', 'model_name2'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True)) self.assertLen(metrics_specs, 7) self.assertProtoEquals( metrics_specs[0], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='ExampleCount', config=json.dumps( {'name': 'example_count'})), ], model_names=['model_name1', 'model_name2'], example_weights=config_pb2.ExampleWeightOptions( unweighted=True))) self.assertProtoEquals( metrics_specs[1], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'})), ], model_names=['model_name1', 'model_name2'], output_names=['output_name1'], example_weights=config_pb2.ExampleWeightOptions( weighted=True))) self.assertProtoEquals( metrics_specs[2], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='Precision', config=json.dumps( { 'name': 'precision', 'class_id': None, 'thresholds': None, 'top_k': None }, sort_keys=True)), config_pb2.MetricConfig(class_name='MeanSquaredError', config=json.dumps( { 'name': 'mse', 'dtype': 'float32', }, sort_keys=True)), config_pb2.MetricConfig( class_name='MeanAbsoluteError', module=metric_specs._TF_LOSSES_MODULE, config=json.dumps({ 'reduction': 'auto', 'name': 'mae' }, sort_keys=True)) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True))) self.assertProtoEquals( metrics_specs[3], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='MeanLabel', config=json.dumps( {'name': 'mean_label'})) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True), example_weights=config_pb2.ExampleWeightOptions( unweighted=True))) self.assertProtoEquals( metrics_specs[4], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'})), ], model_names=['model_name1', 'model_name2'], output_names=['output_name2'], example_weights=config_pb2.ExampleWeightOptions( weighted=True))) self.assertProtoEquals( metrics_specs[5], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='Precision', config=json.dumps( { 'name': 'precision', }, sort_keys=True)), config_pb2.MetricConfig( class_name='MeanAbsolutePercentageError', module=metric_specs._TF_LOSSES_MODULE, config=json.dumps({ 'reduction': 'auto', 'name': 'mape' }, sort_keys=True)), config_pb2.MetricConfig(class_name='MeanPrediction', config=json.dumps( {'name': 'mean_prediction'})) ], model_names=['model_name1', 'model_name2'], output_names=['output_name2'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True))) self.assertProtoEquals( metrics_specs[6], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='RootMeanSquaredError', config=json.dumps( { 'name': 'rmse', 'dtype': 'float32' }, sort_keys=True)) ], model_names=['model_name1', 'model_name2'], output_names=['output_name2'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True), example_weights=config_pb2.ExampleWeightOptions( unweighted=True)))
def testMetricThresholdsFromMetricsSpecs(self): slice_specs = [ config_pb2.SlicingSpec(feature_keys=['feature1']), config_pb2.SlicingSpec(feature_values={'feature2': 'value1'}) ] # For cross slice tests. baseline_slice_spec = config_pb2.SlicingSpec(feature_keys=['feature3']) metrics_specs = [ config_pb2.MetricsSpec( thresholds={ 'auc': config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold()), 'mean/label': config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold(), change_threshold=config_pb2.GenericChangeThreshold()), 'mse': config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold()) }, per_slice_thresholds={ 'auc': config_pb2.PerSliceMetricThresholds(thresholds=[ config_pb2.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold())) ]), 'mean/label': config_pb2.PerSliceMetricThresholds(thresholds=[ config_pb2.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold(), change_threshold=config_pb2. GenericChangeThreshold())) ]) }, cross_slice_thresholds={ 'auc': config_pb2.CrossSliceMetricThresholds(thresholds=[ config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold(), change_threshold=config_pb2. GenericChangeThreshold())) ]), 'mse': config_pb2.CrossSliceMetricThresholds(thresholds=[ config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( change_threshold=config_pb2. GenericChangeThreshold())), # Test for duplicate cross_slicing_spec. config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold())) ]) }, model_names=['model_name'], output_names=['output_name']), config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], example_weights=config_pb2. ExampleWeightOptions(unweighted=True)), config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'}), threshold=config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=[ 'output_name1', 'output_name2' ], example_weights=config_pb2. ExampleWeightOptions(weighted=True)), config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold())), config_pb2.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold()), per_slice_thresholds=[ config_pb2.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config_pb2.MetricThreshold( change_threshold=config_pb2. GenericChangeThreshold())), ], cross_slice_thresholds=[ config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( change_threshold=config_pb2. GenericChangeThreshold())) ]), ], model_names=['model_name'], output_names=['output_name'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions( macro_average=True, class_weights={ 0: 1.0, 1: 1.0 })) ] thresholds = metric_specs.metric_thresholds_from_metrics_specs( metrics_specs, eval_config=config_pb2.EvalConfig()) expected_keys_and_threshold_counts = { metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=False, example_weighted=None): 4, metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=True, example_weighted=None): 1, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=True, example_weighted=None): 3, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=False, example_weighted=None): 3, metric_types.MetricKey(name='example_count', model_name='model_name1'): 1, metric_types.MetricKey(name='example_count', model_name='model_name2'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name1', example_weighted=True): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name2', example_weighted=True): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name1', example_weighted=True): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name2', example_weighted=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=False, example_weighted=None): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=True, example_weighted=None): 2, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 1, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 4 } self.assertLen(thresholds, len(expected_keys_and_threshold_counts)) for key, count in expected_keys_and_threshold_counts.items(): self.assertIn(key, thresholds) self.assertLen(thresholds[key], count, 'failed for key {}'.format(key))