def default_eval_shared_model(eval_saved_model_path,
                              add_metrics_callbacks=None,
                              example_weight_key=None):
    """Returns default EvalSharedModel.

  Args:
    eval_saved_model_path: Path to EvalSavedModel.
    add_metrics_callbacks: Optional list of callbacks for adding additional
      metrics to the graph (see EvalSharedModel for more information on how to
      configure additional metrics). Metrics for example counts and example
      weight will be added automatically.
    example_weight_key: The key of the example weight column. If None, weight
      will be 1 for each example.
  """
    # Always compute example weight and example count.
    # pytype: disable=module-attr
    if not add_metrics_callbacks:
        add_metrics_callbacks = []
    example_count_callback = post_export_metrics.example_count()
    add_metrics_callbacks.append(example_count_callback)
    if example_weight_key:
        example_weight_callback = post_export_metrics.example_weight(
            example_weight_key)
        add_metrics_callbacks.append(example_weight_callback)
    # pytype: enable=module-attr

    return types.EvalSharedModel(model_path=eval_saved_model_path,
                                 add_metrics_callbacks=add_metrics_callbacks,
                                 example_weight_key=example_weight_key)
    def testEvaluateNoSlicingAddPostExportAndCustomMetricsUnsupervisedModel(
            self):
        # Mainly for testing that the ExampleCount post export metric works with
        # unsupervised models.
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (fixed_prediction_estimator_no_labels.
                              simple_fixed_prediction_estimator_no_labels(
                                  None, temp_eval_export_dir))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                post_export_metrics.example_count(),
                post_export_metrics.example_weight(
                    example_weight_key='prediction')
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=1.0)
            example2 = self._makeExample(prediction=2.0)

            metrics, plots = (
                pipeline
                | 'Create' >> beam.Create([
                    example1.SerializeToString(),
                    example2.SerializeToString(),
                ])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
                | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator.
                ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

            def check_result(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictElementsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            'average_loss': 2.5,
                            metric_keys.EXAMPLE_COUNT: 2.0,
                            metric_keys.EXAMPLE_WEIGHT: 3.0
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics, check_result, label='metrics')
            util.assert_that(plots, util.is_empty(), label='plots')
Esempio n. 3
0
def default_eval_shared_model(
        eval_saved_model_path: Text,
        add_metrics_callbacks: Optional[List[
            types.AddMetricsCallbackType]] = None,
        include_default_metrics: Optional[bool] = True,
        example_weight_key: Optional[Text] = None,
        additional_fetches: Optional[List[Text]] = None
) -> types.EvalSharedModel:
    """Returns default EvalSharedModel.

  Args:
    eval_saved_model_path: Path to EvalSavedModel.
    add_metrics_callbacks: Optional list of callbacks for adding additional
      metrics to the graph (see EvalSharedModel for more information on how to
      configure additional metrics). Metrics for example counts and example
      weight will be added automatically.
    include_default_metrics: True to include the default metrics that are part
      of the saved model graph during evaluation.
    example_weight_key: Deprecated.
    additional_fetches: Prefixes of additional tensors stored in
      signature_def.inputs that should be fetched at prediction time. The
      "features" and "labels" tensors are handled automatically and should not
      be included.
  """
    # Always compute example weight and example count.
    # PyType doesn't know about the magic exports we do in post_export_metrics.
    # Additionally, the lines seem to get reordered in compilation, so we can't
    # just put the disable-attr on the add_metrics_callbacks lines.
    # pytype: disable=module-attr
    if not add_metrics_callbacks:
        add_metrics_callbacks = []
    example_count_callback = post_export_metrics.example_count()
    add_metrics_callbacks.append(example_count_callback)
    # TODO(b/126924645): Remove
    if example_weight_key:
        example_weight_callback = post_export_metrics.example_weight(
            example_weight_key)
        add_metrics_callbacks.append(example_weight_callback)
    # pytype: enable=module-attr

    return types.EvalSharedModel(
        model_path=eval_saved_model_path,
        add_metrics_callbacks=add_metrics_callbacks,
        include_default_metrics=include_default_metrics,
        example_weight_key=example_weight_key,
        additional_fetches=additional_fetches,
        construct_fn=dofn.make_construct_fn(
            eval_saved_model_path,
            add_metrics_callbacks,
            include_default_metrics,
            additional_fetches=additional_fetches))
 def testExampleCountNoStandardKeys(self):
   # Test ExampleCount with a custom Estimator that doesn't have any of the
   # standard PredictionKeys.
   temp_eval_export_dir = self._getEvalExportDir()
   _, eval_export_dir = (
       fixed_prediction_estimator.simple_fixed_prediction_estimator(
           None, temp_eval_export_dir, output_prediction_key='non_standard'))
   examples = [
       self._makeExample(prediction=5.0, label=5.0),
       self._makeExample(prediction=6.0, label=6.0),
       self._makeExample(prediction=7.0, label=7.0),
   ]
   expected_values_dict = {
       metric_keys.EXAMPLE_COUNT: 3.0,
   }
   self._runTest(examples, eval_export_dir, [
       post_export_metrics.example_count(),
   ], expected_values_dict)
 def testPostExportMetricsDNNRegressor(self):
   temp_eval_export_dir = self._getEvalExportDir()
   _, eval_export_dir = dnn_regressor.simple_dnn_regressor(
       None, temp_eval_export_dir)
   examples = [
       self._makeExample(age=3.0, language='english', label=1.0),
       self._makeExample(age=3.0, language='chinese', label=0.0),
       self._makeExample(age=4.0, language='english', label=1.0),
       self._makeExample(age=5.0, language='chinese', label=0.0)
   ]
   expected_values_dict = {
       metric_keys.EXAMPLE_COUNT: 4.0,
       metric_keys.EXAMPLE_WEIGHT: 15.0,
   }
   self._runTest(examples, eval_export_dir, [
       post_export_metrics.example_count(),
       post_export_metrics.example_weight('age')
   ], expected_values_dict)
 def testExampleCountEmptyPredictionsDict(self):
   # Test ExampleCount with a custom Estimator that has empty predictions dict.
   # This is possible if the Estimator doesn't return the predictions dict
   # in EVAL mode, but computes predictions and feeds them into the metrics
   # internally.
   temp_eval_export_dir = self._getEvalExportDir()
   _, eval_export_dir = (
       fixed_prediction_estimator.simple_fixed_prediction_estimator(
           None, temp_eval_export_dir, output_prediction_key=None))
   examples = [
       self._makeExample(prediction=5.0, label=5.0),
       self._makeExample(prediction=6.0, label=6.0),
       self._makeExample(prediction=7.0, label=7.0),
   ]
   expected_values_dict = {
       metric_keys.EXAMPLE_COUNT: 3.0,
   }
   self._runTest(examples, eval_export_dir, [
       post_export_metrics.example_count(),
   ], expected_values_dict)
Esempio n. 7
0
def default_eval_shared_model(eval_saved_model_path,
                              add_metrics_callbacks=None,
                              include_default_metrics=True,
                              example_weight_key=None):
    """Returns default EvalSharedModel.

  Args:
    eval_saved_model_path: Path to EvalSavedModel.
    add_metrics_callbacks: Optional list of callbacks for adding additional
      metrics to the graph (see EvalSharedModel for more information on how to
      configure additional metrics). Metrics for example counts and example
      weight will be added automatically.
    include_default_metrics: True to include the default metrics that are part
      of the saved model graph during evaluation.
    example_weight_key: The key of the example weight column. If None, weight
      will be 1 for each example.
  """
    # Always compute example weight and example count.
    # PyType doesn't know about the magic exports we do in post_export_metrics.
    # Additionally, the lines seem to get reordered in compilation, so we can't
    # just put the disable-attr on the add_metrics_callbacks lines.
    # pytype: disable=module-attr
    if not add_metrics_callbacks:
        add_metrics_callbacks = []
    example_count_callback = post_export_metrics.example_count()
    add_metrics_callbacks.append(example_count_callback)
    if example_weight_key:
        example_weight_callback = post_export_metrics.example_weight(
            example_weight_key)
        add_metrics_callbacks.append(example_weight_callback)
    # pytype: enable=module-attr

    return types.EvalSharedModel(
        model_path=eval_saved_model_path,
        add_metrics_callbacks=add_metrics_callbacks,
        include_default_metrics=include_default_metrics,
        example_weight_key=example_weight_key,
        construct_fn=dofn.make_construct_fn(eval_saved_model_path,
                                            add_metrics_callbacks,
                                            include_default_metrics))
    def testEvaluateWithPlots(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (
            fixed_prediction_estimator.simple_fixed_prediction_estimator(
                None, temp_eval_export_dir))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                post_export_metrics.example_count(),
                post_export_metrics.auc_plots()
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(prediction=0.0, label=1.0)
            example2 = self._makeExample(prediction=0.7, label=0.0)
            example3 = self._makeExample(prediction=0.8, label=1.0)
            example4 = self._makeExample(prediction=1.0, label=1.0)

            metrics, plots = (
                pipeline
                | 'Create' >> beam.Create([
                    example1.SerializeToString(),
                    example2.SerializeToString(),
                    example3.SerializeToString(),
                    example4.SerializeToString()
                ])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
                | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator.
                ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

            def check_metrics(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictElementsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            metric_keys.EXAMPLE_COUNT: 4.0,
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics, check_metrics, label='metrics')

            def check_plots(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictMatrixRowsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            _full_key(metric_keys.AUC_PLOTS_MATRICES):
                            [(8001, [2, 1, 0, 1, 1.0 / 1.0, 1.0 / 3.0])],
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(plots, check_plots, label='plots')
    def testEvaluateNoSlicingAddPostExportAndCustomMetrics(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = linear_classifier.simple_linear_classifier(
            None, temp_eval_export_dir)
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=eval_export_dir,
            add_metrics_callbacks=[
                _addExampleCountMetricCallback,
                # Note that since everything runs in-process this doesn't
                # actually test that the py_func can be correctly recreated
                # on workers in a distributed context.
                _addPyFuncMetricCallback,
                post_export_metrics.example_count(),
                post_export_metrics.example_weight(example_weight_key='age')
            ])
        extractors = [
            predict_extractor.PredictExtractor(eval_shared_model),
            slice_key_extractor.SliceKeyExtractor()
        ]

        with beam.Pipeline() as pipeline:
            example1 = self._makeExample(age=3.0,
                                         language='english',
                                         label=1.0)
            example2 = self._makeExample(age=3.0,
                                         language='chinese',
                                         label=0.0)
            example3 = self._makeExample(age=4.0,
                                         language='english',
                                         label=1.0)
            example4 = self._makeExample(age=5.0,
                                         language='chinese',
                                         label=0.0)

            metrics, plots = (
                pipeline
                | 'Create' >> beam.Create([
                    example1.SerializeToString(),
                    example2.SerializeToString(),
                    example3.SerializeToString(),
                    example4.SerializeToString()
                ])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'Extract' >> tfma_unit.Extract(extractors=extractors)  # pylint: disable=no-value-for-parameter
                | 'ComputeMetricsAndPlots' >> metrics_and_plots_evaluator.
                ComputeMetricsAndPlots(eval_shared_model=eval_shared_model))

            def check_result(got):
                try:
                    self.assertEqual(1, len(got), 'got: %s' % got)
                    (slice_key, value) = got[0]
                    self.assertEqual((), slice_key)
                    self.assertDictElementsAlmostEqual(
                        got_values_dict=value,
                        expected_values_dict={
                            'accuracy': 1.0,
                            'label/mean': 0.5,
                            'my_mean_age': 3.75,
                            'my_mean_age_times_label': 1.75,
                            'added_example_count': 4.0,
                            'py_func_label_sum': 2.0,
                            metric_keys.EXAMPLE_COUNT: 4.0,
                            metric_keys.EXAMPLE_WEIGHT: 15.0
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics, check_result, label='metrics')
            util.assert_that(plots, util.is_empty(), label='plots')
  def testWriteMetricsAndPlots(self):
    metrics_file = os.path.join(self._getTempDir(), 'metrics')
    plots_file = os.path.join(self._getTempDir(), 'plots')
    temp_eval_export_dir = os.path.join(self._getTempDir(), 'eval_export_dir')

    _, eval_export_dir = (
        fixed_prediction_estimator.simple_fixed_prediction_estimator(
            None, temp_eval_export_dir))
    eval_config = config.EvalConfig(
        model_specs=[config.ModelSpec()],
        options=config.Options(
            disabled_outputs={'values': ['eval_config.json']}))
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=eval_export_dir,
        add_metrics_callbacks=[
            post_export_metrics.example_count(),
            post_export_metrics.calibration_plot_and_prediction_histogram(
                num_buckets=2)
        ])
    extractors = [
        predict_extractor.PredictExtractor(eval_shared_model),
        slice_key_extractor.SliceKeyExtractor()
    ]
    evaluators = [
        metrics_and_plots_evaluator.MetricsAndPlotsEvaluator(eval_shared_model)
    ]
    output_paths = {
        constants.METRICS_KEY: metrics_file,
        constants.PLOTS_KEY: plots_file
    }
    writers = [
        metrics_plots_and_validations_writer.MetricsPlotsAndValidationsWriter(
            output_paths, eval_shared_model.add_metrics_callbacks)
    ]

    with beam.Pipeline() as pipeline:
      example1 = self._makeExample(prediction=0.0, label=1.0)
      example2 = self._makeExample(prediction=1.0, label=1.0)

      # pylint: disable=no-value-for-parameter
      _ = (
          pipeline
          | 'Create' >> beam.Create([
              example1.SerializeToString(),
              example2.SerializeToString(),
          ])
          | 'ExtractEvaluateAndWriteResults' >>
          model_eval_lib.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              extractors=extractors,
              evaluators=evaluators,
              writers=writers))
      # pylint: enable=no-value-for-parameter

    expected_metrics_for_slice = text_format.Parse(
        """
        slice_key {}
        metrics {
          key: "average_loss"
          value {
            double_value {
              value: 0.5
            }
          }
        }
        metrics {
          key: "post_export_metrics/example_count"
          value {
            double_value {
              value: 2.0
            }
          }
        }
        """, metrics_for_slice_pb2.MetricsForSlice())

    metric_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(metrics_file):
      metric_records.append(
          metrics_for_slice_pb2.MetricsForSlice.FromString(record))
    self.assertEqual(1, len(metric_records), 'metrics: %s' % metric_records)
    self.assertProtoEquals(expected_metrics_for_slice, metric_records[0])

    expected_plots_for_slice = text_format.Parse(
        """
      slice_key {}
      plots {
        key: "post_export_metrics"
        value {
          calibration_histogram_buckets {
            buckets {
              lower_threshold_inclusive: -inf
              num_weighted_examples {}
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              upper_threshold_exclusive: 0.5
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 0.5
              upper_threshold_exclusive: 1.0
              num_weighted_examples {
              }
              total_weighted_label {}
              total_weighted_refined_prediction {}
            }
            buckets {
              lower_threshold_inclusive: 1.0
              upper_threshold_exclusive: inf
              num_weighted_examples {
                value: 1.0
              }
              total_weighted_label {
                value: 1.0
              }
              total_weighted_refined_prediction {
                value: 1.0
              }
            }
         }
        }
      }
    """, metrics_for_slice_pb2.PlotsForSlice())

    plot_records = []
    for record in tf.compat.v1.python_io.tf_record_iterator(plots_file):
      plot_records.append(
          metrics_for_slice_pb2.PlotsForSlice.FromString(record))
    self.assertEqual(1, len(plot_records), 'plots: %s' % plot_records)
    self.assertProtoEquals(expected_plots_for_slice, plot_records[0])
Esempio n. 11
0
def default_eval_shared_model(
    eval_saved_model_path: Text,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    include_default_metrics: Optional[bool] = True,
    example_weight_key: Optional[Union[Text, Dict[Text, Text]]] = None,
    additional_fetches: Optional[List[Text]] = None,
    blacklist_feature_fetches: Optional[List[Text]] = None,
    tags: Optional[List[Text]] = None,
    eval_config: Optional[config.EvalConfig] = None) -> types.EvalSharedModel:
  """Returns default EvalSharedModel.

  Args:
    eval_saved_model_path: Path to EvalSavedModel.
    add_metrics_callbacks: Optional list of callbacks for adding additional
      metrics to the graph (see EvalSharedModel for more information on how to
      configure additional metrics). Metrics for example count and example
      weights will be added automatically.
    include_default_metrics: True to include the default metrics that are part
      of the saved model graph during evaluation. Note that
      eval_config.options.include_default_metrics must also be true.
    example_weight_key: Example weight key (single-output model) or dict of
      example weight keys (multi-output model) keyed by output name.
    additional_fetches: Prefixes of additional tensors stored in
      signature_def.inputs that should be fetched at prediction time. The
      "features" and "labels" tensors are handled automatically and should not
      be included.
    blacklist_feature_fetches: List of tensor names in the features dictionary
      which should be excluded from the fetches request. This is useful in
      scenarios where features are large (e.g. images) and can lead to excessive
      memory use if stored.
    tags: Model tags (e.g. 'serve' for serving or 'eval' for EvalSavedModel).
    eval_config: Eval config. Only used for setting default tags.
  """
  if tags is None:
    if eval_config:
      # Default to serving unless all the signature_names are eval. We do not
      # support running with a mixture of eval and non-eval tags.
      signatures = [s.signature_name for s in eval_config.model_specs]
      if eval_constants.EVAL_TAG in signatures:
        if not all(s == eval_constants.EVAL_TAG for s in signatures):
          tf.compat.v1.logging.warning(
              'mixture of eval and non-eval signatures used: '
              'eval_config={}'.format(eval_config))
        tags = [eval_constants.EVAL_TAG]
      else:
        tags = [tf.saved_model.SERVING]
    else:
      tags = [eval_constants.EVAL_TAG]

  # Backwards compatibility for legacy add_metrics_callbacks implementation.
  if tags == [eval_constants.EVAL_TAG]:
    # PyType doesn't know about the magic exports we do in post_export_metrics.
    # Additionally, the lines seem to get reordered in compilation, so we can't
    # just put the disable-attr on the add_metrics_callbacks lines.
    # pytype: disable=module-attr
    if not add_metrics_callbacks:
      add_metrics_callbacks = []
    # Always compute example weight and example count.
    example_count_callback = post_export_metrics.example_count()
    add_metrics_callbacks.append(example_count_callback)
    if example_weight_key:
      if isinstance(example_weight_key, dict):
        for output_name, key in example_weight_key.items():
          example_weight_callback = post_export_metrics.example_weight(
              key, metric_tag=output_name)
          add_metrics_callbacks.append(example_weight_callback)
      else:
        example_weight_callback = post_export_metrics.example_weight(
            example_weight_key)
        add_metrics_callbacks.append(example_weight_callback)
    # pytype: enable=module-attr

  return types.EvalSharedModel(
      model_path=eval_saved_model_path,
      add_metrics_callbacks=add_metrics_callbacks,
      include_default_metrics=include_default_metrics,
      example_weight_key=example_weight_key,
      additional_fetches=additional_fetches,
      model_loader=types.ModelLoader(
          tags=tags,
          construct_fn=model_util.model_construct_fn(
              eval_saved_model_path=eval_saved_model_path,
              add_metrics_callbacks=add_metrics_callbacks,
              include_default_metrics=include_default_metrics,
              additional_fetches=additional_fetches,
              blacklist_feature_fetches=blacklist_feature_fetches,
              tags=tags)))
def default_eval_shared_model(
    eval_saved_model_path: Text,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    include_default_metrics: Optional[bool] = True,
    example_weight_key: Optional[Union[Text, Dict[Text, Text]]] = None,
    additional_fetches: Optional[List[Text]] = None,
    blacklist_feature_fetches: Optional[List[Text]] = None
) -> types.EvalSharedModel:
  """Returns default EvalSharedModel.

  Args:
    eval_saved_model_path: Path to EvalSavedModel.
    add_metrics_callbacks: Optional list of callbacks for adding additional
      metrics to the graph (see EvalSharedModel for more information on how to
      configure additional metrics). Metrics for example count and example
      weights will be added automatically.
    include_default_metrics: True to include the default metrics that are part
      of the saved model graph during evaluation.
    example_weight_key: Example weight key (single-output model) or dict of
      example weight keys (multi-output model) keyed by output name.
    additional_fetches: Prefixes of additional tensors stored in
      signature_def.inputs that should be fetched at prediction time. The
      "features" and "labels" tensors are handled automatically and should not
      be included.
    blacklist_feature_fetches: List of tensor names in the features dictionary
      which should be excluded from the fetches request. This is useful in
      scenarios where features are large (e.g. images) and can lead to excessive
      memory use if stored.
  """
  # Always compute example weight and example count.
  # PyType doesn't know about the magic exports we do in post_export_metrics.
  # Additionally, the lines seem to get reordered in compilation, so we can't
  # just put the disable-attr on the add_metrics_callbacks lines.
  # pytype: disable=module-attr
  if not add_metrics_callbacks:
    add_metrics_callbacks = []
  example_count_callback = post_export_metrics.example_count()
  add_metrics_callbacks.append(example_count_callback)
  if example_weight_key:
    if isinstance(example_weight_key, dict):
      for output_name, key in example_weight_key.items():
        example_weight_callback = post_export_metrics.example_weight(
            key, metric_tag=output_name)
        add_metrics_callbacks.append(example_weight_callback)
    else:
      example_weight_callback = post_export_metrics.example_weight(
          example_weight_key)
      add_metrics_callbacks.append(example_weight_callback)
  # pytype: enable=module-attr

  return types.EvalSharedModel(
      model_path=eval_saved_model_path,
      add_metrics_callbacks=add_metrics_callbacks,
      include_default_metrics=include_default_metrics,
      example_weight_key=example_weight_key,
      additional_fetches=additional_fetches,
      construct_fn=dofn.make_construct_fn(
          eval_saved_model_path,
          add_metrics_callbacks,
          include_default_metrics,
          additional_fetches=additional_fetches,
          blacklist_feature_fetches=blacklist_feature_fetches))