def default_multi_class_classification_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, output_weights: Optional[Dict[Text, float]] = None, binarize: Optional[config_pb2.BinarizationOptions] = None, aggregate: Optional[config_pb2.AggregationOptions] = None, sparse: bool = True) -> List[config_pb2.MetricsSpec]: """Returns default metric specs for multi-class classification problems. Args: model_names: Optional model names if multi-model evaluation. output_names: Optional list of output names (if multi-output model). output_weights: Optional output weights for creating overall metric aggregated across outputs (if multi-output model). If a weight is not provided for an output, it's weight defaults to 0.0 (i.e. output ignored). binarize: Optional settings for binarizing multi-class/multi-label metrics. aggregate: Optional settings for aggregating multi-class/multi-label metrics. sparse: True if the labels are sparse. """ if sparse: metrics = [ tf.keras.metrics.SparseCategoricalCrossentropy(name='loss'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy') ] else: metrics = [ tf.keras.metrics.CategoricalCrossentropy(name='loss'), tf.keras.metrics.CategoricalAccuracy(name='accuracy') ] metrics.append( multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot()) if binarize is not None: for top_k in binarize.top_k_list.values: metrics.extend([ tf.keras.metrics.Precision(name='precision', top_k=top_k), tf.keras.metrics.Recall(name='recall', top_k=top_k) ]) binarize_without_top_k = config_pb2.BinarizationOptions() binarize_without_top_k.CopyFrom(binarize) binarize_without_top_k.ClearField('top_k_list') binarize = binarize_without_top_k multi_class_metrics = specs_from_metrics( metrics, model_names=model_names, output_names=output_names, output_weights=output_weights) if aggregate is None: aggregate = config_pb2.AggregationOptions(micro_average=True) multi_class_metrics.extend( default_binary_classification_specs( model_names=model_names, output_names=output_names, output_weights=output_weights, binarize=binarize, aggregate=aggregate)) return multi_class_metrics
def testToComputationsWithMixedAggregationAndNonAggregationMetrics(self): computations = metric_specs.to_computations([ config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig(class_name='CategoricalAccuracy') ]), config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='BinaryCrossentropy') ], binarize=config_pb2.BinarizationOptions( class_ids={'values': [1]}), aggregate=config_pb2.AggregationOptions(micro_average=True)) ], config_pb2.EvalConfig()) # 3 separate computations should be used (one for aggregated metrics, one # for non-aggregated metrics, and one for metrics associated with class 1) self.assertLen(computations, 3)
def _aggregation_options( self, me_aggregation: me_proto.Aggregation ) -> Optional[config_pb2.AggregationOptions]: """Convert ME Aggregation into TFMA AggregationOptions. Args: me_aggregation: Input ME Aggregation. Returns: TFMA AggregationOptions. """ if not me_aggregation: return None tfma_aggregation = config_pb2.AggregationOptions() if me_aggregation.micro_average: tfma_aggregation.micro_average = True if me_aggregation.macro_average: tfma_aggregation.macro_average = True if me_aggregation.class_weights: tfma_aggregation.class_weights.update(me_aggregation.class_weights) return tfma_aggregation
def testToComputations(self): computations = metric_specs.to_computations( metric_specs.specs_from_metrics( [ tf.keras.metrics.MeanSquaredError('mse'), # Add a loss exactly same as metric # (https://github.com/tensorflow/tfx/issues/1550) tf.keras.losses.MeanSquaredError(name='loss'), calibration.MeanLabel('mean_label') ], model_names=['model_name'], output_names=['output_1', 'output_2'], output_weights={ 'output_1': 1.0, 'output_2': 1.0 }, binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True, class_weights={ 0: 1.0, 1: 1.0 })), config_pb2.EvalConfig()) keys = [] for m in computations: for k in m.keys: if not k.name.startswith('_'): keys.append(k) self.assertLen(keys, 31) self.assertIn( metric_types.MetricKey(name='example_count', model_name='model_name'), keys) for output_name in ('output_1', 'output_2', ''): self.assertIn( metric_types.MetricKey(name='weighted_example_count', model_name='model_name', output_name=output_name, example_weighted=True), keys) self.assertIn( metric_types.MetricKey( name='mse', model_name='model_name', output_name=output_name, sub_key=metric_types.SubKey(class_id=0)), keys) self.assertIn( metric_types.MetricKey( name='mse', model_name='model_name', output_name=output_name, sub_key=metric_types.SubKey(class_id=1)), keys) aggregation_type = metric_types.AggregationType( macro_average=True) if output_name else None self.assertIn( metric_types.MetricKey(name='mse', model_name='model_name', output_name=output_name, aggregation_type=aggregation_type), keys) self.assertIn( metric_types.MetricKey( name='loss', model_name='model_name', output_name=output_name, sub_key=metric_types.SubKey(class_id=0)), keys) self.assertIn( metric_types.MetricKey( name='loss', model_name='model_name', output_name=output_name, sub_key=metric_types.SubKey(class_id=1)), keys) aggregation_type = metric_types.AggregationType( macro_average=True) if output_name else None self.assertIn( metric_types.MetricKey(name='loss', model_name='model_name', output_name=output_name, aggregation_type=aggregation_type), keys) self.assertIn( metric_types.MetricKey( name='mean_label', model_name='model_name', output_name=output_name, sub_key=metric_types.SubKey(class_id=0)), keys) self.assertIn( metric_types.MetricKey( name='mean_label', model_name='model_name', output_name=output_name, sub_key=metric_types.SubKey(class_id=1)), keys) aggregation_type = metric_types.AggregationType( macro_average=True) if output_name else None self.assertIn( metric_types.MetricKey(name='mean_label', model_name='model_name', output_name=output_name, aggregation_type=aggregation_type), keys)
def testSpecsFromMetrics(self): metrics_specs = metric_specs.specs_from_metrics( { 'output_name1': [ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.MeanSquaredError('mse'), tf.keras.losses.MeanAbsoluteError(name='mae'), ], 'output_name2': [ confusion_matrix_metrics.Precision(name='precision'), tf.keras.losses.MeanAbsolutePercentageError(name='mape'), calibration.MeanPrediction('mean_prediction') ] }, unweighted_metrics={ 'output_name1': [calibration.MeanLabel('mean_label')], 'output_name2': [tf.keras.metrics.RootMeanSquaredError('rmse')] }, model_names=['model_name1', 'model_name2'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True)) self.assertLen(metrics_specs, 7) self.assertProtoEquals( metrics_specs[0], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='ExampleCount', config=json.dumps( {'name': 'example_count'})), ], model_names=['model_name1', 'model_name2'], example_weights=config_pb2.ExampleWeightOptions( unweighted=True))) self.assertProtoEquals( metrics_specs[1], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'})), ], model_names=['model_name1', 'model_name2'], output_names=['output_name1'], example_weights=config_pb2.ExampleWeightOptions( weighted=True))) self.assertProtoEquals( metrics_specs[2], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='Precision', config=json.dumps( { 'name': 'precision', 'class_id': None, 'thresholds': None, 'top_k': None }, sort_keys=True)), config_pb2.MetricConfig(class_name='MeanSquaredError', config=json.dumps( { 'name': 'mse', 'dtype': 'float32', }, sort_keys=True)), config_pb2.MetricConfig( class_name='MeanAbsoluteError', module=metric_specs._TF_LOSSES_MODULE, config=json.dumps({ 'reduction': 'auto', 'name': 'mae' }, sort_keys=True)) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True))) self.assertProtoEquals( metrics_specs[3], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='MeanLabel', config=json.dumps( {'name': 'mean_label'})) ], model_names=['model_name1', 'model_name2'], output_names=['output_name1'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True), example_weights=config_pb2.ExampleWeightOptions( unweighted=True))) self.assertProtoEquals( metrics_specs[4], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'})), ], model_names=['model_name1', 'model_name2'], output_names=['output_name2'], example_weights=config_pb2.ExampleWeightOptions( weighted=True))) self.assertProtoEquals( metrics_specs[5], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='Precision', config=json.dumps( { 'name': 'precision', }, sort_keys=True)), config_pb2.MetricConfig( class_name='MeanAbsolutePercentageError', module=metric_specs._TF_LOSSES_MODULE, config=json.dumps({ 'reduction': 'auto', 'name': 'mape' }, sort_keys=True)), config_pb2.MetricConfig(class_name='MeanPrediction', config=json.dumps( {'name': 'mean_prediction'})) ], model_names=['model_name1', 'model_name2'], output_names=['output_name2'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True))) self.assertProtoEquals( metrics_specs[6], config_pb2.MetricsSpec( metrics=[ config_pb2.MetricConfig(class_name='RootMeanSquaredError', config=json.dumps( { 'name': 'rmse', 'dtype': 'float32' }, sort_keys=True)) ], model_names=['model_name1', 'model_name2'], output_names=['output_name2'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions(macro_average=True), example_weights=config_pb2.ExampleWeightOptions( unweighted=True)))
def testMetricThresholdsFromMetricsSpecs(self): slice_specs = [ config_pb2.SlicingSpec(feature_keys=['feature1']), config_pb2.SlicingSpec(feature_values={'feature2': 'value1'}) ] # For cross slice tests. baseline_slice_spec = config_pb2.SlicingSpec(feature_keys=['feature3']) metrics_specs = [ config_pb2.MetricsSpec( thresholds={ 'auc': config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold()), 'mean/label': config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold(), change_threshold=config_pb2.GenericChangeThreshold()), 'mse': config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold()) }, per_slice_thresholds={ 'auc': config_pb2.PerSliceMetricThresholds(thresholds=[ config_pb2.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold())) ]), 'mean/label': config_pb2.PerSliceMetricThresholds(thresholds=[ config_pb2.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold(), change_threshold=config_pb2. GenericChangeThreshold())) ]) }, cross_slice_thresholds={ 'auc': config_pb2.CrossSliceMetricThresholds(thresholds=[ config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold(), change_threshold=config_pb2. GenericChangeThreshold())) ]), 'mse': config_pb2.CrossSliceMetricThresholds(thresholds=[ config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( change_threshold=config_pb2. GenericChangeThreshold())), # Test for duplicate cross_slicing_spec. config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( value_threshold=config_pb2. GenericValueThreshold())) ]) }, model_names=['model_name'], output_names=['output_name']), config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='ExampleCount', config=json.dumps({'name': 'example_count'}), threshold=config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], example_weights=config_pb2. ExampleWeightOptions(unweighted=True)), config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='WeightedExampleCount', config=json.dumps({'name': 'weighted_example_count'}), threshold=config_pb2.MetricThreshold( value_threshold=config_pb2.GenericValueThreshold())) ], model_names=['model_name1', 'model_name2'], output_names=[ 'output_name1', 'output_name2' ], example_weights=config_pb2. ExampleWeightOptions(weighted=True)), config_pb2.MetricsSpec(metrics=[ config_pb2.MetricConfig( class_name='MeanSquaredError', config=json.dumps({'name': 'mse'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold())), config_pb2.MetricConfig( class_name='MeanLabel', config=json.dumps({'name': 'mean_label'}), threshold=config_pb2.MetricThreshold( change_threshold=config_pb2.GenericChangeThreshold()), per_slice_thresholds=[ config_pb2.PerSliceMetricThreshold( slicing_specs=slice_specs, threshold=config_pb2.MetricThreshold( change_threshold=config_pb2. GenericChangeThreshold())), ], cross_slice_thresholds=[ config_pb2.CrossSliceMetricThreshold( cross_slicing_specs=[ config_pb2.CrossSlicingSpec( baseline_spec=baseline_slice_spec, slicing_specs=slice_specs) ], threshold=config_pb2.MetricThreshold( change_threshold=config_pb2. GenericChangeThreshold())) ]), ], model_names=['model_name'], output_names=['output_name'], binarize=config_pb2.BinarizationOptions( class_ids={'values': [0, 1]}), aggregate=config_pb2.AggregationOptions( macro_average=True, class_weights={ 0: 1.0, 1: 1.0 })) ] thresholds = metric_specs.metric_thresholds_from_metrics_specs( metrics_specs, eval_config=config_pb2.EvalConfig()) expected_keys_and_threshold_counts = { metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=False, example_weighted=None): 4, metric_types.MetricKey(name='auc', model_name='model_name', output_name='output_name', is_diff=True, example_weighted=None): 1, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=True, example_weighted=None): 3, metric_types.MetricKey(name='mean/label', model_name='model_name', output_name='output_name', is_diff=False, example_weighted=None): 3, metric_types.MetricKey(name='example_count', model_name='model_name1'): 1, metric_types.MetricKey(name='example_count', model_name='model_name2'): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name1', example_weighted=True): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name1', output_name='output_name2', example_weighted=True): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name1', example_weighted=True): 1, metric_types.MetricKey(name='weighted_example_count', model_name='model_name2', output_name='output_name2', example_weighted=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=False, example_weighted=None): 1, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', is_diff=True, example_weighted=None): 2, metric_types.MetricKey(name='mse', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 1, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=0), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', sub_key=metric_types.SubKey(class_id=1), is_diff=True): 4, metric_types.MetricKey(name='mean_label', model_name='model_name', output_name='output_name', aggregation_type=metric_types.AggregationType(macro_average=True), is_diff=True): 4 } self.assertLen(thresholds, len(expected_keys_and_threshold_counts)) for key, count in expected_keys_and_threshold_counts.items(): self.assertIn(key, thresholds) self.assertLen(thresholds[key], count, 'failed for key {}'.format(key))