def testMultiClassMetrics(self, metric_name, expected_value):
    computations = tf_metric_wrapper.tf_metric_computations(
        [self._tf_metric_by_name(metric_name)])
    histogram = computations[0]
    matrix = computations[1]
    metric = computations[2]

    example1 = {
        'labels': np.array([2]),
        'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
        'example_weights': np.array([0.5]),
    }
    example2 = {
        'labels': np.array([1]),
        'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
        'example_weights': np.array([0.7]),
    }
    example3 = {
        'labels': np.array([3]),
        'predictions': np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
        'example_weights': np.array([0.9]),
    }
    example4 = {
        'labels': np.array([4]),
        'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
        'example_weights': np.array([0.3]),
    }

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([example1, example2, example3, example4])
          | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
          | 'ComputeConfusionMatrix' >> beam.Map(
              lambda x: (x[0], matrix.result(x[1])))  # pyformat: disable
          | 'ComputeMetric' >> beam.Map(
              lambda x: (x[0], metric.result(x[1]))))  # pyformat: disable

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          top_k = int(metric_name.split('@')[1])
          key = metric_types.MetricKey(
              name=metric_name, sub_key=metric_types.SubKey(top_k=top_k))
          self.assertDictElementsAlmostEqual(
              got_metrics, {key: expected_value}, places=5)

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')
Пример #2
0
  def testMetricsWithWeights(self, metric_name, expected_value):
    # TODO (b/151636380): remove when CL/299961405 is propagated through Kokoro.
    if metric_name == 'specificity_at_sensitivity':
      fix_present = hasattr(tf.keras.metrics.SpecificityAtSensitivity,
                            '_find_max_under_constraint')
      if not fix_present:
        expected_value = 0.0

    computations = tf_metric_wrapper.tf_metric_computations(
        [self._tf_metric_by_name(metric_name)])
    histogram = computations[0]
    matrix = computations[1]
    metric = computations[2]

    example1 = {
        'labels': np.array([0.0]),
        'predictions': np.array([1.0]),
        'example_weights': np.array([0.5]),
    }
    example2 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.7]),
        'example_weights': np.array([0.7]),
    }
    example3 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.5]),
        'example_weights': np.array([0.9]),
    }

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([example1, example2, example3])
          | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
          | 'ComputeConfusionMatrix' >> beam.Map(
              lambda x: (x[0], matrix.result(x[1])))  # pyformat: disable
          | 'ComputeMetric' >> beam.Map(
              lambda x: (x[0], metric.result(x[1]))))  # pyformat: disable

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          key = metric_types.MetricKey(name=metric_name)
          self.assertDictElementsAlmostEqual(
              got_metrics, {key: expected_value}, places=5)

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')
    def testTFMetricWithClassID(self):
        computation = tf_metric_wrapper.tf_metric_computations(
            [tf.keras.metrics.MeanSquaredError(name='mse')],
            sub_key=metric_types.SubKey(class_id=1),
            example_weighted=False)[0]

        example1 = {
            'labels': [2],
            'predictions': [0.5, 0.0, 0.5],
            'example_weights': [0.1]  # ignored, example_weighted=False
        }
        example2 = {
            'labels': [0],
            'predictions': [0.2, 0.5, 0.3],
            'example_weights': [0.2]  # ignored, example_weighted=False
        }
        example3 = {
            'labels': [1],
            'predictions': [0.2, 0.3, 0.5],
            'example_weights': [0.3]  # ignored, example_weighted=False
        }
        example4 = {
            'labels': [1],
            'predictions': [0.0, 0.9, 0.1],
            'example_weights': [0.4]  # ignored, example_weighted=False
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [example1, example2, example3, example4])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    mse_key = metric_types.MetricKey(
                        name='mse',
                        sub_key=metric_types.SubKey(class_id=1),
                        example_weighted=False)
                    self.assertDictElementsAlmostEqual(got_metrics, {
                        mse_key: 0.1875,
                    })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testCustomTFMetricWithPadding(self, example_indices, expected):
        computation = tf_metric_wrapper.tf_metric_computations(
            [
                _CustomMetric(name='custom_label', update_y_pred=False),
                _CustomMetric(name='custom_pred', update_y_pred=True),
            ],
            eval_config=config_pb2.EvalConfig(model_specs=[
                config_pb2.ModelSpec(padding_options=config_pb2.PaddingOptions(
                    label_int_padding=-1,
                    prediction_float_padding=-1.0,
                ))
            ]),
            example_weighted=True)[0]

        examples = [{
            'labels': np.array([1], dtype=np.int64),
            'predictions': np.array([0.1, 0.2, 0.3, 0.0]),
            'example_weights': np.array([1.0])
        }, {
            'labels': np.array([1, 2], dtype=np.int64),
            'predictions': np.array([0.1, 0.2, 0.0]),
            'example_weights': np.array([1.0])
        }, {
            'labels': np.array([1, 2, 3], dtype=np.int64),
            'predictions': np.array([0.1, 0.2, 0.3]),
            'example_weights': np.array([2.0])
        }]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                |
                'Create' >> beam.Create([examples[i] for i in example_indices])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())

                    custom_label_key = metric_types.MetricKey(
                        name='custom_label', example_weighted=True)
                    custom_pred_key = metric_types.MetricKey(
                        name='custom_pred', example_weighted=True)
                    self.assertDictElementsAlmostEqual(got_metrics, expected)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testMergeAccumulators(self):
        computation = tf_metric_wrapper.tf_metric_computations(
            [tf.keras.metrics.MeanSquaredError(name='mse')],
            desired_batch_size=2,
            example_weighted=True)[0]

        example1 = {
            'labels': [0.0],
            'predictions': [0.0],
            'example_weights': [1.0]
        }
        example2 = {
            'labels': [0.0],
            'predictions': [0.5],
            'example_weights': [1.0]
        }
        example3 = {
            'labels': [1.0],
            'predictions': [0.3],
            'example_weights': [1.0]
        }
        example4 = {
            'labels': [1.0],
            'predictions': [0.9],
            'example_weights': [1.0]
        }
        example5 = {
            'labels': [1.0],
            'predictions': [0.5],
            'example_weights': [0.0]
        }

        computation.combiner.setup()
        combiner_inputs = []
        for e in (example1, example2, example3, example4, example5):
            combiner_inputs.append(metric_util.to_standard_metric_inputs(e))
        acc1 = computation.combiner.create_accumulator()
        acc1 = computation.combiner.add_input(acc1, combiner_inputs[0])
        acc1 = computation.combiner.add_input(acc1, combiner_inputs[1])
        acc1 = computation.combiner.add_input(acc1, combiner_inputs[2])
        acc2 = computation.combiner.create_accumulator()
        acc2 = computation.combiner.add_input(acc2, combiner_inputs[3])
        acc2 = computation.combiner.add_input(acc2, combiner_inputs[4])
        acc = computation.combiner.merge_accumulators([acc1, acc2])

        got_metrics = computation.combiner.extract_output(acc)
        mse_key = metric_types.MetricKey(name='mse', example_weighted=True)
        self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.1875})
    def testCustomTFMetric(self):
        metric = tf_metric_wrapper.tf_metric_computations(
            [_CustomMetric()], example_weighted=True)[0]

        example1 = {
            'labels': [0.0],
            'predictions': [0.2],
            'example_weights': [1.0]
        }
        example2 = {
            'labels': [0.0],
            'predictions': [0.8],
            'example_weights': [1.0]
        }
        example3 = {
            'labels': [0.0],
            'predictions': [0.5],
            'example_weights': [2.0]
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example1, example2, example3])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(metric.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())

                    custom_key = metric_types.MetricKey(name='custom',
                                                        example_weighted=True)
                    self.assertDictElementsAlmostEqual(got_metrics, {
                        custom_key: (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0)
                    })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testMetricWithClassWeights(self):
        computation = tf_metric_wrapper.tf_metric_computations(
            [tf.keras.metrics.MeanSquaredError(name='mse')],
            aggregation_type=metric_types.AggregationType(micro_average=True),
            class_weights={
                0: 0.1,
                1: 0.2,
                2: 0.3,
                3: 0.4
            })[0]

        # Simulate a multi-class problem with 4 labels. The use of class weights
        # implies micro averaging which only makes sense for multi-class metrics.
        example = {
            'labels': [0, 0, 1, 0],
            'predictions': [0, 0.5, 0.3, 0.9],
            'example_weights': [1.0]
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    mse_key = metric_types.MetricKey(name='mse')
                    # numerator = (0.1*0**2 + 0.2*0.5**2 + 0.3*0.7**2 + 0.4*0.9**2)
                    # denominator = (.1 + .2 + 0.3 + 0.4)
                    # numerator / denominator = 0.521
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {mse_key: 0.521})

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
Пример #8
0
  def testBatching(self):
    options = config.Options()
    options.desired_batch_size.value = 2
    computation = tf_metric_wrapper.tf_metric_computations(
        [_CustomMetric(),
         tf.keras.metrics.MeanSquaredError(name='mse')],
        config.EvalConfig(options=options))[0]

    example1 = {'labels': [0.0], 'predictions': [0.0], 'example_weights': [1.0]}
    example2 = {'labels': [0.0], 'predictions': [0.5], 'example_weights': [1.0]}
    example3 = {'labels': [1.0], 'predictions': [0.3], 'example_weights': [1.0]}
    example4 = {'labels': [1.0], 'predictions': [0.9], 'example_weights': [1.0]}
    example5 = {'labels': [1.0], 'predictions': [0.5], 'example_weights': [0.0]}

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create(
              [example1, example2, example3, example4, example5])
          | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'Combine' >> beam.CombinePerKey(computation.combiner))

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertEqual(1, len(got), 'got: %s' % got)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())

          custom_key = metric_types.MetricKey(name='custom')
          mse_key = metric_types.MetricKey(name='mse')
          self.assertDictElementsAlmostEqual(
              got_metrics, {
                  custom_key: (0.0 + 0.5 + 0.3 + 0.9 + 0.0) /
                              (1.0 + 1.0 + 1.0 + 1.0 + 0.0),
                  mse_key:
                      0.1875,
              })

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')
    def testMultiOutputTFMetric(self):
        computation = tf_metric_wrapper.tf_metric_computations(
            {
                'output_name': [tf.keras.metrics.MeanSquaredError(name='mse')],
            }, config.EvalConfig())[0]

        extracts = {
            'labels': {
                'output_name': [0, 0, 1, 1],
            },
            'predictions': {
                'output_name': [0, 0.5, 0.3, 0.9],
            },
            'example_weights': {
                'output_name': [1.0]
            }
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([extracts])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    mse_key = metric_types.MetricKey(name='mse',
                                                     output_name='output_name')
                    self.assertDictElementsAlmostEqual(got_metrics, {
                        mse_key: 0.1875,
                    })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testSparseMetric(self):
        computation = tf_metric_wrapper.tf_metric_computations([
            tf.keras.metrics.SparseCategoricalCrossentropy(
                name='sparse_categorical_crossentropy')
        ])[0]

        # Simulate a multi-class problem with 3 labels.
        example = {
            'labels': [1],
            'predictions': [0.3, 0.6, 0.1],
            'example_weights': [1.0]
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    key = metric_types.MetricKey(
                        name='sparse_categorical_crossentropy')
                    # 0*log(.3) -1*log(0.6)-0*log(.1) = 0.51
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {key: 0.51083})

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testMultiClassMetricsUsingKerasConfig(self, metric_name,
                                              expected_value):
        metric = tf_metric_wrapper.tf_metric_computations(
            [self._tf_metric_by_name(metric_name)], example_weighted=True)[0]

        # top_k = 2
        #   TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*0 = 1.6
        #   FP = 0.5*2 + 0.7*1 + 0.9*1 + 0.3*2 = 3.2
        #   FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*1 = 0.8
        #
        # top_k = 3
        #   TP = 0.5*0 + 0.7*1 + 0.9*1 + 0.3*1 = 1.9
        #   FP = 0.5*3 + 0.7*2 + 0.9*2 + 0.3*2 = 5.3
        #   FN = 0.5*1 + 0.7*0 + 0.9*0 + 0.3*0 = 0.5
        example1 = {
            'labels': np.array([2]),
            'predictions': np.array([0.1, 0.2, 0.1, 0.25, 0.35]),
            'example_weights': np.array([0.5]),
        }
        example2 = {
            'labels': np.array([1]),
            'predictions': np.array([0.2, 0.3, 0.05, 0.15, 0.3]),
            'example_weights': np.array([0.7]),
        }
        example3 = {
            'labels': np.array([3]),
            'predictions': np.array([0.01, 0.2, 0.09, 0.5, 0.2]),
            'example_weights': np.array([0.9]),
        }
        example4 = {
            'labels': np.array([1]),
            'predictions': np.array([0.3, 0.2, 0.05, 0.4, 0.05]),
            'example_weights': np.array([0.3]),
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [example1, example2, example3, example4])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(metric.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    top_k = int(metric_name.split('@')[1])
                    key = metric_types.MetricKey(
                        name=metric_name,
                        sub_key=metric_types.SubKey(top_k=top_k),
                        example_weighted=True)
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {key: expected_value},
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testWithMixedMetrics(self):
        computations = tf_metric_wrapper.tf_metric_computations([
            tf.keras.metrics.AUC(name='auc'),
            tf.keras.losses.BinaryCrossentropy(name='binary_crossentropy'),
            tf.keras.metrics.MeanSquaredError(name='mse')
        ])

        confusion_histogram = computations[0]
        confusion_matrix = computations[1].result
        confusion_metrics = computations[2].result
        non_confusion_metrics = computations[3]

        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.0]),
            'example_weights': np.array([1.0]),
        }
        example2 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.5]),
            'example_weights': np.array([1.0]),
        }
        example3 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.3]),
            'example_weights': np.array([1.0]),
        }
        example4 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.9]),
            'example_weights': np.array([1.0]),
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            sliced_examples = (
                pipeline
                | 'Create' >> beam.Create(
                    [example1, example2, example3, example4])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x)))

            confusion_result = (
                sliced_examples
                | 'ComputeHistogram' >> beam.CombinePerKey(
                    confusion_histogram.combiner)
                | 'ComputeConfusionMatrix' >>
                beam.Map(lambda x:
                         (x[0], confusion_matrix(x[1])))  # pyformat: disable
                | 'ComputeMetric' >> beam.Map(lambda x:
                                              (x[0], confusion_metrics(x[1])))
            )  # pyformat: disable

            non_confusion_result = (sliced_examples
                                    | 'Combine' >> beam.CombinePerKey(
                                        non_confusion_metrics.combiner))

            # pylint: enable=no-value-for-parameter

            def check_confusion_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    auc_key = metric_types.MetricKey(name='auc')
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {auc_key: 0.75},
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            def check_non_confusion_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    mse_key = metric_types.MetricKey(name='mse')
                    binary_crossentropy_key = metric_types.MetricKey(
                        name='binary_crossentropy')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            mse_key: 0.1875,
                            binary_crossentropy_key: 0.50061995
                        },
                        places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(confusion_result,
                             check_confusion_result,
                             label='confusion')
            util.assert_that(non_confusion_result,
                             check_non_confusion_result,
                             label='non_confusion')
Пример #13
0
def to_computations(
    metrics_specs: List[config.MetricsSpec],
    eval_config: Optional[config.EvalConfig] = None,
    schema: Optional[schema_pb2.Schema] = None
) -> metric_types.MetricComputations:
    """Returns computations associated with given metrics specs."""
    computations = []

    #
    # Split into TF metrics and TFMA metrics
    #

    # Dict[Text, Type[tf.keras.metrics.Metric]]
    tf_metric_classes = {}  # class_name -> class
    # Dict[Text, Type[tf.keras.losses.Loss]]
    tf_loss_classes = {}  # class_name -> class
    # List[metric_types.MetricsSpec]
    tf_metrics_specs = []
    # Dict[Text, Type[metric_types.Metric]]
    tfma_metric_classes = metric_types.registered_metrics(
    )  # class_name -> class
    # List[metric_types.MetricsSpec]
    tfma_metrics_specs = []
    #
    # Note: Lists are used instead of Dicts for the following items because
    # protos are are no hashable.
    #
    # List[List[_TFOrTFMAMetric]] (offsets align with metrics_specs).
    per_spec_metric_instances = []
    # List[List[_TFMetricOrLoss]] (offsets align with tf_metrics_specs).
    per_tf_spec_metric_instances = []
    # List[List[metric_types.Metric]]] (offsets align with tfma_metrics_specs).
    per_tfma_spec_metric_instances = []
    for spec in metrics_specs:
        tf_spec = config.MetricsSpec()
        tf_spec.CopyFrom(spec)
        del tf_spec.metrics[:]
        tfma_spec = config.MetricsSpec()
        tfma_spec.CopyFrom(spec)
        del tfma_spec.metrics[:]
        for metric in spec.metrics:
            if metric.class_name in tfma_metric_classes:
                tfma_spec.metrics.append(metric)
            elif not metric.module:
                tf_spec.metrics.append(metric)
            else:
                cls = getattr(importlib.import_module(metric.module),
                              metric.class_name)
                if issubclass(cls, tf.keras.metrics.Metric):
                    tf_metric_classes[metric.class_name] = cls
                    tf_spec.metrics.append(metric)
                elif issubclass(cls, tf.keras.losses.Loss):
                    tf_loss_classes[metric.class_name] = cls
                    tf_spec.metrics.append(metric)
                else:
                    tfma_metric_classes[metric.class_name] = cls
                    tfma_spec.metrics.append(metric)

        metric_instances = []
        if tf_spec.metrics:
            tf_metrics_specs.append(tf_spec)
            tf_metric_instances = []
            for m in tf_spec.metrics:
                # To distinguish losses from metrics, losses are required to set the
                # module name.
                if m.module == _TF_LOSSES_MODULE:
                    tf_metric_instances.append(
                        _deserialize_tf_loss(m, tf_loss_classes))
                else:
                    tf_metric_instances.append(
                        _deserialize_tf_metric(m, tf_metric_classes))
            per_tf_spec_metric_instances.append(tf_metric_instances)
            metric_instances.extend(tf_metric_instances)
        if tfma_spec.metrics:
            tfma_metrics_specs.append(tfma_spec)
            tfma_metric_instances = [
                _deserialize_tfma_metric(m, tfma_metric_classes)
                for m in tfma_spec.metrics
            ]
            per_tfma_spec_metric_instances.append(tfma_metric_instances)
            metric_instances.extend(tfma_metric_instances)
        per_spec_metric_instances.append(metric_instances)

    #
    # Group TF metrics by the subkeys, models and outputs. This is done in reverse
    # because model and subkey processing is done outside of TF and so each unique
    # sub key combination needs to be run through a separate model instance. Note
    # that output_names are handled by the tf_metric_computation since all the
    # outputs are batch calculated in a single model evaluation call.
    #

    # Dict[metric_types.SubKey, Dict[Text, List[int]]
    tf_spec_indices_by_subkey = {
    }  # SubKey -> model_name -> [index(MetricSpec)]
    for i, spec in enumerate(tf_metrics_specs):
        sub_keys = _create_sub_keys(spec)
        if not sub_keys:
            sub_keys = [None]
        for sub_key in sub_keys:
            if sub_key not in tf_spec_indices_by_subkey:
                tf_spec_indices_by_subkey[sub_key] = {}
            # Dict[Text, List[config.MetricSpec]]
            tf_spec_indices_by_model = (tf_spec_indices_by_subkey[sub_key]
                                        )  # name -> [ModelSpec]
            model_names = spec.model_names
            if not model_names:
                model_names = [''
                               ]  # '' is name used when only one model is used
            for model_name in model_names:
                if model_name not in tf_spec_indices_by_model:
                    tf_spec_indices_by_model[model_name] = []
                tf_spec_indices_by_model[model_name].append(i)
    for sub_key, spec_indices_by_model in tf_spec_indices_by_subkey.items():
        for model_name, indices in spec_indices_by_model.items():
            # Class weights are a dict that is not hashable, so we store index to spec
            # containing class weights.
            metrics_by_class_weights_by_output = collections.defaultdict(dict)
            for i in indices:
                class_weights_i = None
                if tf_metrics_specs[i].HasField('aggregate'):
                    class_weights_i = i
                metrics_by_output = metrics_by_class_weights_by_output[
                    class_weights_i]
                output_names = ['']  # '' is name used when only one output
                if tf_metrics_specs[i].output_names:
                    output_names = tf_metrics_specs[i].output_names
                for output_name in output_names:
                    if output_name not in metrics_by_output:
                        metrics_by_output[output_name] = []
                    metrics_by_output[output_name].extend(
                        per_tf_spec_metric_instances[i])
            for i, metrics_by_output in metrics_by_class_weights_by_output.items(
            ):
                class_weights = None
                if i is not None:
                    class_weights = dict(
                        tf_metrics_specs[i].aggregate.class_weights)
                computations.extend(
                    tf_metric_wrapper.tf_metric_computations(
                        metrics_by_output,
                        eval_config=eval_config,
                        model_name=model_name,
                        sub_key=sub_key,
                        class_weights=class_weights))

    #
    # Group TFMA metric specs by the metric classes
    #

    # Dict[bytes, List[config.MetricSpec]]
    tfma_specs_by_metric_config = {}  # hash(MetricConfig) -> [MetricSpec]
    # Dict[bytes, metric_types.Metric]
    hashed_metrics = {}  # hash(MetricConfig) -> Metric
    for i, spec in enumerate(tfma_metrics_specs):
        for metric_config, metric in zip(spec.metrics,
                                         per_tfma_spec_metric_instances[i]):
            # Note that hashing by SerializeToString() is only safe if used within the
            # same process.
            config_hash = metric_config.SerializeToString()
            if config_hash not in tfma_specs_by_metric_config:
                hashed_metrics[config_hash] = metric
                tfma_specs_by_metric_config[config_hash] = []
            tfma_specs_by_metric_config[config_hash].append(spec)
    for config_hash, specs in tfma_specs_by_metric_config.items():
        metric = hashed_metrics[config_hash]
        for spec in specs:
            sub_keys = _create_sub_keys(spec)
            class_weights = None
            if spec.HasField('aggregate'):
                class_weights = dict(spec.aggregate.class_weights)
            computations.extend(
                metric.computations(
                    eval_config=eval_config,
                    schema=schema,
                    model_names=spec.model_names if spec.model_names else [''],
                    output_names=spec.output_names
                    if spec.output_names else [''],
                    sub_keys=sub_keys,
                    class_weights=class_weights,
                    query_key=spec.query_key))

    #
    # Create macro averaging metrics
    #

    for i, spec in enumerate(metrics_specs):
        if spec.aggregate.macro_average or spec.aggregate.weighted_macro_average:
            sub_keys = _create_sub_keys(spec)
            if sub_keys is None:
                raise ValueError(
                    'binarize settings are required when aggregate.macro_average or '
                    'aggregate.weighted_macro_average is used: spec={}'.format(
                        spec))
            for model_name in spec.model_names or ['']:
                for output_name in spec.output_names or ['']:
                    for metric in per_spec_metric_instances[i]:
                        if spec.aggregate.macro_average:
                            computations.extend(
                                aggregation.macro_average(
                                    metric.get_config()['name'],
                                    eval_config=eval_config,
                                    model_name=model_name,
                                    output_name=output_name,
                                    sub_keys=sub_keys,
                                    class_weights=dict(
                                        spec.aggregate.class_weights)))
                        elif spec.aggregate.weighted_macro_average:
                            computations.extend(
                                aggregation.weighted_macro_average(
                                    metric.get_config()['name'],
                                    eval_config=eval_config,
                                    model_name=model_name,
                                    output_name=output_name,
                                    sub_keys=sub_keys,
                                    class_weights=dict(
                                        spec.aggregate.class_weights)))

    return computations
Пример #14
0
def to_computations(
    metrics_specs: List[config.MetricsSpec],
    eval_config: Optional[config.EvalConfig] = None,
    model_loaders: Optional[Dict[Text, types.ModelLoader]] = None
) -> metric_types.MetricComputations:
    """Returns computations associated with given metrics specs."""
    computations = []

    #
    # Split into TF metrics and TFMA metrics
    #

    # Dict[Text, Type[tf.keras.metrics.Metric]]
    tf_metric_classes = {}  # class_name -> class
    # List[metric_types.MetricsSpec]
    tf_metrics_specs = []
    # Dict[Text, Type[metric_types.Metric]]
    tfma_metric_classes = metric_types.registered_metrics(
    )  # class_name -> class
    # List[metric_types.MetricsSpec]
    tfma_metrics_specs = []
    for spec in metrics_specs:
        tf_spec = config.MetricsSpec()
        tf_spec.CopyFrom(spec)
        del tf_spec.metrics[:]
        tfma_spec = config.MetricsSpec()
        tfma_spec.CopyFrom(spec)
        del tfma_spec.metrics[:]
        for metric in spec.metrics:
            if metric.class_name in tfma_metric_classes:
                tfma_spec.metrics.append(metric)
            elif not metric.module:
                tf_spec.metrics.append(metric)
            else:
                cls = getattr(importlib.import_module(metric.module_name),
                              metric.class_name)
                if isinstance(metric, tf.keras.metrics.Metric):
                    tf_metric_classes[metric.class_name] = cls
                    tf_spec.metrics.append(metric)
                else:
                    tfma_metric_classes[metric.class_name] = cls
                    tfma_spec.metrics.append(metric)
        if tf_spec.metrics:
            tf_metrics_specs.append(tf_spec)
        if tfma_spec.metrics:
            tfma_metrics_specs.append(tfma_spec)

    #
    # Group TF metrics by the subkeys, models and outputs. This is done in reverse
    # because model and subkey processing is done outside of TF and so each unique
    # sub key combination needs to be run through a separate model instance. Note
    # that output_names are handled by the tf_metric_computation since all the
    # outputs are batch calculated in a single model evaluation call.
    #

    # Dict[metric_types.SubKey, Dict[Text, List[config.MetricSpec]]
    tf_specs_by_subkey = {}  # SubKey -> model_name -> [MetricSpec]
    for spec in tf_metrics_specs:
        sub_keys = _create_sub_keys(spec)
        if not sub_keys:
            sub_keys = [None]
        for sub_key in sub_keys:
            if sub_key not in tf_specs_by_subkey:
                tf_specs_by_subkey[sub_key] = {}
            # Dict[Text, List[config.MetricSpec]]
            tf_specs_by_model = tf_specs_by_subkey[
                sub_key]  # name -> [ModelSpec]
            model_names = spec.model_names
            if not model_names:
                model_names = [''
                               ]  # '' is name used when only one model is used
            for model_name in model_names:
                if model_name not in tf_specs_by_model:
                    tf_specs_by_model[model_name] = []
                tf_specs_by_model[model_name].append(spec)
    for sub_key, specs_by_model in tf_specs_by_subkey.items():
        for model_name, specs in specs_by_model.items():
            metrics_by_output = {}
            for spec in specs:
                metrics = [
                    _deserialize_tf_metric(m, tf_metric_classes)
                    for m in spec.metrics
                ]
                if spec.output_names:
                    for output_name in spec.output_names:
                        if output_name not in metrics_by_output:
                            metrics_by_output[output_name] = []
                        metrics_by_output[output_name].extend(metrics)
                else:
                    if '' not in metrics_by_output:
                        metrics_by_output[''] = [
                        ]  # '' is name used when only one output
                    metrics_by_output[''].extend(metrics)
            model_loader = None
            if model_loaders and model_name in model_loaders:
                model_loader = model_loaders[model_name]
            computations.extend(
                tf_metric_wrapper.tf_metric_computations(
                    metrics_by_output,
                    eval_config=eval_config,
                    model_name=model_name,
                    sub_key=sub_key,
                    model_loader=model_loader))

    #
    # Group TFMA metric specs by the metric classes
    #

    # Dict[bytes, List[config.MetricSpec]]
    tfma_specs_by_metric_config = {}  # hash(MetricConfig) -> [MetricSpec]
    # Dict[bytes, config.MetricConfig]
    hashed_metric_configs = {}  # hash(MetricConfig) -> MetricConfig
    for spec in tfma_metrics_specs:
        for metric_config in spec.metrics:
            # Note that hashing by SerializeToString() is only safe if used within the
            # same process.
            config_hash = metric_config.SerializeToString()
            if config_hash not in tfma_specs_by_metric_config:
                hashed_metric_configs[config_hash] = metric_config
                tfma_specs_by_metric_config[config_hash] = []
            tfma_specs_by_metric_config[config_hash].append(spec)
    for config_hash, specs in tfma_specs_by_metric_config.items():
        metric = _deserialize_tfma_metric(hashed_metric_configs[config_hash],
                                          tfma_metric_classes)
        for spec in specs:
            sub_keys = _create_sub_keys(spec)
            computations.extend(
                metric.computations(
                    eval_config=eval_config,
                    model_names=spec.model_names if spec.model_names else [''],
                    output_names=spec.output_names
                    if spec.output_names else [''],
                    sub_keys=sub_keys,
                    query_key=spec.query_key))
    return computations
Пример #15
0
    def testCustomTFMetricWithPadding(self):
        computation = tf_metric_wrapper.tf_metric_computations([
            _CustomMetric(name='custom_label', update_y_pred=False),
            _CustomMetric(name='custom_pred', update_y_pred=True)
        ])[0]

        # label_sum = (1 - 1 - 1 - 1) * 1.0 +
        #             (1 + 2 - 1.0 - 1) * 1.0 +
        #             (1 + 2 + 3 - 1) * 2.0
        #           = 9.0
        #
        # pred_sum = (0.1 + 0.2 + 0.3 + 0.0) * 1.0 +
        #            (0.1 + 0.2 + 0.0 - 1.0) * 1.0 +
        #            (0.1 + 0.2 + 0.3 - 1.0) * 2.0
        #           = -0.9
        #
        # weights_total = (1.0 * 4 + 1.0 * 4 + 2.0 * 4) = 16.0
        example1 = {
            'labels': np.array([1], dtype=np.int64),
            'predictions': np.array([0.1, 0.2, 0.3, 0.0]),
            'example_weights': np.array([1.0])
        }
        example2 = {
            'labels': np.array([1, 2], dtype=np.int64),
            'predictions': np.array([0.1, 0.2, 0.0]),
            'example_weights': np.array([1.0])
        }
        example3 = {
            'labels': np.array([1, 2, 3], dtype=np.int64),
            'predictions': np.array([0.1, 0.2, 0.3]),
            'example_weights': np.array([2.0])
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example1, example2, example3])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'Combine' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())

                    custom_label_key = metric_types.MetricKey(
                        name='custom_label')
                    custom_pred_key = metric_types.MetricKey(
                        name='custom_pred')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            custom_label_key: 9.0 / 16.0,
                            custom_pred_key: -0.9 / 16.0
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testWithDefaultMetricsProvidedByModel(self):
        export_dir = os.path.join(self._getTempDir(), 'export_dir')
        dummy_layer = tf.keras.layers.Input(shape=(1, ))
        model = tf.keras.models.Model([dummy_layer], [dummy_layer])
        model.compile(loss=tf.keras.losses.BinaryCrossentropy(),
                      metrics=[tf.keras.metrics.MeanSquaredError(name='mse')])
        model.save(export_dir, save_format='tf')
        model_loader = types.ModelLoader(
            tags=[tf.saved_model.SERVING],
            construct_fn=model_util.model_construct_fn(
                eval_saved_model_path=export_dir,
                tags=[tf.saved_model.SERVING]))

        computations = tf_metric_wrapper.tf_metric_computations(
            [tf.keras.metrics.AUC(name='auc')],
            config.EvalConfig(),
            model_loader=model_loader)

        confusion_histogram = computations[0]
        confusion_matrix = computations[1].result
        confusion_metrics = computations[2].result
        non_confusion_metrics = computations[3]

        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.0]),
            'example_weights': np.array([1.0]),
        }
        example2 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.5]),
            'example_weights': np.array([1.0]),
        }
        example3 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.3]),
            'example_weights': np.array([1.0]),
        }
        example4 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.9]),
            'example_weights': np.array([1.0]),
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            sliced_examples = (
                pipeline
                | 'Create' >> beam.Create(
                    [example1, example2, example3, example4])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x)))

            confusion_result = (
                sliced_examples
                | 'ComputeHistogram' >> beam.CombinePerKey(
                    confusion_histogram.combiner)
                | 'ComputeConfusionMatrix' >>
                beam.Map(lambda x:
                         (x[0], confusion_matrix(x[1])))  # pyformat: disable
                | 'ComputeMetric' >> beam.Map(lambda x:
                                              (x[0], confusion_metrics(x[1])))
            )  # pyformat: disable

            non_confusion_result = (sliced_examples
                                    | 'Combine' >> beam.CombinePerKey(
                                        non_confusion_metrics.combiner))

            # pylint: enable=no-value-for-parameter

            def check_confusion_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    auc_key = metric_types.MetricKey(name='auc')
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {auc_key: 0.75},
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            def check_non_confusion_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    mse_key = metric_types.MetricKey(name='mse')
                    binary_crossentropy_key = metric_types.MetricKey(
                        name='binary_crossentropy')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            mse_key: 0.1875,
                            binary_crossentropy_key: 0.0
                        },
                        places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(confusion_result,
                             check_confusion_result,
                             label='confusion')
            util.assert_that(non_confusion_result,
                             check_non_confusion_result,
                             label='non_confusion')
Пример #17
0
def _process_tf_metrics_specs(
    tf_metrics_specs: List[config.MetricsSpec],
    per_tf_spec_metric_instances: List[List[_TFMetricOrLoss]],
    eval_config: config.EvalConfig) -> metric_types.MetricComputations:
  """Processes list of TF MetricsSpecs to create computations."""

  # Wrap args into structure that is hashable so we can track unique arg sets.
  class UniqueArgs(
      NamedTuple('UniqueArgs',
                 [('model_name', Text),
                  ('sub_key', Optional[metric_types.SubKey]),
                  ('aggregation_type', Optional[metric_types.AggregationType]),
                  ('class_weights', Tuple[Tuple[int, float], ...])])):
    pass

  def _create_private_tf_metrics(
      metrics: List[_TFMetricOrLoss]) -> List[_TFMetricOrLoss]:
    """Creates private versions of TF metrics."""
    result = []
    for m in metrics:
      if isinstance(m, tf.keras.metrics.Metric):
        result.append(_private_tf_metric(m))
      else:
        result.append(_private_tf_loss(m))
    return result

  #
  # Group TF metrics by the subkeys, models and outputs. This is done in reverse
  # because model and subkey processing is done outside of TF and so each unique
  # sub key combination needs to be run through a separate model instance. Note
  # that output_names are handled by the tf_metric_computation since all the
  # outputs are batch calculated in a single model evaluation call.
  #

  # UniqueArgs -> output_name -> [_TFMetricOrLoss]
  metrics_by_unique_args = collections.defaultdict(dict)
  for i, spec in enumerate(tf_metrics_specs):
    metrics = per_tf_spec_metric_instances[i]
    sub_keys_by_aggregation_type = _create_sub_keys(spec)
    # Keep track of metrics that can be shared between macro averaging and
    # binarization. For example, if macro averaging is being performed over 10
    # classes and 5 of the classes are also being binarized, then those 5
    # classes can be re-used by the macro averaging calculation. The remaining
    # 5 classes need to be added as private metrics since those classes were
    # not requested but are still needed for the macro averaging calculation.
    if None in sub_keys_by_aggregation_type:
      shared_sub_keys = set(sub_keys_by_aggregation_type[None])
    else:
      shared_sub_keys = set()
    for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items():
      if aggregation_type:
        class_weights = tuple(sorted((_class_weights(spec) or {}).items()))
      else:
        class_weights = ()
      is_macro = (
          aggregation_type and (aggregation_type.macro_average or
                                aggregation_type.weighted_macro_average))
      for parent_sub_key in sub_keys:
        if is_macro:
          child_sub_keys = _macro_average_sub_keys(parent_sub_key,
                                                   _class_weights(spec))
        else:
          child_sub_keys = [parent_sub_key]
        for output_name in spec.output_names or ['']:
          for sub_key in child_sub_keys:
            if is_macro and sub_key not in shared_sub_keys:
              # Create private metrics for all non-shared metrics.
              instances = _create_private_tf_metrics(metrics)
            else:
              instances = metrics
            for model_name in spec.model_names or ['']:
              unique_args = UniqueArgs(
                  model_name, sub_key,
                  aggregation_type if not is_macro else None,
                  class_weights if not is_macro else ())
              if output_name not in metrics_by_unique_args[unique_args]:
                metrics_by_unique_args[unique_args][output_name] = []
              metrics_by_unique_args[unique_args][output_name].extend(instances)

  # Convert Unique args and outputs to calls to compute TF metrics
  result = []
  for args, metrics_by_output in metrics_by_unique_args.items():
    class_weights = dict(args.class_weights) if args.class_weights else None
    result.extend(
        tf_metric_wrapper.tf_metric_computations(
            metrics_by_output,
            eval_config=eval_config,
            model_name=args.model_name,
            sub_key=args.sub_key,
            aggregation_type=args.aggregation_type,
            class_weights=class_weights))
  return result
    def testMetricsWithFractionalLabels(self, metric_name, expected_value):
        computations = tf_metric_wrapper.tf_metric_computations(
            [self._tf_metric_by_name(metric_name)])
        histogram = computations[0]
        matrix = computations[1]
        metric = computations[2]

        # The following examples will be expanded to:
        #
        # prediction | label | weight
        #     0.0    |   -   |  1.0
        #     0.7    |   -   |  0.4
        #     0.7    |   +   |  0.6
        #     1.0    |   -   |  0.2
        #     1.0    |   +   |  0.8
        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.0]),
            'example_weights': np.array([1.0]),
        }
        example2 = {
            'labels': np.array([0.6]),
            'predictions': np.array([0.7]),
            'example_weights': np.array([1.0]),
        }
        example3 = {
            'labels': np.array([0.8]),
            'predictions': np.array([1.0]),
            'example_weights': np.array([1.0]),
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example1, example2, example3])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
                | 'ComputeConfusionMatrix' >> beam.Map(
                    lambda x: (x[0], matrix.result(x[1])))  # pyformat: disable
                | 'ComputeMetric' >> beam.Map(lambda x:
                                              (x[0], metric.result(x[1])))
            )  # pyformat: disable

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    key = metric_types.MetricKey(name=metric_name)
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {key: expected_value},
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')