Example #1
0
def to_standard_metric_inputs(
    extracts: types.Extracts,
    include_features: bool = False,
    include_transformed_features: bool = False,
    include_attributions: bool = False) -> metric_types.StandardMetricInputs:
  """Verifies extract keys and converts extracts to StandardMetricInputs."""
  if constants.LABELS_KEY not in extracts:
    raise ValueError('"{}" key not found in extracts. Check that the '
                     'configuration is setup properly to specify the name of '
                     'label input and that the proper extractor has been '
                     'configured to extract the labels from the inputs.'.format(
                         constants.LABELS_KEY))
  if constants.PREDICTIONS_KEY not in extracts:
    raise ValueError('"{}" key not found in extracts. Check that the proper '
                     'extractor has been configured to perform model '
                     'inference.'.format(constants.PREDICTIONS_KEY))
  if include_features and constants.FEATURES_KEY not in extracts:
    raise ValueError('"{}" key not found in extracts. Check that the proper '
                     'extractor has been configured to extract the features '
                     'from the inputs.'.format(constants.FEATURES_KEY))
  if (include_transformed_features and
      constants.TRANSFORMED_FEATURES_KEY not in extracts):
    raise ValueError('"{}" key not found in extracts. Check that the proper '
                     'extractor has been configured to extract the transformed '
                     'features from the inputs.'.format(
                         constants.TRANSFORMED_FEATURES_KEY))
  if (include_attributions and constants.ATTRIBUTIONS_KEY not in extracts):
    raise ValueError('"{}" key not found in extracts. Check that the proper '
                     'extractor has been configured to extract the '
                     'attributions from the inputs.'.format(
                         constants.ATTRIBUTIONS_KEY))
  return metric_types.StandardMetricInputs(extracts)
Example #2
0
def to_standard_metric_inputs(
    extracts: types.Extracts,
    include_features: bool = False) -> metric_types.StandardMetricInputs:
  """Filters and converts extracts to StandardMetricInputs."""
  if constants.LABELS_KEY not in extracts:
    raise ValueError('"{}" key not found in extracts. Check that the '
                     'configuration is setup properly to specify the name of '
                     'label input and that the proper extractor has been '
                     'configured to extract the labels from the inputs.'.format(
                         constants.LABELS_KEY))
  if constants.PREDICTIONS_KEY not in extracts:
    raise ValueError('"{}" key not found in extracts. Check that the proper '
                     'extractor has been configured to perform model '
                     'inference.'.format(constants.PREDICTIONS_KEY))
  example_weights = None
  if constants.EXAMPLE_WEIGHTS_KEY in extracts:
    example_weights = extracts[constants.EXAMPLE_WEIGHTS_KEY]
  features = None
  if include_features:
    if constants.FEATURES_KEY not in extracts:
      raise ValueError('"{}" key not found in extracts. Check that the proper '
                       'extractor has been configured to extract the features '
                       'from the inputs.'.format(constants.FEATURES_KEY))
    features = extracts[constants.FEATURES_KEY]
  return metric_types.StandardMetricInputs(extracts[constants.LABELS_KEY],
                                           extracts[constants.PREDICTIONS_KEY],
                                           example_weights, features)
Example #3
0
    def testStandardMetricInputsWithMultipleOutputs(self):
        example = metric_types.StandardMetricInputs(label={
            'output1': np.array([0, 1]),
            'output2': np.array([1, 1])
        },
                                                    prediction={
                                                        'output1':
                                                        np.array([0, 0.5]),
                                                        'output2':
                                                        np.array([0.2, 0.8])
                                                    },
                                                    example_weight={
                                                        'output1':
                                                        np.array([0.5]),
                                                        'output2':
                                                        np.array([1.0])
                                                    })

        for output in ('output1', 'output2'):
            iterator = metric_util.to_label_prediction_example_weight(
                example, output_name=output, flatten=False)
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label, example.label[output])
            self.assertAllEqual(got_pred, example.prediction[output])
            self.assertAllClose(got_example_weight,
                                example.example_weight[output])
Example #4
0
    def testStandardMetricInputsWithCustomLabelKeys(self):
        example = metric_types.StandardMetricInputs(
            labels={
                'custom_label': np.array([2]),
                'other_label': np.array([0])
            },
            predictions={'custom_prediction': np.array([0, 0.5, 0.3, 0.9])},
            example_weights=np.array([1.0]))
        eval_config = config_pb2.EvalConfig(model_specs=[
            config_pb2.ModelSpec(label_key='custom_label',
                                 prediction_key='custom_prediction')
        ])
        iterator = metric_util.to_label_prediction_example_weight(
            example, eval_config=eval_config)

        for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0),
                                                       (0.0, 0.5, 0.3, 0.9)):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label,
                                np.array([expected_label]),
                                atol=0,
                                rtol=0)
            self.assertAllClose(got_pred,
                                np.array([expected_prediction]),
                                atol=0,
                                rtol=0)
            self.assertAllClose(got_example_weight,
                                np.array([1.0]),
                                atol=0,
                                rtol=0)
Example #5
0
    def testStandardMetricInputsWithMissingStringLabel(self):
        example = metric_types.StandardMetricInputs(
            label=np.array(['d']),
            prediction={
                'scores': np.array([0.2, 0.7, 0.1]),
                'classes': np.array(['a', 'b', 'c'])
            },
            example_weight=np.array([1.0]))
        iterator = metric_util.to_label_prediction_example_weight(example)

        for expected_label, expected_prediction in zip((0.0, 0.0, 0.0),
                                                       (0.2, 0.7, 0.1)):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label,
                                np.array([expected_label]),
                                atol=0,
                                rtol=0)
            self.assertAllClose(got_pred,
                                np.array([expected_prediction]),
                                atol=0,
                                rtol=0)
            self.assertAllClose(got_example_weight,
                                np.array([1.0]),
                                atol=0,
                                rtol=0)
Example #6
0
 def testStandardMetricInputsRequiringSingleExampleWeightRaisesError(self):
     with self.assertRaises(ValueError):
         example = metric_types.StandardMetricInputs(
             labels=np.array([2]),
             predictions=np.array([0, 0.5, 0.3, 0.9]),
             example_weights=np.array([1.0, 0.0]))
         next(
             metric_util.to_label_prediction_example_weight(
                 example, require_single_example_weight=True))
  def testStandardMetricInputsWithZeroWeightsToNumpyWithoutFlatten(self):
    example = metric_types.StandardMetricInputs(
        np.array([2]), np.array([0, 0.5, 0.3, 0.9]), np.array([0.0]))
    got_label, got_pred, got_example_weight = next(
        metric_util.to_label_prediction_example_weight(example, flatten=False))

    self.assertAllClose(got_label, np.array([2]))
    self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9]))
    self.assertAllClose(got_example_weight, np.array([0.0]))
 def testStandardMetricInputsWithClassWeightsRaisesErrorWithoutFlatten(self):
   with self.assertRaises(ValueError):
     example = metric_types.StandardMetricInputs(
         np.array([2]), np.array([0, 0.5, 0.3, 0.9]), np.array([1.0]))
     next(
         metric_util.to_label_prediction_example_weight(
             example, class_weights={
                 1: 0.5,
                 2: 0.25
             }, flatten=False))
Example #9
0
 def testStandardMetricInputsWithNonScalarWeightsNoFlatten(self):
     example = metric_types.StandardMetricInputs(
         label=np.array([2]),
         prediction=np.array([0, 0.5, 0.3, 0.9]),
         example_weight=np.array([1.0, 0.0, 1.0, 1.0]))
     got_label, got_pred, got_example_weight = next(
         metric_util.to_label_prediction_example_weight(
             example, flatten=False, require_single_example_weight=False))
     self.assertAllClose(got_label, np.array([2]))
     self.assertAllEqual(got_pred, np.array([0, 0.5, 0.3, 0.9]))
     self.assertAllClose(got_example_weight, np.array([1.0, 0.0, 1.0, 1.0]))
    def testStandardMetricInputsWithZeroWeightsToNumpy(self):
        example = metric_types.StandardMetricInputs(
            np.array([2]), np.array([0, 0.5, 0.3, 0.9]), np.array([0.0]))
        iterable = metric_util.to_label_prediction_example_weight(example)

        for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0),
                                                       (0.0, 0.5, 0.3, 0.9)):
            got_label, got_pred, got_example_weight = next(iterable)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight, np.array([0.0]))
    def testStandardMetricInputsToNumpyWithoutFlatten(self):
        example = metric_types.StandardMetricInputs(
            label={'output_name': np.array([2])},
            prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weight={'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = next(
            metric_util.to_label_prediction_example_weight(
                example, output_name='output_name', flatten=False))

        self.assertAllClose(got_label, np.array([2]))
        self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
Example #12
0
    def testStandardMetricInputsToNumpy(self):
        example = metric_types.StandardMetricInputs(
            {'output_name': np.array([2])},
            {'output_name': np.array([0, 0.5, 0.3, 0.9])},
            {'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = (
            metric_util.to_label_prediction_example_weight(
                example, output_name='output_name'))

        self.assertAllClose(got_label, np.array([2]))
        self.assertAllClose(got_pred, np.array([0, 0.5, 0.3, 0.9]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
Example #13
0
    def testStandardMetricInputsWithoutPredictions(self):
        example = metric_types.StandardMetricInputs(
            label={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            prediction={'output_name': np.array([])},
            example_weight={'output_name': np.array([1.0])})
        iterator = metric_util.to_label_prediction_example_weight(
            example, output_name='output_name')

        for expected_label in (0.0, 0.5, 0.3, 0.9):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllEqual(got_pred, np.array([]))
            self.assertAllClose(got_example_weight, np.array([1.0]))
  def testStandardMetricInputsWithSparseTensorValue(self):
    example = metric_types.StandardMetricInputs(
        tf.compat.v1.SparseTensorValue(
            values=np.array([1]), indices=np.array([2]), dense_shape=(0, 1)),
        np.array([0, 0.5, 0.3, 0.9]), np.array([0.0]))
    iterable = metric_util.to_label_prediction_example_weight(example)

    for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0),
                                                   (0.0, 0.5, 0.3, 0.9)):
      got_label, got_pred, got_example_weight = next(iterable)
      self.assertAllClose(got_label, np.array([expected_label]))
      self.assertAllClose(got_pred, np.array([expected_prediction]))
      self.assertAllClose(got_example_weight, np.array([0.0]))
Example #15
0
    def testStandardMetricInputsWithMissingLabelsAndExampleWeights(self):
        example = metric_types.StandardMetricInputs(
            prediction={
                'output1': np.array([0, 0.5]),
                'output2': np.array([0.2, 0.8])
            })

        for output in ('output1', 'output2'):
            iterator = metric_util.to_label_prediction_example_weight(
                example, output_name=output, flatten=False, allow_none=True)
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllEqual(got_label, np.array([]))
            self.assertAllEqual(got_pred, example.prediction[output])
            self.assertAllEqual(got_example_weight, np.array([1.0]))
    def testStandardMetricInputsToNumpy(self):
        example = metric_types.StandardMetricInputs(
            label={'output_name': np.array([2])},
            prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weight={'output_name': np.array([1.0])})
        iterable = metric_util.to_label_prediction_example_weight(
            example, output_name='output_name')

        for expected_label, expected_prediction in zip((0.0, 0.0, 1.0, 0.0),
                                                       (0.0, 0.5, 0.3, 0.9)):
            got_label, got_pred, got_example_weight = next(iterable)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight, np.array([1.0]))
Example #17
0
    def testStandardMetricInputsWithTopKToNumpy(self):
        example = metric_types.StandardMetricInputs(
            {'output_name': np.array([1])},
            {'output_name': np.array([0, 0.5, 0.3, 0.9])},
            {'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = (
            metric_util.to_label_prediction_example_weight(
                example,
                output_name='output_name',
                sub_key=metric_types.SubKey(top_k=2)))

        self.assertAllClose(got_label, np.array([0.0, 1.0]))
        self.assertAllClose(got_pred, np.array([0.9, 0.5]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
    def testStandardMetricInputsWithClassIDToNumpy(self):
        example = metric_types.StandardMetricInputs(
            label={'output_name': np.array([2])},
            prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weight={'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = next(
            metric_util.to_label_prediction_example_weight(
                example,
                output_name='output_name',
                sub_key=metric_types.SubKey(class_id=2)))

        self.assertAllClose(got_label, np.array([1.0]))
        self.assertAllClose(got_pred, np.array([0.3]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
    def testStandardMetricInputsWithTopKToNumpy(self):
        example = metric_types.StandardMetricInputs(
            {'output_name': np.array([1])},
            {'output_name': np.array([0, 0.5, 0.3, 0.9])},
            {'output_name': np.array([1.0])})
        iterable = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            sub_key=metric_types.SubKey(top_k=2))

        for expected_label, expected_prediction in zip((0.0, 1.0), (0.9, 0.5)):
            got_label, got_pred, got_example_weight = next(iterable)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight, np.array([1.0]))
Example #20
0
 def testStandardMetricInputsWithMissingPredictionRaisesError(self):
     example = metric_types.StandardMetricInputs(
         label={
             'output1': np.array([0, 1]),
             'output2': np.array([1, 1])
         },
         prediction={'output2': np.array([0.8])},
         example_weight={
             'output1': np.array([0.5]),
             'output2': np.array([1.0])
         })
     with self.assertRaisesRegex(ValueError, '"output1" key not found.*'):
         next(
             metric_util.to_label_prediction_example_weight(
                 example, output_name='output1'))
Example #21
0
    def testStandardMetricInputsWithTopKAndAggregationTypeToNumpy(self):
        example = metric_types.StandardMetricInputs(
            labels={'output_name': np.array([1])},
            predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weights={'output_name': np.array([1.0])})
        iterator = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            sub_key=metric_types.SubKey(top_k=2),
            aggregation_type=metric_types.AggregationType(micro_average=True))

        for expected_label, expected_prediction in zip((1.0, 0.0), (0.5, 0.9)):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight, np.array([1.0]))
Example #22
0
    def testStandardMetricInputsWithNonScalarWeights(self):
        example = metric_types.StandardMetricInputs(
            label={'output_name': np.array([2])},
            prediction={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weight={'output_name': np.array([1.0, 0.0, 1.0, 1.0])})
        iterable = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            require_single_example_weight=False)

        for expected_label, expected_prediction, expected_weight in zip(
            (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.0, 1.0, 1.0)):
            got_label, got_pred, got_example_weight = next(iterable)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllEqual(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight,
                                np.array([expected_weight]))
Example #23
0
 def testStandardMetricInputsWithMissingExampleWeightKeyRaisesError(self):
     example = metric_types.StandardMetricInputs(
         label={
             'output1': np.array([0, 1]),
             'output2': np.array([1, 1])
         },
         prediction={
             'output1': np.array([0.5]),
             'output2': np.array([0.8])
         },
         example_weight={'output2': np.array([1.0])})
     with self.assertRaisesRegex(
             ValueError,
             'unable to prepare example_weight for metric computation.*'):
         next(
             metric_util.to_label_prediction_example_weight(
                 example, output_name='output1'))
    def testStandardMetricInputsWithClassWeightsWithoutFlatten(self):
        example = metric_types.StandardMetricInputs(
            {'output_name': np.array([1.0, 1.0, 1.0, 0.0])},
            {'output_name': np.array([0, 0.5, 0.3, 0.9])},
            {'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = next(
            metric_util.to_label_prediction_example_weight(
                example,
                output_name='output_name',
                class_weights={
                    0: 0.5,
                    2: 0.25
                },
                flatten=False))

        self.assertAllClose(got_label, np.array([0.5, 1.0, 0.25, 0.0]))
        self.assertAllClose(got_pred, np.array([0, 0.5, 0.075, 0.9]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
Example #25
0
    def testStandardMetricInputsWithKToNumpy2D(self):
        example = metric_types.StandardMetricInputs(
            labels={'output_name': np.array([1, 2])},
            predictions={
                'output_name': np.array([[0, 0.5, 0.3, 0.9],
                                         [0.1, 0.4, 0.2, 0.3]])
            },
            example_weights={'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = next(
            metric_util.to_label_prediction_example_weight(
                example,
                output_name='output_name',
                sub_key=metric_types.SubKey(k=2),
                flatten=False,
                squeeze=False))

        self.assertAllClose(got_label, np.array([[1], [0]]))
        self.assertAllClose(got_pred, np.array([[0.5], [0.3]]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
    def testStandardMetricInputsWithClassWeights(self):
        example = metric_types.StandardMetricInputs(
            {'output_name': np.array([2])},
            {'output_name': np.array([0, 0.5, 0.3, 0.9])},
            {'output_name': np.array([1.0])})
        iterable = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            class_weights={
                1: 0.5,
                2: 0.25
            })

        for expected_label, expected_prediction in zip(
            (0.0, 0.0, 0.25, 0.0), (0.0, 0.25, 0.075, 0.9)):
            got_label, got_pred, got_example_weight = next(iterable)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight, np.array([1.0]))
Example #27
0
def to_standard_metric_inputs(
        extracts: types.Extracts,
        include_labels: bool = True,
        include_predictions: bool = True,
        include_features: bool = False,
        include_transformed_features: bool = False,
        include_attributions: bool = False
) -> metric_types.StandardMetricInputs:
    """Verifies extract keys and converts extracts to StandardMetricInputs."""
    if include_labels and constants.LABELS_KEY not in extracts:
        raise ValueError(
            f'"{constants.LABELS_KEY}" key not found in extracts. '
            'Check that the configuration is setup properly to '
            'specify the name of label input and that the proper '
            'extractor has been configured to extract the labels from '
            f'the inputs. Existing keys: {extracts.keys()}')
    if include_predictions and constants.PREDICTIONS_KEY not in extracts:
        raise ValueError(f'"{constants.PREDICTIONS_KEY}" key not found in '
                         'extracts. Check that the proper extractor has been '
                         'configured to perform model inference.')
    if include_features and constants.FEATURES_KEY not in extracts:
        raise ValueError(
            f'"{constants.FEATURES_KEY}" key not found in extracts. '
            'Check that the proper extractor has been configured to '
            'extract the features from the inputs. Existing keys: '
            f'{extracts.keys()}')
    if (include_transformed_features
            and constants.TRANSFORMED_FEATURES_KEY not in extracts):
        raise ValueError(
            f'"{constants.TRANSFORMED_FEATURES_KEY}" key not found in '
            'extracts. Check that the proper extractor has been '
            'configured to extract the transformed features from the '
            f'inputs. Existing keys: {extracts.keys()}')
    if (include_attributions and constants.ATTRIBUTIONS_KEY not in extracts):
        raise ValueError(
            f'"{constants.ATTRIBUTIONS_KEY}" key not found in '
            'extracts. Check that the proper extractor has been '
            'configured to extract the attributions from the inputs.'
            f'Existing keys: {extracts.keys()}')
    return metric_types.StandardMetricInputs(extracts)
Example #28
0
    def testStandardMetricInputsWithTopKToNumpyWithoutFlatten(self):
        example = metric_types.StandardMetricInputs(
            label={'output_name': np.array([1, 2])},
            prediction={
                'output_name': np.array([[0, 0.5, 0.3, 0.9],
                                         [0.1, 0.4, 0.2, 0.3]])
            },
            example_weight={'output_name': np.array([1.0])})
        got_label, got_pred, got_example_weight = next(
            metric_util.to_label_prediction_example_weight(
                example,
                output_name='output_name',
                sub_key=metric_types.SubKey(top_k=2),
                flatten=False))

        self.assertAllClose(got_label, np.array([1, 2]))
        self.assertAllClose(
            got_pred,
            np.array([[float('-inf'), 0.5,
                       float('-inf'), 0.9],
                      [float('-inf'), 0.4,
                       float('-inf'), 0.3]]))
        self.assertAllClose(got_example_weight, np.array([1.0]))
Example #29
0
    def testStandardMetricInputsWithClassWeights(self):
        example = metric_types.StandardMetricInputs(
            labels={'output_name': np.array([2])},
            predictions={'output_name': np.array([0, 0.5, 0.3, 0.9])},
            example_weights={'output_name': np.array([1.0])})
        iterator = metric_util.to_label_prediction_example_weight(
            example,
            output_name='output_name',
            aggregation_type=metric_types.AggregationType(micro_average=True),
            class_weights={
                0: 1.0,
                1: 0.5,
                2: 0.25,
                3: 1.0
            },
            flatten=True)

        for expected_label, expected_prediction, expected_weight in zip(
            (0.0, 0.0, 1.0, 0.0), (0.0, 0.5, 0.3, 0.9), (1.0, 0.5, 0.25, 1.0)):
            got_label, got_pred, got_example_weight = next(iterator)
            self.assertAllClose(got_label, np.array([expected_label]))
            self.assertAllClose(got_pred, np.array([expected_prediction]))
            self.assertAllClose(got_example_weight,
                                np.array([expected_weight]))
    def testWithMultiClassClassificationMultiOutput(self, add_custom_metrics):
        export_dir = self._createMultiClassClassificationModel(
            sequential=False,
            output_names=('output_1', 'output_2'),
            add_custom_metrics=add_custom_metrics)
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir)
        computation = keras_util.metric_computations_using_keras_saved_model(
            '', eval_shared_model.model_loader, None)[0]

        inputs = [
            metric_types.StandardMetricInputs(
                labels={
                    'output_1': np.array([0, 0, 1, 0, 0]),
                    'output_2': np.array([0, 0, 1, 0, 0]),
                },
                predictions={
                    'output_1': np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
                    'output_2': np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
                },
                example_weights={
                    'output_1': np.array([0.5]),
                    'output_2': np.array([0.5]),
                },
                input=self._makeExample(output_1=[0.1, 0.2, 0.1, 0.25, 0.35],
                                        output_2=[0.1, 0.2, 0.1, 0.25,
                                                  0.35]).SerializeToString()),
            metric_types.StandardMetricInputs(
                labels={
                    'output_1': np.array([0, 1, 0, 0, 0]),
                    'output_2': np.array([0, 1, 0, 0, 0]),
                },
                predictions={
                    'output_1': np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
                    'output_2': np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
                },
                example_weights={
                    'output_1': np.array([0.7]),
                    'output_2': np.array([0.7]),
                },
                input=self._makeExample(output_1=[0.2, 0.3, 0.05, 0.15, 0.3],
                                        output_2=[0.2, 0.3, 0.05, 0.15,
                                                  0.3]).SerializeToString()),
            metric_types.StandardMetricInputs(
                labels={
                    'output_1': np.array([0, 0, 0, 1, 0]),
                    'output_2': np.array([0, 0, 0, 1, 0]),
                },
                predictions={
                    'output_1': np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
                    'output_2': np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
                },
                example_weights={
                    'output_1': np.array([0.9]),
                    'output_2': np.array([0.9]),
                },
                input=self._makeExample(output_1=[0.01, 0.2, 0.09, 0.5, 0.2],
                                        output_2=[0.01, 0.2, 0.09, 0.5,
                                                  0.2]).SerializeToString()),
            metric_types.StandardMetricInputs(
                labels={
                    'output_1': np.array([0, 1, 0, 0, 0]),
                    'output_2': np.array([0, 1, 0, 0, 0]),
                },
                predictions={
                    'output_1': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
                    'output_2': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
                },
                example_weights={
                    'output_1': np.array([0.3]),
                    'output_2': np.array([0.3]),
                },
                input=self._makeExample(output_1=[0.3, 0.2, 0.05, 0.4, 0.05],
                                        output_2=[0.3, 0.2, 0.05, 0.4,
                                                  0.05]).SerializeToString())
        ]

        # Unweighted:
        #   top_k = 2
        #     TP = 0 + 1 + 1 + 0 = 2
        #     FP = 2 + 1 + 1 + 2 = 6
        #     FN = 1 + 0 + 0 + 1 = 2
        #
        #   top_k = 3
        #     TP = 0 + 1 + 1 + 1 = 3
        #     FP = 3 + 2 + 2 + 2 = 9
        #     FN = 1 + 0 + 0 + 0 = 1
        #
        # Weighted:
        #   top_k = 2
        #     TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6
        #     FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2
        #     FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8
        #
        #   top_k = 3
        #     TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9
        #     FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3
        #     FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5
        expected_values = {
            'output_1': {
                'precision@2': 2 / (2 + 6),
                'precision@3': 3 / (3 + 9),
                'recall@2': 2 / (2 + 2),
                'recall@3': 3 / (3 + 1),
                'weighted_precision@2': 1.6 / (1.6 + 3.2),
                'weighted_precision@3': 1.9 / (1.9 + 5.3),
                'weighted_recall@2': 1.6 / (1.6 + 0.8),
                'weighted_recall@3': 1.9 / (1.9 + 0.5),
                'loss': 0.77518433
            },
            'output_2': {
                'precision@2': 2 / (2 + 6),
                'precision@3': 3 / (3 + 9),
                'recall@2': 2 / (2 + 2),
                'recall@3': 3 / (3 + 1),
                'weighted_precision@2': 1.6 / (1.6 + 3.2),
                'weighted_precision@3': 1.9 / (1.9 + 5.3),
                'weighted_recall@2': 1.6 / (1.6 + 0.8),
                'weighted_recall@3': 1.9 / (1.9 + 0.5),
                'loss': 0.77518433
            },
            '': {
                'loss': 0.77518433 + 0.77518433
            }
        }
        if add_custom_metrics:
            expected_values['']['custom_output_1'] = 4.0
            expected_values['']['custom_output_2'] = 4.0

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    expected = {}
                    for output_name, per_output_values in expected_values.items(
                    ):
                        for name, value in per_output_values.items():
                            sub_key = None
                            if '@' in name:
                                sub_key = metric_types.SubKey(
                                    top_k=int(name.split('@')[1]))
                            key = metric_types.MetricKey(
                                name=name,
                                output_name=output_name,
                                sub_key=sub_key)
                            expected[key] = value
                    self.assertDictElementsAlmostEqual(got_metrics, expected)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')