Example #1
0
def default_multi_class_classification_specs(
    model_names: Optional[List[Text]] = None,
    output_names: Optional[List[Text]] = None,
    output_weights: Optional[Dict[Text, float]] = None,
    binarize: Optional[config_pb2.BinarizationOptions] = None,
    aggregate: Optional[config_pb2.AggregationOptions] = None,
    sparse: bool = True) -> List[config_pb2.MetricsSpec]:
  """Returns default metric specs for multi-class classification problems.

  Args:
    model_names: Optional model names if multi-model evaluation.
    output_names: Optional list of output names (if multi-output model).
    output_weights: Optional output weights for creating overall metric
      aggregated across outputs (if multi-output model). If a weight is not
      provided for an output, it's weight defaults to 0.0 (i.e. output ignored).
    binarize: Optional settings for binarizing multi-class/multi-label metrics.
    aggregate: Optional settings for aggregating multi-class/multi-label
      metrics.
    sparse: True if the labels are sparse.
  """

  if sparse:
    metrics = [
        tf.keras.metrics.SparseCategoricalCrossentropy(name='loss'),
        tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
    ]
  else:
    metrics = [
        tf.keras.metrics.CategoricalCrossentropy(name='loss'),
        tf.keras.metrics.CategoricalAccuracy(name='accuracy')
    ]
  metrics.append(
      multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot())
  if binarize is not None:
    for top_k in binarize.top_k_list.values:
      metrics.extend([
          tf.keras.metrics.Precision(name='precision', top_k=top_k),
          tf.keras.metrics.Recall(name='recall', top_k=top_k)
      ])
    binarize_without_top_k = config_pb2.BinarizationOptions()
    binarize_without_top_k.CopyFrom(binarize)
    binarize_without_top_k.ClearField('top_k_list')
    binarize = binarize_without_top_k
  multi_class_metrics = specs_from_metrics(
      metrics,
      model_names=model_names,
      output_names=output_names,
      output_weights=output_weights)
  if aggregate is None:
    aggregate = config_pb2.AggregationOptions(micro_average=True)
  multi_class_metrics.extend(
      default_binary_classification_specs(
          model_names=model_names,
          output_names=output_names,
          output_weights=output_weights,
          binarize=binarize,
          aggregate=aggregate))
  return multi_class_metrics
    def testToComputationsWithMixedAggregationAndNonAggregationMetrics(self):
        computations = metric_specs.to_computations([
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(class_name='CategoricalAccuracy')
            ]),
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(class_name='BinaryCrossentropy')
                ],
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [1]}),
                aggregate=config_pb2.AggregationOptions(micro_average=True))
        ], config_pb2.EvalConfig())

        # 3 separate computations should be used (one for aggregated metrics, one
        # for non-aggregated metrics, and one for metrics associated with class 1)
        self.assertLen(computations, 3)
Example #3
0
    def _aggregation_options(
        self, me_aggregation: me_proto.Aggregation
    ) -> Optional[config_pb2.AggregationOptions]:
        """Convert ME Aggregation into TFMA AggregationOptions.

    Args:
      me_aggregation: Input ME Aggregation.

    Returns:
      TFMA AggregationOptions.
    """
        if not me_aggregation:
            return None
        tfma_aggregation = config_pb2.AggregationOptions()
        if me_aggregation.micro_average:
            tfma_aggregation.micro_average = True
        if me_aggregation.macro_average:
            tfma_aggregation.macro_average = True
        if me_aggregation.class_weights:
            tfma_aggregation.class_weights.update(me_aggregation.class_weights)

        return tfma_aggregation
    def testToComputations(self):
        computations = metric_specs.to_computations(
            metric_specs.specs_from_metrics(
                [
                    tf.keras.metrics.MeanSquaredError('mse'),
                    # Add a loss exactly same as metric
                    # (https://github.com/tensorflow/tfx/issues/1550)
                    tf.keras.losses.MeanSquaredError(name='loss'),
                    calibration.MeanLabel('mean_label')
                ],
                model_names=['model_name'],
                output_names=['output_1', 'output_2'],
                output_weights={
                    'output_1': 1.0,
                    'output_2': 1.0
                },
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config_pb2.AggregationOptions(macro_average=True,
                                                        class_weights={
                                                            0: 1.0,
                                                            1: 1.0
                                                        })),
            config_pb2.EvalConfig())

        keys = []
        for m in computations:
            for k in m.keys:
                if not k.name.startswith('_'):
                    keys.append(k)
        self.assertLen(keys, 31)
        self.assertIn(
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name'), keys)
        for output_name in ('output_1', 'output_2', ''):
            self.assertIn(
                metric_types.MetricKey(name='weighted_example_count',
                                       model_name='model_name',
                                       output_name=output_name,
                                       example_weighted=True), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mse',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=0)), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mse',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=1)), keys)
            aggregation_type = metric_types.AggregationType(
                macro_average=True) if output_name else None
            self.assertIn(
                metric_types.MetricKey(name='mse',
                                       model_name='model_name',
                                       output_name=output_name,
                                       aggregation_type=aggregation_type),
                keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='loss',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=0)), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='loss',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=1)), keys)
            aggregation_type = metric_types.AggregationType(
                macro_average=True) if output_name else None
            self.assertIn(
                metric_types.MetricKey(name='loss',
                                       model_name='model_name',
                                       output_name=output_name,
                                       aggregation_type=aggregation_type),
                keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mean_label',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=0)), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mean_label',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=1)), keys)
            aggregation_type = metric_types.AggregationType(
                macro_average=True) if output_name else None
            self.assertIn(
                metric_types.MetricKey(name='mean_label',
                                       model_name='model_name',
                                       output_name=output_name,
                                       aggregation_type=aggregation_type),
                keys)
    def testSpecsFromMetrics(self):
        metrics_specs = metric_specs.specs_from_metrics(
            {
                'output_name1': [
                    tf.keras.metrics.Precision(name='precision'),
                    tf.keras.metrics.MeanSquaredError('mse'),
                    tf.keras.losses.MeanAbsoluteError(name='mae'),
                ],
                'output_name2': [
                    confusion_matrix_metrics.Precision(name='precision'),
                    tf.keras.losses.MeanAbsolutePercentageError(name='mape'),
                    calibration.MeanPrediction('mean_prediction')
                ]
            },
            unweighted_metrics={
                'output_name1': [calibration.MeanLabel('mean_label')],
                'output_name2':
                [tf.keras.metrics.RootMeanSquaredError('rmse')]
            },
            model_names=['model_name1', 'model_name2'],
            binarize=config_pb2.BinarizationOptions(
                class_ids={'values': [0, 1]}),
            aggregate=config_pb2.AggregationOptions(macro_average=True))

        self.assertLen(metrics_specs, 7)
        self.assertProtoEquals(
            metrics_specs[0],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(class_name='ExampleCount',
                                            config=json.dumps(
                                                {'name': 'example_count'})),
                ],
                model_names=['model_name1', 'model_name2'],
                example_weights=config_pb2.ExampleWeightOptions(
                    unweighted=True)))
        self.assertProtoEquals(
            metrics_specs[1],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(
                        class_name='WeightedExampleCount',
                        config=json.dumps({'name': 'weighted_example_count'})),
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name1'],
                example_weights=config_pb2.ExampleWeightOptions(
                    weighted=True)))
        self.assertProtoEquals(
            metrics_specs[2],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(class_name='Precision',
                                            config=json.dumps(
                                                {
                                                    'name': 'precision',
                                                    'class_id': None,
                                                    'thresholds': None,
                                                    'top_k': None
                                                },
                                                sort_keys=True)),
                    config_pb2.MetricConfig(class_name='MeanSquaredError',
                                            config=json.dumps(
                                                {
                                                    'name': 'mse',
                                                    'dtype': 'float32',
                                                },
                                                sort_keys=True)),
                    config_pb2.MetricConfig(
                        class_name='MeanAbsoluteError',
                        module=metric_specs._TF_LOSSES_MODULE,
                        config=json.dumps({
                            'reduction': 'auto',
                            'name': 'mae'
                        },
                                          sort_keys=True))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name1'],
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config_pb2.AggregationOptions(macro_average=True)))
        self.assertProtoEquals(
            metrics_specs[3],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(class_name='MeanLabel',
                                            config=json.dumps(
                                                {'name': 'mean_label'}))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name1'],
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config_pb2.AggregationOptions(macro_average=True),
                example_weights=config_pb2.ExampleWeightOptions(
                    unweighted=True)))
        self.assertProtoEquals(
            metrics_specs[4],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(
                        class_name='WeightedExampleCount',
                        config=json.dumps({'name': 'weighted_example_count'})),
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name2'],
                example_weights=config_pb2.ExampleWeightOptions(
                    weighted=True)))
        self.assertProtoEquals(
            metrics_specs[5],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(class_name='Precision',
                                            config=json.dumps(
                                                {
                                                    'name': 'precision',
                                                },
                                                sort_keys=True)),
                    config_pb2.MetricConfig(
                        class_name='MeanAbsolutePercentageError',
                        module=metric_specs._TF_LOSSES_MODULE,
                        config=json.dumps({
                            'reduction': 'auto',
                            'name': 'mape'
                        },
                                          sort_keys=True)),
                    config_pb2.MetricConfig(class_name='MeanPrediction',
                                            config=json.dumps(
                                                {'name': 'mean_prediction'}))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name2'],
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config_pb2.AggregationOptions(macro_average=True)))
        self.assertProtoEquals(
            metrics_specs[6],
            config_pb2.MetricsSpec(
                metrics=[
                    config_pb2.MetricConfig(class_name='RootMeanSquaredError',
                                            config=json.dumps(
                                                {
                                                    'name': 'rmse',
                                                    'dtype': 'float32'
                                                },
                                                sort_keys=True))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name2'],
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config_pb2.AggregationOptions(macro_average=True),
                example_weights=config_pb2.ExampleWeightOptions(
                    unweighted=True)))
    def testMetricThresholdsFromMetricsSpecs(self):
        slice_specs = [
            config_pb2.SlicingSpec(feature_keys=['feature1']),
            config_pb2.SlicingSpec(feature_values={'feature2': 'value1'})
        ]

        # For cross slice tests.
        baseline_slice_spec = config_pb2.SlicingSpec(feature_keys=['feature3'])

        metrics_specs = [
            config_pb2.MetricsSpec(
                thresholds={
                    'auc':
                    config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold()),
                    'mean/label':
                    config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold(),
                        change_threshold=config_pb2.GenericChangeThreshold()),
                    'mse':
                    config_pb2.MetricThreshold(
                        change_threshold=config_pb2.GenericChangeThreshold())
                },
                per_slice_thresholds={
                    'auc':
                    config_pb2.PerSliceMetricThresholds(thresholds=[
                        config_pb2.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold()))
                    ]),
                    'mean/label':
                    config_pb2.PerSliceMetricThresholds(thresholds=[
                        config_pb2.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold(),
                                change_threshold=config_pb2.
                                GenericChangeThreshold()))
                    ])
                },
                cross_slice_thresholds={
                    'auc':
                    config_pb2.CrossSliceMetricThresholds(thresholds=[
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold(),
                                change_threshold=config_pb2.
                                GenericChangeThreshold()))
                    ]),
                    'mse':
                    config_pb2.CrossSliceMetricThresholds(thresholds=[
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                change_threshold=config_pb2.
                                GenericChangeThreshold())),
                        # Test for duplicate cross_slicing_spec.
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                value_threshold=config_pb2.
                                GenericValueThreshold()))
                    ])
                },
                model_names=['model_name'],
                output_names=['output_name']),
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(
                    class_name='ExampleCount',
                    config=json.dumps({'name': 'example_count'}),
                    threshold=config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold()))
            ],
                                   model_names=['model_name1', 'model_name2'],
                                   example_weights=config_pb2.
                                   ExampleWeightOptions(unweighted=True)),
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(
                    class_name='WeightedExampleCount',
                    config=json.dumps({'name': 'weighted_example_count'}),
                    threshold=config_pb2.MetricThreshold(
                        value_threshold=config_pb2.GenericValueThreshold()))
            ],
                                   model_names=['model_name1', 'model_name2'],
                                   output_names=[
                                       'output_name1', 'output_name2'
                                   ],
                                   example_weights=config_pb2.
                                   ExampleWeightOptions(weighted=True)),
            config_pb2.MetricsSpec(metrics=[
                config_pb2.MetricConfig(
                    class_name='MeanSquaredError',
                    config=json.dumps({'name': 'mse'}),
                    threshold=config_pb2.MetricThreshold(
                        change_threshold=config_pb2.GenericChangeThreshold())),
                config_pb2.MetricConfig(
                    class_name='MeanLabel',
                    config=json.dumps({'name': 'mean_label'}),
                    threshold=config_pb2.MetricThreshold(
                        change_threshold=config_pb2.GenericChangeThreshold()),
                    per_slice_thresholds=[
                        config_pb2.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config_pb2.MetricThreshold(
                                change_threshold=config_pb2.
                                GenericChangeThreshold())),
                    ],
                    cross_slice_thresholds=[
                        config_pb2.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config_pb2.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config_pb2.MetricThreshold(
                                change_threshold=config_pb2.
                                GenericChangeThreshold()))
                    ]),
            ],
                                   model_names=['model_name'],
                                   output_names=['output_name'],
                                   binarize=config_pb2.BinarizationOptions(
                                       class_ids={'values': [0, 1]}),
                                   aggregate=config_pb2.AggregationOptions(
                                       macro_average=True,
                                       class_weights={
                                           0: 1.0,
                                           1: 1.0
                                       }))
        ]

        thresholds = metric_specs.metric_thresholds_from_metrics_specs(
            metrics_specs, eval_config=config_pb2.EvalConfig())

        expected_keys_and_threshold_counts = {
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False,
                                   example_weighted=None):
            4,
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True,
                                   example_weighted=None):
            1,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True,
                                   example_weighted=None):
            3,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False,
                                   example_weighted=None):
            3,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name1',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name2',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name1',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name2',
                                   example_weighted=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False,
                                   example_weighted=None):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True,
                                   example_weighted=None):
            2,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            4
        }
        self.assertLen(thresholds, len(expected_keys_and_threshold_counts))
        for key, count in expected_keys_and_threshold_counts.items():
            self.assertIn(key, thresholds)
            self.assertLen(thresholds[key], count,
                           'failed for key {}'.format(key))