Exemple #1
0
def _aggregation_type(
    spec: config.MetricsSpec) -> Optional[metric_types.AggregationType]:
  """Returns AggregationType associated with AggregationOptions at offset."""
  if spec.aggregate.micro_average:
    return metric_types.AggregationType(micro_average=True)
  if spec.aggregate.macro_average:
    return metric_types.AggregationType(macro_average=True)
  if spec.aggregate.weighted_macro_average:
    return metric_types.AggregationType(weighted_macro_average=True)
  return None
Exemple #2
0
 def testRaisesErrorForInvalidNonSparseSettings(self):
   with self.assertRaises(ValueError):
     tf_metric_wrapper.tf_metric_computations(
         [
             tf.keras.metrics.SparseCategoricalCrossentropy(
                 name='sparse_categorical_crossentropy')
         ],
         aggregation_type=metric_types.AggregationType(micro_average=True))
Exemple #3
0
    def testStandardMetricInputsWithTopKAndAggregationTypeToNumpy(self):
        example = metric_types.StandardMetricInputs(
            labels={'output_name': np.array([1])},
            predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weights={'output_name': np.array([1.0])})
        iterator = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            sub_key=metric_types.SubKey(top_k=2),
            aggregation_type=metric_types.AggregationType(micro_average=True))

        for expected_label, expected_prediction in zip((1.0, 0.0), (0.5, 0.9)):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight, np.array([1.0]))
    def testMetricWithClassWeights(self):
        computation = tf_metric_wrapper.tf_metric_computations(
            [tf.keras.metrics.MeanSquaredError(name='mse')],
            aggregation_type=metric_types.AggregationType(micro_average=True),
            class_weights={
                0: 0.1,
                1: 0.2,
                2: 0.3,
                3: 0.4
            })[0]

        # Simulate a multi-class problem with 4 labels. The use of class weights
        # implies micro averaging which only makes sense for multi-class metrics.
        example = {
            'labels': [0, 0, 1, 0],
            'predictions': [0, 0.5, 0.3, 0.9],
            'example_weights': [1.0]
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    mse_key = metric_types.MetricKey(name='mse')
                    # numerator = (0.1*0**2 + 0.2*0.5**2 + 0.3*0.7**2 + 0.4*0.9**2)
                    # denominator = (.1 + .2 + 0.3 + 0.4)
                    # numerator / denominator = 0.521
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {mse_key: 0.521})

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
 def testMetricKeyFromProto(self):
     metric_keys = [
         metric_types.MetricKey(name='metric_name'),
         metric_types.MetricKey(name='metric_name',
                                model_name='model_name',
                                output_name='output_name',
                                sub_key=metric_types.SubKey(class_id=1),
                                is_diff=True),
         metric_types.MetricKey(
             name='metric_name',
             model_name='model_name',
             output_name='output_name',
             sub_key=metric_types.SubKey(top_k=2),
             aggregation_type=metric_types.AggregationType(
                 micro_average=True))
     ]
     for key in metric_keys:
         got_key = metric_types.MetricKey.from_proto(key.to_proto())
         self.assertEqual(key, got_key, '{} != {}'.format(key, got_key))
Exemple #6
0
    def testStandardMetricInputsWithClassWeights(self):
        example = metric_types.StandardMetricInputs(
            labels={'output_name': np.array([2])},
            predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weights={'output_name': np.array([1.0])})
        iterator = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            aggregation_type=metric_types.AggregationType(micro_average=True),
            class_weights={
                0: 1.0,
                1: 0.5,
                2: 0.25,
                3: 1.0
            },
            flatten=True)

        for expected_label, expected_prediction, expected_weight in zip(
            (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.5, 0.25, 1.0)):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight,
                                np.array([expected_weight]))
 def testAggregationTypeLessThan(self):
     self.assertLess(metric_types.AggregationType(macro_average=True),
                     metric_types.AggregationType(micro_average=True))
     self.assertLess(
         metric_types.AggregationType(weighted_macro_average=True),
         metric_types.AggregationType(macro_average=True))
Exemple #8
0
    def testMetricThresholdsFromMetricsSpecs(self):
        slice_specs = [
            config.SlicingSpec(feature_keys=['feature1']),
            config.SlicingSpec(feature_values={'feature2': 'value1'})
        ]

        # For cross slice tests.
        baseline_slice_spec = config.SlicingSpec(feature_keys=['feature3'])

        metrics_specs = [
            config.MetricsSpec(
                thresholds={
                    'auc':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()),
                    'mean/label':
                    config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold(),
                        change_threshold=config.GenericChangeThreshold()),
                    'mse':
                    config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())
                },
                per_slice_thresholds={
                    'auc':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                )))
                    ]),
                    'mean/label':
                    config.PerSliceMetricThresholds(thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ])
                },
                cross_slice_thresholds={
                    'auc':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(),
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
                    'mse':
                    config.CrossSliceMetricThresholds(thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                        # Test for duplicate cross_slicing_spec.
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold())
                        )
                    ])
                },
                model_names=['model_name'],
                output_names=['output_name']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='ExampleCount',
                    config=json.dumps({'name': 'example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='WeightedExampleCount',
                    config=json.dumps({'name': 'weighted_example_count'}),
                    threshold=config.MetricThreshold(
                        value_threshold=config.GenericValueThreshold()))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1', 'output_name2']),
            config.MetricsSpec(metrics=[
                config.MetricConfig(
                    class_name='MeanSquaredError',
                    config=json.dumps({'name': 'mse'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold())),
                config.MetricConfig(
                    class_name='MeanLabel',
                    config=json.dumps({'name': 'mean_label'}),
                    threshold=config.MetricThreshold(
                        change_threshold=config.GenericChangeThreshold()),
                    per_slice_thresholds=[
                        config.PerSliceMetricThreshold(
                            slicing_specs=slice_specs,
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                ))),
                    ],
                    cross_slice_thresholds=[
                        config.CrossSliceMetricThreshold(
                            cross_slicing_specs=[
                                config.CrossSlicingSpec(
                                    baseline_spec=baseline_slice_spec,
                                    slicing_specs=slice_specs)
                            ],
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                )))
                    ]),
            ],
                               model_names=['model_name'],
                               output_names=['output_name'],
                               binarize=config.BinarizationOptions(
                                   class_ids={'values': [0, 1]}),
                               aggregate=config.AggregationOptions(
                                   macro_average=True,
                                   class_weights={
                                       0: 1.0,
                                       1: 1.0
                                   }))
        ]

        thresholds = metric_specs.metric_thresholds_from_metrics_specs(
            metrics_specs)

        expected_keys_and_threshold_counts = {
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            4,
            metric_types.MetricKey(name='auc',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            3,
            metric_types.MetricKey(name='mean/label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            3,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name1',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name1'):
            1,
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name2',
                                   output_name='output_name2'):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=True):
            2,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   is_diff=False):
            1,
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            1,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1),
                                   is_diff=True):
            4,
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   aggregation_type=metric_types.AggregationType(macro_average=True),
                                   is_diff=True):
            4
        }
        self.assertLen(thresholds, len(expected_keys_and_threshold_counts))
        for key, count in expected_keys_and_threshold_counts.items():
            self.assertIn(key, thresholds)
            self.assertLen(thresholds[key], count,
                           'failed for key {}'.format(key))
Exemple #9
0
def macro_average(
    metric_name: Text,
    sub_keys: Iterable[metric_types.SubKey],
    eval_config: Optional[config_pb2.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    class_weights: Optional[Dict[int, float]] = None
) -> metric_types.MetricComputations:
    """Returns metric computations for computing macro average of given metric.

  Args:
    metric_name: Name of underlying metric average is being computed for.
    sub_keys: Sub keys used to compute the metric (e.g. class_ids, etc).
    eval_config: Eval config.
    model_name: Optional model name.
    output_name: Optional output name.
    sub_key: Optional sub key associated with aggregation metric (e.g. top_k).
    class_weights: Optional class weights to apply. Required if sub_key is not
      provided. If class_weights are provided, but a sub_key.class_id (if
      sub_key is None) or sub_key.k (if sub_key is top_k) is not set or not
      found in the dictionary then 0.0 is assumed.

  Returns:
    Computation for performing the macro average.
  """
    del eval_config

    key = metric_types.MetricKey(
        name=metric_name,
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        aggregation_type=metric_types.AggregationType(macro_average=True))

    def result(
        metrics: Dict[metric_types.MetricKey, float]
    ) -> Dict[metric_types.MetricKey, float]:
        """Returns macro average."""
        total_value = 0.0
        total_weight = 0.0
        for sub_key in sub_keys:
            child_key = metric_types.MetricKey(name=metric_name,
                                               model_name=model_name,
                                               output_name=output_name,
                                               sub_key=sub_key)
            if child_key not in metrics:
                # Use private name if not found under metric name
                child_key = metric_types.MetricKey(name='_' + metric_name,
                                                   model_name=model_name,
                                                   output_name=output_name,
                                                   sub_key=sub_key)
            weight = 1.0 if not class_weights else 0.0
            offset = None
            if (child_key.sub_key is not None
                    and child_key.sub_key.class_id is not None):
                offset = child_key.sub_key.class_id
            elif child_key.sub_key is not None and child_key.sub_key.k is not None:
                offset = child_key.sub_key.k
            if offset is not None and offset in class_weights:
                weight = class_weights[offset]
            total_value += _to_float(metrics[child_key]) * weight
            total_weight += weight
        average = total_value / total_weight if total_weight else float('nan')
        return {key: average}

    return [metric_types.DerivedMetricComputation(keys=[key], result=result)]
Exemple #10
0
def weighted_macro_average(
    metric_name: Text,
    sub_keys: Iterable[metric_types.SubKey],
    eval_config: Optional[config_pb2.EvalConfig] = None,
    model_name: Text = '',
    output_name: Text = '',
    sub_key: Optional[metric_types.SubKey] = None,
    class_weights: Optional[Dict[int, float]] = None
) -> metric_types.MetricComputations:
    """Returns metric computations for computing weighted macro average of metric.

  The weights per class are based on the percentage of positive labels for each
  class.

  Args:
    metric_name: Name of metric weighted average is being computed for.
    sub_keys: Sub keys used to compute the metric (e.g. class_ids, etc).
    eval_config: Eval config.
    model_name: Optional model name.
    output_name: Optional output name.
    sub_key: Optional sub key associated with aggregation metric (e.g. top_k).
    class_weights: Optional class weights to apply. Required if sub_key is not
      provided. If class_weights are provided, but a sub_key.class_id (if
      sub_key is None) or sub_key.k (if sub_key is top_k) is not set or not
      found in the dictionary then 0.0 is assumed. Note that these weights are
      applied in addition to the weights based on the positive labels for each
      class.

  Returns:
    Computation for performing the weighted macro average.
  """
    key = metric_types.MetricKey(
        name=metric_name,
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        aggregation_type=metric_types.AggregationType(macro_average=True))

    class_ids = [k.class_id for k in sub_keys if k.class_id is not None]

    # Compute the weights for labels.
    computations = _class_weights_from_labels(class_ids=class_ids,
                                              eval_config=eval_config,
                                              model_name=model_name,
                                              output_name=output_name)
    # Class weights metrics are based on a single computation and key.
    class_weights_from_labels_key = computations[0].keys[0]

    def result(
        metrics: Dict[metric_types.MetricKey, Any]
    ) -> Dict[metric_types.MetricKey, float]:
        """Returns weighted macro average."""
        class_weights_from_labels = metrics[class_weights_from_labels_key]
        total_value = 0.0
        total_weight = 0.0
        for sub_key in sub_keys:
            child_key = metric_types.MetricKey(name=metric_name,
                                               model_name=model_name,
                                               output_name=output_name,
                                               sub_key=sub_key)
            if child_key not in metrics:
                # Use private name if not found under metric name
                child_key = metric_types.MetricKey(name='_' + metric_name,
                                                   model_name=model_name,
                                                   output_name=output_name,
                                                   sub_key=sub_key)
            weight = 1.0 if not class_weights else 0.0
            offset = None
            if (child_key.sub_key is not None
                    and child_key.sub_key.class_id is not None):
                offset = child_key.sub_key.class_id
            elif child_key.sub_key is not None and child_key.sub_key.k is not None:
                offset = child_key.sub_key.k
            if offset is not None:
                if (class_weights_from_labels and child_key.sub_key.class_id
                        in class_weights_from_labels):
                    weight = class_weights_from_labels[offset]
                if class_weights and child_key.sub_key.class_id in class_weights:
                    weight *= class_weights[offset]
            total_value += _to_float(metrics[child_key]) * weight
            total_weight += weight
        average = total_value / total_weight if total_weight else float('nan')
        return {key: average}

    derived_computation = metric_types.DerivedMetricComputation(keys=[key],
                                                                result=result)
    computations.append(derived_computation)
    return computations
    def testToComputations(self):
        computations = metric_specs.to_computations(
            metric_specs.specs_from_metrics(
                [
                    tf.keras.metrics.MeanSquaredError('mse'),
                    # Add a loss exactly same as metric
                    # (https://github.com/tensorflow/tfx/issues/1550)
                    tf.keras.losses.MeanSquaredError(name='loss'),
                    calibration.MeanLabel('mean_label')
                ],
                model_names=['model_name'],
                output_names=['output_1', 'output_2'],
                output_weights={
                    'output_1': 1.0,
                    'output_2': 1.0
                },
                binarize=config_pb2.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config_pb2.AggregationOptions(macro_average=True,
                                                        class_weights={
                                                            0: 1.0,
                                                            1: 1.0
                                                        })),
            config_pb2.EvalConfig())

        keys = []
        for m in computations:
            for k in m.keys:
                if not k.name.startswith('_'):
                    keys.append(k)
        self.assertLen(keys, 31)
        self.assertIn(
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name'), keys)
        for output_name in ('output_1', 'output_2', ''):
            self.assertIn(
                metric_types.MetricKey(name='weighted_example_count',
                                       model_name='model_name',
                                       output_name=output_name,
                                       example_weighted=True), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mse',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=0)), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mse',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=1)), keys)
            aggregation_type = metric_types.AggregationType(
                macro_average=True) if output_name else None
            self.assertIn(
                metric_types.MetricKey(name='mse',
                                       model_name='model_name',
                                       output_name=output_name,
                                       aggregation_type=aggregation_type),
                keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='loss',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=0)), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='loss',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=1)), keys)
            aggregation_type = metric_types.AggregationType(
                macro_average=True) if output_name else None
            self.assertIn(
                metric_types.MetricKey(name='loss',
                                       model_name='model_name',
                                       output_name=output_name,
                                       aggregation_type=aggregation_type),
                keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mean_label',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=0)), keys)
            self.assertIn(
                metric_types.MetricKey(
                    name='mean_label',
                    model_name='model_name',
                    output_name=output_name,
                    sub_key=metric_types.SubKey(class_id=1)), keys)
            aggregation_type = metric_types.AggregationType(
                macro_average=True) if output_name else None
            self.assertIn(
                metric_types.MetricKey(name='mean_label',
                                       model_name='model_name',
                                       output_name=output_name,
                                       aggregation_type=aggregation_type),
                keys)