def testLoadSavedModelDisallowsAdditionalFetchesWithFeatures(self):
     temp_eval_export_dir = self._getEvalExportDir()
     _, eval_export_dir = multi_head.simple_multi_head(
         None, temp_eval_export_dir)
     with self.assertRaisesRegexp(
             ValueError,
             'additional_fetches should not contain "features"'):
         load.EvalSavedModel(eval_export_dir,
                             additional_fetches=['features'])
    def testEvaluateWithAdditionalMetricsBasic(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = multi_head.simple_multi_head(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        _, prediction_dict, label_dict = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            metric_ops = {}
            value_op, update_op = tf.metrics.mean_absolute_error(
                label_dict['english_head'][0][0],
                prediction_dict['english_head/probabilities'][0][1])
            metric_ops['mean_absolute_error/english_head'] = (value_op,
                                                              update_op)

            value_op, update_op = tf.contrib.metrics.count(
                prediction_dict['english_head/logits'])
            metric_ops['example_count/english_head'] = (value_op, update_op)

            eval_saved_model.register_additional_metric_ops(metric_ops)

        example1 = self._makeMultiHeadExample('english')
        features_predictions_labels = self.predict_injective_single_example(
            eval_saved_model, example1.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)

        example2 = self._makeMultiHeadExample('chinese')
        features_predictions_labels = self.predict_injective_single_example(
            eval_saved_model, example2.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)

        metric_values = eval_saved_model.get_metric_values()

        # Check that the original metrics are still there.
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'accuracy/english_head': 1.0,
                'accuracy/chinese_head': 1.0,
                'accuracy/other_head': 1.0,
                'auc/english_head': 1.0,
                'auc/chinese_head': 1.0,
                'auc/other_head': 1.0,
                'label/mean/english_head': 0.5,
                'label/mean/chinese_head': 0.5,
                'label/mean/other_head': 0.0
            })

        # Check the added metrics.
        # We don't control the trained model's weights fully, but it should
        # predict probabilities > 0.7.
        self.assertIn('mean_absolute_error/english_head', metric_values)
        self.assertLess(metric_values['mean_absolute_error/english_head'], 0.3)

        self.assertHasKeyWithValueAlmostEqual(metric_values,
                                              'example_count/english_head',
                                              2.0)
Beispiel #3
0
    def testGetAndSetMetricVariables(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = multi_head.simple_multi_head(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        _, prediction_dict, _ = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            metric_ops = {}
            value_op, update_op = tf.contrib.metrics.count(
                prediction_dict['english_head/logits'])
            metric_ops['example_count/english_head'] = (value_op, update_op)

            eval_saved_model.register_additional_metric_ops(metric_ops)

        example1 = self._makeMultiHeadExample('english')
        features_predictions_labels = self.predict_injective_single_example(
            eval_saved_model, example1.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'label/mean/english_head': 1.0,
                'label/mean/chinese_head': 0.0,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })
        metric_variables = eval_saved_model.get_metric_variables()

        example2 = self._makeMultiHeadExample('chinese')
        features_predictions_labels = self.predict_injective_single_example(
            eval_saved_model, example2.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'label/mean/english_head': 0.5,
                'label/mean/chinese_head': 0.5,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 2.0
            })

        # Now set metric variables to what they were after the first example.
        eval_saved_model.set_metric_variables(metric_variables)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'label/mean/english_head': 1.0,
                'label/mean/chinese_head': 0.0,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })
    def testFairnessIndicatorsCounters(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = (multi_head.simple_multi_head(
            None, temp_eval_export_dir))

        examples = [
            self._makeExample(age=3.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=3.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=4.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=5.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=6.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
        ]
        fairness_english = post_export_metrics.fairness_indicators(
            target_prediction_keys=['english_head/logistic'],
            labels_key='english_head')
        fairness_chinese = post_export_metrics.fairness_indicators(
            target_prediction_keys=['chinese_head/logistic'],
            labels_key='chinese_head')

        def check_metric_counter(result):
            metric_filter = beam.metrics.metric.MetricsFilter().with_name(
                'metric_computed_fairness_indicators')
            actual_metrics_count = result.metrics().query(
                filter=metric_filter)['counters'][0].committed
            self.assertEqual(actual_metrics_count, 2)

        self._runTestWithCustomCheck(examples,
                                     eval_export_dir, [
                                         fairness_english,
                                         fairness_chinese,
                                     ],
                                     custom_result_check=check_metric_counter)
    def testEvaluateWithOnlyAdditionalMetricsBasic(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = multi_head.simple_multi_head(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir,
                                               include_default_metrics=False)
        _, prediction_dict, label_dict = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            metric_ops = {}
            value_op, update_op = tf.compat.v1.metrics.mean_absolute_error(
                label_dict['english_head'][0][0],
                prediction_dict['english_head/probabilities'][0][1])
            metric_ops['mean_absolute_error/english_head'] = (value_op,
                                                              update_op)

            value_op, update_op = metrics.total(
                tf.shape(input=prediction_dict['english_head/logits'])[0])
            metric_ops['example_count/english_head'] = (value_op, update_op)

            eval_saved_model.register_additional_metric_ops(metric_ops)

        example1 = self._makeMultiHeadExample('english').SerializeToString()
        example2 = self._makeMultiHeadExample('chinese').SerializeToString()
        eval_saved_model.metrics_reset_update_get_list([example1, example2])

        metric_values = eval_saved_model.get_metric_values()

        # Check that the original metrics are not there.
        self.assertNotIn('accuracy/english_head', metric_values)
        self.assertNotIn('accuracy/chinese_head', metric_values)
        self.assertNotIn('accuracy/other_head', metric_values)
        self.assertNotIn('auc/english_head', metric_values)
        self.assertNotIn('auc/chinese_head', metric_values)
        self.assertNotIn('auc/other_head', metric_values)
        self.assertNotIn('label/mean/english_head', metric_values)
        self.assertNotIn('label/mean/chinese_head', metric_values)
        self.assertNotIn('label/mean/other_head', metric_values)

        # Check the added metrics.
        # We don't control the trained model's weights fully, but it should
        # predict probabilities > 0.7.
        self.assertIn('mean_absolute_error/english_head', metric_values)
        self.assertLess(metric_values['mean_absolute_error/english_head'], 0.3)

        self.assertHasKeyWithValueAlmostEqual(metric_values,
                                              'example_count/english_head',
                                              2.0)
Beispiel #6
0
    def testResetMetricVariables(self):
        _, temp_eval_export_dir = self._getExportDirs()
        _, eval_export_dir = multi_head.simple_multi_head(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        _, prediction_dict, _ = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            metric_ops = {}
            value_op, update_op = tf.contrib.metrics.count(
                prediction_dict[('english_head', 'logits')])
            metric_ops['example_count/english_head'] = (value_op, update_op)

            eval_saved_model.register_additional_metric_ops(metric_ops)

        example1 = self._makeMultiHeadExample('english')
        features_predictions_labels = eval_saved_model.predict(
            example1.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'labels/actual_label_mean/english_head': 1.0,
                'labels/actual_label_mean/chinese_head': 0.0,
                'labels/actual_label_mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })
        eval_saved_model.reset_metric_variables()

        example2 = self._makeMultiHeadExample('chinese')
        features_predictions_labels = eval_saved_model.predict(
            example2.SerializeToString())
        eval_saved_model.perform_metrics_update(features_predictions_labels)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'labels/actual_label_mean/english_head': 0.0,
                'labels/actual_label_mean/chinese_head': 1.0,
                'labels/actual_label_mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })
    def testResetMetricVariables(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = multi_head.simple_multi_head(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        _, prediction_dict, _ = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            metric_ops = {}
            value_op, update_op = metrics.total(
                tf.shape(input=prediction_dict['english_head/logits'])[0])
            metric_ops['example_count/english_head'] = (value_op, update_op)

            eval_saved_model.register_additional_metric_ops(metric_ops)

        example1 = self._makeMultiHeadExample('english').SerializeToString()
        eval_saved_model.metrics_reset_update_get(example1)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'label/mean/english_head': 1.0,
                'label/mean/chinese_head': 0.0,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })
        eval_saved_model.reset_metric_variables()

        example2 = self._makeMultiHeadExample('chinese').SerializeToString()
        eval_saved_model.metrics_reset_update_get(example2)
        metric_values = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values, {
                'label/mean/english_head': 0.0,
                'label/mean/chinese_head': 1.0,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })
Beispiel #8
0
    def testMetricsResetUpdateGetList(self):
        temp_eval_export_dir = self._getEvalExportDir()
        _, eval_export_dir = multi_head.simple_multi_head(
            None, temp_eval_export_dir)

        eval_saved_model = load.EvalSavedModel(eval_export_dir)
        _, prediction_dict, _ = (
            eval_saved_model.get_features_predictions_labels_dicts())
        with eval_saved_model.graph_as_default():
            metric_ops = {}
            value_op, update_op = tf.contrib.metrics.count(
                prediction_dict['english_head/logits'])
            metric_ops['example_count/english_head'] = (value_op, update_op)

            eval_saved_model.register_additional_metric_ops(metric_ops)

        example1 = self._makeMultiHeadExample('english')
        features_predictions_labels1 = self.predict_injective_single_example(
            eval_saved_model, example1.SerializeToString())
        metric_variables1 = eval_saved_model.metrics_reset_update_get(
            features_predictions_labels1)

        example2 = self._makeMultiHeadExample('chinese')
        features_predictions_labels2 = self.predict_injective_single_example(
            eval_saved_model, example2.SerializeToString())
        metric_variables2 = eval_saved_model.metrics_reset_update_get(
            features_predictions_labels2)

        example3 = self._makeMultiHeadExample('other')
        features_predictions_labels3 = self.predict_injective_single_example(
            eval_saved_model, example3.SerializeToString())
        metric_variables3 = eval_saved_model.metrics_reset_update_get(
            features_predictions_labels3)

        eval_saved_model.set_metric_variables(metric_variables1)
        metric_values1 = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values1, {
                'label/mean/english_head': 1.0,
                'label/mean/chinese_head': 0.0,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })

        eval_saved_model.set_metric_variables(metric_variables2)
        metric_values2 = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values2, {
                'label/mean/english_head': 0.0,
                'label/mean/chinese_head': 1.0,
                'label/mean/other_head': 0.0,
                'example_count/english_head': 1.0
            })

        eval_saved_model.set_metric_variables(metric_variables3)
        metric_values3 = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values3, {
                'label/mean/english_head': 0.0,
                'label/mean/chinese_head': 0.0,
                'label/mean/other_head': 1.0,
                'example_count/english_head': 1.0
            })

        eval_saved_model.metrics_reset_update_get_list([
            features_predictions_labels1, features_predictions_labels2,
            features_predictions_labels3
        ])
        metric_values_combined = eval_saved_model.get_metric_values()
        self.assertDictElementsAlmostEqual(
            metric_values_combined, {
                'label/mean/english_head': 1.0 / 3.0,
                'label/mean/chinese_head': 1.0 / 3.0,
                'label/mean/other_head': 1.0 / 3.0,
                'example_count/english_head': 3.0
            })
  def testFairnessIndicatorsMultiHead(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = (
        multi_head.simple_multi_head(None, temp_eval_export_dir))

    examples = [
        self._makeExample(
            age=3.0,
            language='english',
            english_label=1.0,
            chinese_label=0.0,
            other_label=0.0),
        self._makeExample(
            age=3.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
        self._makeExample(
            age=4.0,
            language='english',
            english_label=1.0,
            chinese_label=0.0,
            other_label=0.0),
        self._makeExample(
            age=5.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
        self._makeExample(
            age=6.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
    ]
    fairness_english = post_export_metrics.fairness_indicators(
        target_prediction_keys=['english_head/logistic'],
        labels_key='english_head')
    fairness_chinese = post_export_metrics.fairness_indicators(
        target_prediction_keys=['chinese_head/logistic'],
        labels_key='chinese_head')

    def check_metric_result(got):
      try:
        self.assertEqual(1, len(got), 'got: %s' % got)
        (slice_key, value) = got[0]
        self.assertEqual((), slice_key)
        expected_values_dict = {
            metric_keys.base_key(
                'english_head/logistic/[email protected]'):
                1.0,
            metric_keys.base_key(
                'chinese_head/logistic/[email protected]'):
                1.0,
        }
        self.assertDictElementsAlmostEqual(value, expected_values_dict)
      except AssertionError as err:
        raise util.BeamAssertException(err)

    self._runTestWithCustomCheck(
        examples,
        eval_export_dir, [
            fairness_english,
            fairness_chinese,
        ],
        custom_metrics_check=check_metric_result)
    def testNativeEvalSavedModelMetricComputationsWithMultiHead(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = multi_head.simple_multi_head(None, temp_export_dir)

        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir)

        computation = (
            eval_saved_model_util.metric_computations_using_eval_saved_model(
                '', eval_shared_model.model_loader)[0])

        examples = [
            self._makeExample(age=1.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=1.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='other',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=1.0),
        ]

        extracts = []
        for e in examples:
            extracts.append({constants.INPUT_KEY: e.SerializeToString()})

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(extracts)
                | 'Process' >> beam.ParDo(computation.preprocessor)
                | 'ToStandardMetricInputs' >> beam.Map(
                    metric_types.StandardMetricInputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeMetric' >> beam.CombinePerKey(computation.combiner))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    chinese_accuracy_key = metric_types.MetricKey(
                        name='accuracy',
                        output_name='chinese_head',
                        example_weighted=None)
                    chinese_mean_label_key = metric_types.MetricKey(
                        name='label/mean',
                        output_name='chinese_head',
                        example_weighted=None)
                    english_accuracy_key = metric_types.MetricKey(
                        name='accuracy',
                        output_name='english_head',
                        example_weighted=None)
                    english_mean_label_key = metric_types.MetricKey(
                        name='label/mean',
                        output_name='english_head',
                        example_weighted=None)
                    other_accuracy_key = metric_types.MetricKey(
                        name='accuracy',
                        output_name='other_head',
                        example_weighted=None)
                    other_mean_label_key = metric_types.MetricKey(
                        name='label/mean',
                        output_name='other_head',
                        example_weighted=None)
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            chinese_accuracy_key: 0.75,
                            chinese_mean_label_key: 0.5,
                            english_accuracy_key: 1.0,
                            english_mean_label_key: 0.5,
                            other_accuracy_key: 1.0,
                            other_mean_label_key: 0.25
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testPredictExtractorWithMultiModels(self):
        temp_export_dir = self._getExportDir()
        export_dir1, _ = multi_head.simple_multi_head(temp_export_dir, None)
        export_dir2, _ = multi_head.simple_multi_head(temp_export_dir, None)

        eval_config = config.EvalConfig(model_specs=[
            config.ModelSpec(name='model1'),
            config.ModelSpec(name='model2')
        ])
        eval_shared_model1 = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir1, tags=[tf.saved_model.SERVING])
        eval_shared_model2 = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir2, tags=[tf.saved_model.SERVING])
        schema = text_format.Parse(
            """
        feature {
          name: "age"
          type: FLOAT
        }
        feature {
          name: "langauge"
          type: BYTES
        }
        feature {
          name: "english_label"
          type: FLOAT
        }
        feature {
          name: "chinese_label"
          type: FLOAT
        }
        feature {
          name: "other_label"
          type: FLOAT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model={
                'model1': eval_shared_model1,
                'model2': eval_shared_model2
            },
            tensor_adapter_config=tensor_adapter_config)

        examples = [
            self._makeExample(age=1.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=1.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='other',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=1.0)
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=4)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    for item in got:
                        # We can't verify the actual predictions, but we can verify the keys
                        self.assertIn(constants.BATCHED_PREDICTIONS_KEY, item)
                        for pred in item[constants.BATCHED_PREDICTIONS_KEY]:
                            for model_name in ('model1', 'model2'):
                                self.assertIn(model_name, pred)
                                for output_name in ('chinese_head',
                                                    'english_head',
                                                    'other_head'):
                                    for pred_key in ('logistic',
                                                     'probabilities',
                                                     'all_classes'):
                                        self.assertIn(
                                            output_name + '/' + pred_key,
                                            pred[model_name])

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
  def testPredictExtractorWithMultiModels(self):
    temp_export_dir = self._getExportDir()
    export_dir1, _ = multi_head.simple_multi_head(temp_export_dir, None)
    export_dir2, _ = multi_head.simple_multi_head(temp_export_dir, None)

    eval_config = config.EvalConfig(model_specs=[
        config.ModelSpec(name='model1'),
        config.ModelSpec(name='model2')
    ])
    eval_shared_model1 = self.createTestEvalSharedModel(
        eval_saved_model_path=export_dir1, tags=[tf.saved_model.SERVING])
    eval_shared_model2 = self.createTestEvalSharedModel(
        eval_saved_model_path=export_dir2, tags=[tf.saved_model.SERVING])
    predict_extractor = predict_extractor_v2.PredictExtractor(
        eval_config=eval_config,
        eval_shared_model={
            'model1': eval_shared_model1,
            'model2': eval_shared_model2
        })

    examples = [
        self._makeExample(
            age=1.0,
            language='english',
            english_label=1.0,
            chinese_label=0.0,
            other_label=0.0),
        self._makeExample(
            age=1.0,
            language='chinese',
            english_label=0.0,
            chinese_label=1.0,
            other_label=0.0),
        self._makeExample(
            age=2.0,
            language='english',
            english_label=1.0,
            chinese_label=0.0,
            other_label=0.0),
        self._makeExample(
            age=2.0,
            language='other',
            english_label=0.0,
            chinese_label=1.0,
            other_label=1.0)
    ]

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples],
                                    reshuffle=False)
          | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
          | predict_extractor.stage_name >> predict_extractor.ptransform)

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 4)
          for item in got:
            # We can't verify the actual predictions, but we can verify the keys
            self.assertIn(constants.PREDICTIONS_KEY, item)
            for model_name in ('model1', 'model2'):
              self.assertIn(model_name, item[constants.PREDICTIONS_KEY])
              for output_name in ('chinese_head', 'english_head', 'other_head'):
                for pred_key in ('logistic', 'probabilities', 'all_classes'):
                  self.assertIn(output_name + '/' + pred_key,
                                item[constants.PREDICTIONS_KEY][model_name])

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')
    def testEvaluateWithMultiOutputModel(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = multi_head.simple_multi_head(None, temp_export_dir)

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_keys={
                                     'chinese_head': 'chinese_label',
                                     'english_head': 'english_label',
                                     'other_head': 'other_label'
                                 },
                                 example_weight_keys={
                                     'chinese_head': 'age',
                                     'english_head': 'age',
                                     'other_head': 'age'
                                 })
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics({
                'chinese_head': [calibration.MeanLabel('mean_label')],
                'english_head': [calibration.MeanLabel('mean_label')],
                'other_head': [calibration.MeanLabel('mean_label')],
            }))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=1.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='other',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=1.0),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    chinese_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='chinese_head')
                    chinese_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='chinese_head')
                    english_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='english_head')
                    english_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='english_head')
                    other_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='other_head')
                    other_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='other_head')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            4,
                            chinese_label_key:
                            (0.0 + 1.0 + 2 * 0.0 + 2 * 1.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            chinese_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0),
                            english_label_key:
                            (1.0 + 0.0 + 2 * 1.0 + 2 * 0.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            english_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0),
                            other_label_key: (0.0 + 0.0 + 2 * 0.0 + 2 * 1.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            other_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0)
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')