예제 #1
0
def default_extractors(  # pylint: disable=invalid-name
        eval_shared_model: Union[types.EvalSharedModel,
                                 Dict[Text, types.EvalSharedModel]] = None,
        eval_config: config.EvalConfig = None,
        slice_spec: Optional[List[slicer.SingleSliceSpec]] = None,
        desired_batch_size: Optional[int] = None,
        materialize: Optional[bool] = True) -> List[extractor.Extractor]:
    """Returns the default extractors for use in ExtractAndEvaluate.

  Args:
    eval_shared_model: Shared model (single-model evaluation) or dict of shared
      models keyed by model name (multi-model evaluation). Required unless the
      predictions are provided alongside of the features (i.e. model-agnostic
      evaluations).
    eval_config: Eval config.
    slice_spec: Deprecated (use EvalConfig).
    desired_batch_size: Optional batch size for batching in Predict.
    materialize: True to have extractors create materialized output.
  """
    if eval_config is not None:
        slice_spec = [
            slicer.SingleSliceSpec(spec=spec)
            for spec in eval_config.slicing_specs
        ]
    if (eval_shared_model and not isinstance(eval_shared_model, dict) and
        (not eval_shared_model.model_loader.tags
         or eval_constants.EVAL_TAG in eval_shared_model.model_loader.tags)):
        # Backwards compatibility for previous add_metrics_callbacks implementation.
        return [
            predict_extractor.PredictExtractor(eval_shared_model,
                                               desired_batch_size,
                                               materialize=materialize),
            slice_key_extractor.SliceKeyExtractor(slice_spec,
                                                  materialize=materialize)
        ]
    elif eval_shared_model:
        return [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_model=eval_shared_model,
                desired_batch_size=desired_batch_size),
            slice_key_extractor.SliceKeyExtractor(slice_spec,
                                                  materialize=materialize)
        ]
    else:
        return [
            input_extractor.InputExtractor(eval_config=eval_config),
            slice_key_extractor.SliceKeyExtractor(slice_spec,
                                                  materialize=materialize)
        ]
def default_extractors(  # pylint: disable=invalid-name
        eval_shared_model: Optional[types.EvalSharedModel] = None,
        eval_shared_models: Optional[List[types.EvalSharedModel]] = None,
        eval_config: config.EvalConfig = None,
        slice_spec: Optional[List[slicer.SingleSliceSpec]] = None,
        desired_batch_size: Optional[int] = None,
        materialize: Optional[bool] = True) -> List[extractor.Extractor]:
    """Returns the default extractors for use in ExtractAndEvaluate.

  Args:
    eval_shared_model: Shared model (single-model evaluation).
    eval_shared_models: Shared models (multi-model evaluation).
    eval_config: Eval config.
    slice_spec: Deprecated (use EvalConfig).
    desired_batch_size: Deprecated (use EvalConfig).
    materialize: True to have extractors create materialized output.
  """
    # TODO(b/141016373): Add support for multiple models.
    if eval_config is not None:
        slice_spec = [
            slicer.SingleSliceSpec(spec=spec)
            for spec in eval_config.slicing_specs
        ]
        if eval_config.options.HasField('desired_batch_size'):
            desired_batch_size = eval_config.options.desired_batch_size.value
    if eval_shared_model is not None:
        eval_shared_models = [eval_shared_model]
    if (not eval_shared_models[0].model_loader.tags or eval_constants.EVAL_TAG
            in eval_shared_models[0].model_loader.tags):
        # Backwards compatibility for previous EvalSavedModel implementation.
        return [
            predict_extractor.PredictExtractor(eval_shared_models[0],
                                               desired_batch_size,
                                               materialize=materialize),
            slice_key_extractor.SliceKeyExtractor(slice_spec,
                                                  materialize=materialize)
        ]
    else:
        return [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=eval_shared_models),
            slice_key_extractor.SliceKeyExtractor(slice_spec,
                                                  materialize=materialize)
        ]
예제 #3
0
    def benchmarkMiniPipelineUnbatched(self):
        """Benchmark an unbatched "mini" TFMA - predict, slice and compute metrics.

    Runs a "mini" version of TFMA in a Beam pipeline. Records the wall time
    taken for the whole pipeline.
    """
        self._init_model()
        pipeline = beam.Pipeline(runner=fn_api_runner.FnApiRunner())
        raw_data = (pipeline
                    | "Examples" >> beam.Create(
                        self._dataset.read_raw_dataset(deserialize=False,
                                                       limit=MAX_NUM_EXAMPLES))
                    | "InputsToExtracts" >> tfma.InputsToExtracts())

        _ = (raw_data
             | "InputExtractor" >> input_extractor.InputExtractor(
                 eval_config=self._eval_config).ptransform
             | "V2PredictExtractor" >> predict_extractor_v2.PredictExtractor(
                 eval_config=self._eval_config,
                 eval_shared_model=self._eval_shared_model).ptransform
             | "SliceKeyExtractor" >>
             tfma.extractors.SliceKeyExtractor().ptransform
             | "V2ComputeMetricsAndPlots" >>
             metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                 eval_config=self._eval_config,
                 eval_shared_model=self._eval_shared_model).ptransform)

        start = time.time()
        result = pipeline.run()
        result.wait_until_finish()
        end = time.time()
        delta = end - start

        self.report_benchmark(
            iters=1,
            wall_time=delta,
            extras={
                "num_examples":
                self._dataset.num_examples(limit=MAX_NUM_EXAMPLES)
            })
  def testWriteValidationResults(self):
    model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir()
    eval_shared_model = self._build_keras_model(model_dir, mul=0)
    baseline_eval_shared_model = self._build_keras_model(baseline_dir, mul=1)
    validations_file = os.path.join(self._getTempDir(),
                                    constants.VALIDATIONS_KEY)
    examples = [
        self._makeExample(
            input=0.0,
            label=1.0,
            example_weight=1.0,
            extra_feature='non_model_feature'),
        self._makeExample(
            input=1.0,
            label=0.0,
            example_weight=0.5,
            extra_feature='non_model_feature'),
    ]

    eval_config = config.EvalConfig(
        model_specs=[
            config.ModelSpec(
                name='candidate',
                label_key='label',
                example_weight_key='example_weight'),
            config.ModelSpec(
                name='baseline',
                label_key='label',
                example_weight_key='example_weight',
                is_baseline=True)
        ],
        slicing_specs=[config.SlicingSpec()],
        metrics_specs=[
            config.MetricsSpec(
                metrics=[
                    config.MetricConfig(
                        class_name='WeightedExampleCount',
                        # 1.5 < 1, NOT OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                upper_bound={'value': 1}))),
                    config.MetricConfig(
                        class_name='ExampleCount',
                        # 2 > 10, NOT OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                lower_bound={'value': 10}))),
                    config.MetricConfig(
                        class_name='MeanLabel',
                        # 0 > 0 and 0 > 0%?: NOT OK.
                        threshold=config.MetricThreshold(
                            change_threshold=config.GenericChangeThreshold(
                                direction=config.MetricDirection
                                .HIGHER_IS_BETTER,
                                relative={'value': 0},
                                absolute={'value': 0}))),
                    config.MetricConfig(
                        # MeanPrediction = (0+0)/(1+0.5) = 0
                        class_name='MeanPrediction',
                        # -.01 < 0 < .01, OK.
                        # Diff% = -.333/.333 = -100% < -99%, OK.
                        # Diff = 0 - .333 = -.333 < 0, OK.
                        threshold=config.MetricThreshold(
                            value_threshold=config.GenericValueThreshold(
                                upper_bound={'value': .01},
                                lower_bound={'value': -.01}),
                            change_threshold=config.GenericChangeThreshold(
                                direction=config.MetricDirection
                                .LOWER_IS_BETTER,
                                relative={'value': -.99},
                                absolute={'value': 0})))
                ],
                model_names=['candidate', 'baseline']),
        ],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}),
    )
    slice_spec = [
        slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
    ]
    eval_shared_models = {
        'candidate': eval_shared_model,
        'baseline': baseline_eval_shared_model
    }
    extractors = [
        input_extractor.InputExtractor(eval_config),
        predict_extractor_v2.PredictExtractor(
            eval_shared_model=eval_shared_models, eval_config=eval_config),
        slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
    ]
    evaluators = [
        metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
            eval_config=eval_config, eval_shared_model=eval_shared_models)
    ]
    output_paths = {
        constants.VALIDATIONS_KEY: validations_file,
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths, add_metrics_callbacks=[])
    ]

    with beam.Pipeline() as pipeline:

      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples])
          | 'ExtractEvaluateAndWriteResults' >>
          model_eval_lib.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              extractors=extractors,
              evaluators=evaluators,
              writers=writers))
      # pylint: enable=no-value-for-parameter

    validation_result = model_eval_lib.load_validation_result(
        os.path.dirname(validations_file))

    expected_validations = [
        text_format.Parse(
            """
            metric_key {
              name: "weighted_example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                upper_bound {
                  value: 1.0
                }
              }
            }
            metric_value {
              double_value {
                value: 1.5
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        text_format.Parse(
            """
            metric_key {
              name: "example_count"
            }
            metric_threshold {
              value_threshold {
                lower_bound {
                  value: 10.0
                }
              }
            }
            metric_value {
              double_value {
                value: 2.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        text_format.Parse(
            """
            metric_key {
              name: "mean_label"
              model_name: "candidate"
              is_diff: true
            }
            metric_threshold {
              change_threshold {
                absolute {
                  value: 0.0
                }
                relative {
                  value: 0.0
                }
                direction: HIGHER_IS_BETTER
              }
            }
            metric_value {
              double_value {
                value: 0.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
    ]
    self.assertFalse(validation_result.validation_ok)
    self.assertLen(validation_result.metric_validations_per_slice, 1)
    self.assertCountEqual(
        expected_validations,
        validation_result.metric_validations_per_slice[0].failures)
예제 #5
0
    def testInputExtractor(self):
        model_spec = config.ModelSpec(label_key='label',
                                      example_weight_key='example_weight')
        extractor = input_extractor.InputExtractor(
            eval_config=config.EvalConfig(model_specs=[model_spec]))

        examples = [
            self._makeExample(label=1.0,
                              example_weight=0.5,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(label=0.0,
                              example_weight=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string2'),
            self._makeExample(label=0.0,
                              example_weight=1.0,
                              fixed_int=2,
                              fixed_float=0.0,
                              fixed_string='fixed_string3')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | extractor.stage_name >> extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 3)
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                            'fixed_float': np.array([1.0]),
                        })
                    self.assertEqual(
                        got[0][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string1']))
                    self.assertAlmostEqual(got[0][constants.LABELS_KEY],
                                           np.array([1.0]))
                    self.assertAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                            'fixed_float': np.array([1.0]),
                        })
                    self.assertEqual(
                        got[1][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string2']))
                    self.assertAlmostEqual(got[1][constants.LABELS_KEY],
                                           np.array([0.0]))
                    self.assertAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]))
                    self.assertDictElementsAlmostEqual(
                        got[2][constants.FEATURES_KEY], {
                            'fixed_int': np.array([2]),
                            'fixed_float': np.array([0.0]),
                        })
                    self.assertEqual(
                        got[2][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string3']))
                    self.assertAlmostEqual(got[2][constants.LABELS_KEY],
                                           np.array([0.0]))
                    self.assertAlmostEqual(
                        got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
예제 #6
0
    def testInputExtractorMultiModel(self):
        model_spec1 = config.ModelSpec(name='model1',
                                       label_key='label',
                                       example_weight_key='example_weight',
                                       prediction_key='fixed_float')
        model_spec2 = config.ModelSpec(name='model2',
                                       label_keys={
                                           'output1': 'label1',
                                           'output2': 'label2'
                                       },
                                       example_weight_keys={
                                           'output1': 'example_weight1',
                                           'output2': 'example_weight2'
                                       },
                                       prediction_keys={
                                           'output1': 'fixed_float',
                                           'output2': 'fixed_float'
                                       })
        extractor = input_extractor.InputExtractor(
            eval_config=config.EvalConfig(
                model_specs=[model_spec1, model_spec2]))

        examples = [
            self._makeExample(label=1.0,
                              label1=1.0,
                              label2=0.0,
                              example_weight=0.5,
                              example_weight1=0.5,
                              example_weight2=0.5,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(label=1.0,
                              label1=1.0,
                              label2=1.0,
                              example_weight=0.0,
                              example_weight1=0.0,
                              example_weight2=1.0,
                              fixed_int=1,
                              fixed_float=2.0,
                              fixed_string='fixed_string2'),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | extractor.stage_name >> extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 2)
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                        })
                    self.assertEqual(
                        got[0][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string1']))
                    for model_name in ('model1', 'model2'):
                        self.assertIn(model_name, got[0][constants.LABELS_KEY])
                        self.assertIn(model_name,
                                      got[0][constants.EXAMPLE_WEIGHTS_KEY])
                        self.assertIn(model_name,
                                      got[0][constants.PREDICTIONS_KEY])
                    self.assertAlmostEqual(
                        got[0][constants.LABELS_KEY]['model1'],
                        np.array([1.0]))
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.LABELS_KEY]['model2'], {
                            'output1': np.array([1.0]),
                            'output2': np.array([0.0])
                        })
                    self.assertAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY]['model1'],
                        np.array([0.5]))
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY]['model2'], {
                            'output1': np.array([0.5]),
                            'output2': np.array([0.5])
                        })
                    self.assertAlmostEqual(
                        got[0][constants.PREDICTIONS_KEY]['model1'],
                        np.array([1.0]))
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.PREDICTIONS_KEY]['model2'], {
                            'output1': np.array([1.0]),
                            'output2': np.array([1.0])
                        })

                    self.assertDictElementsAlmostEqual(
                        got[1][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                        })
                    self.assertEqual(
                        got[1][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string2']))
                    for model_name in ('model1', 'model2'):
                        self.assertIn(model_name, got[1][constants.LABELS_KEY])
                        self.assertIn(model_name,
                                      got[1][constants.EXAMPLE_WEIGHTS_KEY])
                        self.assertIn(model_name,
                                      got[1][constants.PREDICTIONS_KEY])
                    self.assertAlmostEqual(
                        got[1][constants.LABELS_KEY]['model1'],
                        np.array([1.0]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.LABELS_KEY]['model2'], {
                            'output1': np.array([1.0]),
                            'output2': np.array([1.0])
                        })
                    self.assertAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY]['model1'],
                        np.array([0.0]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY]['model2'], {
                            'output1': np.array([0.0]),
                            'output2': np.array([1.0])
                        })
                    self.assertAlmostEqual(
                        got[1][constants.PREDICTIONS_KEY]['model1'],
                        np.array([2.0]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.PREDICTIONS_KEY]['model2'], {
                            'output1': np.array([2.0]),
                            'output2': np.array([2.0])
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
예제 #7
0
def default_extractors(  # pylint: disable=invalid-name
    eval_shared_model: Union[types.EvalSharedModel,
                             Dict[Text, types.EvalSharedModel]] = None,
    eval_config: config.EvalConfig = None,
    slice_spec: Optional[List[slicer.SingleSliceSpec]] = None,
    desired_batch_size: Optional[int] = None,
    materialize: Optional[bool] = True) -> List[extractor.Extractor]:
  """Returns the default extractors for use in ExtractAndEvaluate.

  Args:
    eval_shared_model: Shared model (single-model evaluation) or dict of shared
      models keyed by model name (multi-model evaluation). Required unless the
      predictions are provided alongside of the features (i.e. model-agnostic
      evaluations).
    eval_config: Eval config.
    slice_spec: Deprecated (use EvalConfig).
    desired_batch_size: Optional batch size for batching in Predict.
    materialize: True to have extractors create materialized output.

  Raises:
    NotImplementedError: If eval_config contains mixed serving and eval models.
  """
  if eval_config is not None:
    eval_config = config.update_eval_config_with_defaults(eval_config)
    slice_spec = [
        slicer.SingleSliceSpec(spec=spec) for spec in eval_config.slicing_specs
    ]
  if _is_legacy_eval(eval_shared_model, eval_config):
    # Backwards compatibility for previous add_metrics_callbacks implementation.
    return [
        predict_extractor.PredictExtractor(
            eval_shared_model, desired_batch_size, materialize=materialize),
        slice_key_extractor.SliceKeyExtractor(
            slice_spec, materialize=materialize)
    ]
  elif eval_shared_model:
    model_types = model_util.get_model_types(eval_config)
    if not model_types.issubset(constants.VALID_MODEL_TYPES):
      raise NotImplementedError(
          'model type must be one of: {}. evalconfig={}'.format(
              str(constants.VALID_MODEL_TYPES), eval_config))
    if model_types == set([constants.TF_LITE]):
      return [
          input_extractor.InputExtractor(eval_config=eval_config),
          tflite_predict_extractor.TFLitePredictExtractor(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              desired_batch_size=desired_batch_size),
          slice_key_extractor.SliceKeyExtractor(
              slice_spec, materialize=materialize)
      ]
    elif constants.TF_LITE in model_types:
      raise NotImplementedError(
          'support for mixing tf_lite and non-tf_lite models is not '
          'implemented: eval_config={}'.format(eval_config))

    elif (eval_config and all(s.signature_name == eval_constants.EVAL_TAG
                              for s in eval_config.model_specs)):
      return [
          predict_extractor.PredictExtractor(
              eval_shared_model,
              desired_batch_size,
              materialize=materialize,
              eval_config=eval_config),
          slice_key_extractor.SliceKeyExtractor(
              slice_spec, materialize=materialize)
      ]
    elif (eval_config and any(s.signature_name == eval_constants.EVAL_TAG
                              for s in eval_config.model_specs)):
      raise NotImplementedError(
          'support for mixing eval and non-eval models is not implemented: '
          'eval_config={}'.format(eval_config))
    else:
      return [
          input_extractor.InputExtractor(eval_config=eval_config),
          predict_extractor_v2.PredictExtractor(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              desired_batch_size=desired_batch_size),
          slice_key_extractor.SliceKeyExtractor(
              slice_spec, materialize=materialize)
      ]
  else:
    return [
        input_extractor.InputExtractor(eval_config=eval_config),
        slice_key_extractor.SliceKeyExtractor(
            slice_spec, materialize=materialize)
    ]
    def testEvaluateWithQueryBasedMetrics(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = (fixed_prediction_estimator_extra_fields.
                         simple_fixed_prediction_estimator_extra_fields(
                             None, temp_export_dir))
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='fixed_int')
            ],
            slicing_specs=[
                config.SlicingSpec(),
                config.SlicingSpec(feature_keys=['fixed_string']),
            ],
            metrics_specs=metric_specs.specs_from_metrics(
                [ndcg.NDCG(gain_key='fixed_float', name='ndcg')],
                binarize=config.BinarizationOptions(top_k_list=[1, 2]),
                query_key='fixed_string'))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        # fixed_string used as query_key
        # fixed_float used as gain_key for NDCG
        # fixed_int used as example_weight_key for NDCG
        examples = [
            self._makeExample(prediction=0.2,
                              label=1.0,
                              fixed_float=1.0,
                              fixed_string='query1',
                              fixed_int=1),
            self._makeExample(prediction=0.8,
                              label=0.0,
                              fixed_float=0.5,
                              fixed_string='query1',
                              fixed_int=1),
            self._makeExample(prediction=0.5,
                              label=0.0,
                              fixed_float=0.5,
                              fixed_string='query2',
                              fixed_int=2),
            self._makeExample(prediction=0.9,
                              label=1.0,
                              fixed_float=1.0,
                              fixed_string='query2',
                              fixed_int=2),
            self._makeExample(prediction=0.1,
                              label=0.0,
                              fixed_float=0.1,
                              fixed_string='query2',
                              fixed_int=2),
            self._makeExample(prediction=0.9,
                              label=1.0,
                              fixed_float=1.0,
                              fixed_string='query3',
                              fixed_int=3)
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 4)
                    slices = {}
                    for slice_key, value in got:
                        slices[slice_key] = value
                    overall_slice = ()
                    query1_slice = (('fixed_string', b'query1'), )
                    query2_slice = (('fixed_string', b'query2'), )
                    query3_slice = (('fixed_string', b'query3'), )
                    self.assertCountEqual(list(slices.keys()), [
                        overall_slice, query1_slice, query2_slice, query3_slice
                    ])
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    ndcg1_key = metric_types.MetricKey(
                        name='ndcg', sub_key=metric_types.SubKey(top_k=1))
                    ndcg2_key = metric_types.MetricKey(
                        name='ndcg', sub_key=metric_types.SubKey(top_k=2))
                    # Query1 (weight=1): (p=0.8, g=0.5) (p=0.2, g=1.0)
                    # Query2 (weight=2): (p=0.9, g=1.0) (p=0.5, g=0.5) (p=0.1, g=0.1)
                    # Query3 (weight=3): (p=0.9, g=1.0)
                    #
                    # DCG@1:  0.5, 1.0, 1.0
                    # NDCG@1: 0.5, 1.0, 1.0
                    # Average NDCG@1: (1 * 0.5 + 2 * 1.0 + 3 * 1.0) / (1 + 2 + 3) ~ 0.92
                    #
                    # DCG@2: (0.5 + 1.0/log(3) ~ 0.630930
                    #        (1.0 + 0.5/log(3) ~ 1.315465
                    #        1.0
                    # NDCG@2: (0.5 + 1.0/log(3)) / (1.0 + 0.5/log(3)) ~ 0.85972
                    #         (1.0 + 0.5/log(3)) / (1.0 + 0.5/log(3)) = 1.0
                    #         1.0
                    # Average NDCG@2: (1 * 0.860 + 2 * 1.0 + 3 * 1.0) / (1 + 2 + 3) ~ 0.97
                    self.assertDictElementsAlmostEqual(
                        slices[overall_slice], {
                            example_count_key: 6,
                            weighted_example_count_key: 11.0,
                            ndcg1_key: 0.9166667,
                            ndcg2_key: 0.9766198
                        })
                    self.assertDictElementsAlmostEqual(
                        slices[query1_slice], {
                            example_count_key: 2,
                            weighted_example_count_key: 2.0,
                            ndcg1_key: 0.5,
                            ndcg2_key: 0.85972
                        })
                    self.assertDictElementsAlmostEqual(
                        slices[query2_slice], {
                            example_count_key: 3,
                            weighted_example_count_key: 6.0,
                            ndcg1_key: 1.0,
                            ndcg2_key: 1.0
                        })
                    self.assertDictElementsAlmostEqual(
                        slices[query3_slice], {
                            example_count_key: 1,
                            weighted_example_count_key: 3.0,
                            ndcg1_key: 1.0,
                            ndcg2_key: 1.0
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithKerasModel(self):
        input1 = tf.keras.layers.Input(shape=(1, ), name='input1')
        input2 = tf.keras.layers.Input(shape=(1, ), name='input2')
        inputs = [input1, input2]
        input_layer = tf.keras.layers.concatenate(inputs)
        output_layer = tf.keras.layers.Dense(1,
                                             activation=tf.nn.sigmoid,
                                             name='output')(input_layer)
        model = tf.keras.models.Model(inputs, output_layer)
        model.compile(optimizer=tf.keras.optimizers.Adam(lr=.001),
                      loss=tf.keras.losses.binary_crossentropy,
                      metrics=['accuracy'])

        features = {'input1': [[0.0], [1.0]], 'input2': [[1.0], [0.0]]}
        labels = [[1], [0]]
        example_weights = [1.0, 0.5]
        dataset = tf.data.Dataset.from_tensor_slices(
            (features, labels, example_weights))
        dataset = dataset.shuffle(buffer_size=1).repeat().batch(2)
        model.fit(dataset, steps_per_epoch=1)

        export_dir = self._getExportDir()
        model.save(export_dir, save_format='tf')

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='example_weight')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics(
                [calibration.MeanLabel('mean_label')]))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(input1=0.0,
                              input2=1.0,
                              label=1.0,
                              example_weight=1.0,
                              extra_feature='non_model_feature'),
            self._makeExample(input1=1.0,
                              input2=0.0,
                              label=0.0,
                              example_weight=0.5,
                              extra_feature='non_model_feature'),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key: 2,
                            weighted_example_count_key: (1.0 + 0.5),
                            label_key: (1.0 * 1.0 + 0.0 * 0.5) / (1.0 + 0.5),
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithMultiOutputModel(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = multi_head.simple_multi_head(None, temp_export_dir)

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_keys={
                                     'chinese_head': 'chinese_label',
                                     'english_head': 'english_label',
                                     'other_head': 'other_label'
                                 },
                                 example_weight_keys={
                                     'chinese_head': 'age',
                                     'english_head': 'age',
                                     'other_head': 'age'
                                 })
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics({
                'chinese_head': [calibration.MeanLabel('mean_label')],
                'english_head': [calibration.MeanLabel('mean_label')],
                'other_head': [calibration.MeanLabel('mean_label')],
            }))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=1.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='other',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=1.0),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    chinese_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='chinese_head')
                    chinese_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='chinese_head')
                    english_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='english_head')
                    english_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='english_head')
                    other_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='other_head')
                    other_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='other_head')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            4,
                            chinese_label_key:
                            (0.0 + 1.0 + 2 * 0.0 + 2 * 1.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            chinese_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0),
                            english_label_key:
                            (1.0 + 0.0 + 2 * 1.0 + 2 * 0.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            english_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0),
                            other_label_key: (0.0 + 0.0 + 2 * 0.0 + 2 * 1.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            other_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0)
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithMultiClassModel(self):
        n_classes = 3
        temp_export_dir = self._getExportDir()
        _, export_dir = dnn_classifier.simple_dnn_classifier(
            None, temp_export_dir, n_classes=n_classes)

        # Add example_count and weighted_example_count
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='age')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics(
                [calibration.MeanLabel('mean_label')],
                binarize=config.BinarizationOptions(
                    class_ids=range(n_classes))))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0, language='english', label=0),
            self._makeExample(age=2.0, language='chinese', label=1),
            self._makeExample(age=3.0, language='english', label=2),
            self._makeExample(age=4.0, language='chinese', label=1),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key_class_0 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=0))
                    label_key_class_1 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=1))
                    label_key_class_2 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=2))
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            4,
                            weighted_example_count_key:
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_0:
                            (1 * 1.0 + 0 * 2.0 + 0 * 3.0 + 0 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_1:
                            (0 * 1.0 + 1 * 2.0 + 0 * 3.0 + 1 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_2:
                            (0 * 1.0 + 0 * 2.0 + 1 * 3.0 + 0 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0)
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithBinaryClassificationModel(self):
        n_classes = 2
        temp_export_dir = self._getExportDir()
        _, export_dir = dnn_classifier.simple_dnn_classifier(
            None, temp_export_dir, n_classes=n_classes)

        # Add mean_label, example_count, weighted_example_count, calibration_plot
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='age')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics([
                calibration.MeanLabel('mean_label'),
                calibration_plot.CalibrationPlot(name='calibration_plot',
                                                 num_buckets=10)
            ]))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0, language='english', label=0.0),
            self._makeExample(age=2.0, language='chinese', label=1.0),
            self._makeExample(age=3.0, language='chinese', label=0.0),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics_and_plots = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            3,
                            weighted_example_count_key: (1.0 + 2.0 + 3.0),
                            label_key:
                            (0 * 1.0 + 1 * 2.0 + 0 * 3.0) / (1.0 + 2.0 + 3.0),
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            def check_plots(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    plot_key = metric_types.PlotKey('calibration_plot')
                    self.assertIn(plot_key, got_plots)
                    # 10 buckets + 2 for edge cases
                    self.assertLen(got_plots[plot_key].buckets, 12)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics_and_plots[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
            util.assert_that(metrics_and_plots[constants.PLOTS_KEY],
                             check_plots,
                             label='plots')
    def testEvaluateWithRegressionModel(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = (fixed_prediction_estimator_extra_fields.
                         simple_fixed_prediction_estimator_extra_fields(
                             None, temp_export_dir))
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='fixed_float')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics([
                calibration.MeanLabel('mean_label'),
                calibration.MeanPrediction('mean_prediction')
            ]))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        # fixed_float used as example_weight key
        examples = [
            self._makeExample(prediction=0.2,
                              label=1.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.8,
                              label=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.5,
                              label=0.0,
                              fixed_int=2,
                              fixed_float=2.0,
                              fixed_string='fixed_string2')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    pred_key = metric_types.MetricKey(name='mean_prediction')
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key: 3,
                            weighted_example_count_key: 4.0,
                            label_key:
                            (1.0 + 0.0 + 2 * 0.0) / (1.0 + 1.0 + 2.0),
                            pred_key:
                            (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0),
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithConfidenceIntervals(self):
        # NOTE: This test does not actually test that confidence intervals are
        #   accurate it only tests that the proto output by the test is well formed.
        #   This test would pass if the confidence interval implementation did
        #   nothing at all except compute the unsampled value.
        temp_export_dir = self._getExportDir()
        _, export_dir = (fixed_prediction_estimator_extra_fields.
                         simple_fixed_prediction_estimator_extra_fields(
                             None, temp_export_dir))
        options = config.Options()
        options.compute_confidence_intervals.value = True
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='fixed_float')
            ],
            slicing_specs=[
                config.SlicingSpec(),
                config.SlicingSpec(feature_keys=['fixed_string']),
            ],
            metrics_specs=metric_specs.specs_from_metrics([
                calibration.MeanLabel('mean_label'),
                calibration.MeanPrediction('mean_prediction')
            ]),
            options=options)
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        # fixed_float used as example_weight key
        examples = [
            self._makeExample(prediction=0.2,
                              label=1.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.8,
                              label=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.5,
                              label=0.0,
                              fixed_int=2,
                              fixed_float=2.0,
                              fixed_string='fixed_string2')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 3)
                    slices = {}
                    for slice_key, value in got:
                        slices[slice_key] = value
                    overall_slice = ()
                    fixed_string1_slice = (('fixed_string',
                                            b'fixed_string1'), )
                    fixed_string2_slice = (('fixed_string',
                                            b'fixed_string2'), )
                    self.assertCountEqual(list(slices.keys()), [
                        overall_slice, fixed_string1_slice, fixed_string2_slice
                    ])
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    pred_key = metric_types.MetricKey(name='mean_prediction')
                    self.assertDictElementsWithTDistributionAlmostEqual(
                        slices[overall_slice], {
                            example_count_key: 3,
                            weighted_example_count_key: 4.0,
                            label_key:
                            (1.0 + 0.0 + 2 * 0.0) / (1.0 + 1.0 + 2.0),
                            pred_key:
                            (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0),
                        })
                    self.assertDictElementsWithTDistributionAlmostEqual(
                        slices[fixed_string1_slice], {
                            example_count_key: 2,
                            weighted_example_count_key: 2.0,
                            label_key: (1.0 + 0.0) / (1.0 + 1.0),
                            pred_key: (0.2 + 0.8) / (1.0 + 1.0),
                        })
                    self.assertDictElementsWithTDistributionAlmostEqual(
                        slices[fixed_string2_slice], {
                            example_count_key: 1,
                            weighted_example_count_key: 2.0,
                            label_key: (2 * 0.0) / 2.0,
                            pred_key: (2 * 0.5) / 2.0,
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')