Esempio n. 1
0
    def benchmarkMiniPipelineBatched(self):
        """Benchmark a batched "mini" TFMA - predict, slice and compute metrics.

    Runs a "mini" version of TFMA in a Beam pipeline. Records the wall time
    taken for the whole pipeline.
    """
        self._init_model()
        pipeline = self._create_beam_pipeline()
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=benchmark_utils.read_schema(
                self._dataset.tf_metadata_schema_path()),
            raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        raw_data = (pipeline
                    | "Examples" >> beam.Create(
                        self._dataset.read_raw_dataset(deserialize=False,
                                                       limit=MAX_NUM_EXAMPLES))
                    | "BatchExamples" >> tfx_io.BeamSource()
                    | "InputsToExtracts" >> tfma.BatchedInputsToExtracts())

        _ = (raw_data
             | "BatchedInputExtractor" >> batched_input_extractor.
             BatchedInputExtractor(eval_config=self._eval_config).ptransform
             | "V2BatchedPredictExtractor" >>
             batched_predict_extractor_v2.BatchedPredictExtractor(
                 eval_config=self._eval_config,
                 eval_shared_model=self._eval_shared_model).ptransform
             | "UnbatchExtractor" >>
             unbatch_extractor.UnbatchExtractor().ptransform
             | "SliceKeyExtractor" >>
             tfma.extractors.SliceKeyExtractor().ptransform
             | "V2ComputeMetricsAndPlots" >>
             metrics_plots_and_validations_evaluator.
             MetricsPlotsAndValidationsEvaluator(
                 eval_config=self._eval_config,
                 eval_shared_model=self._eval_shared_model).ptransform)

        start = time.time()
        result = pipeline.run()
        result.wait_until_finish()
        end = time.time()
        delta = end - start

        self.report_benchmark(
            iters=1,
            wall_time=delta,
            extras={
                "num_examples":
                self._dataset.num_examples(limit=MAX_NUM_EXAMPLES)
            })
    def testBatchSizeLimit(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = batch_size_limited_classifier.simple_batch_size_limited_classifier(
            None, temp_export_dir)
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
        eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
        schema = text_format.Parse(
            """
        feature {
          name: "classes"
          type: BYTES
        }
        feature {
          name: "scores"
          type: FLOAT
        }
        feature {
          name: "labels"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model=eval_shared_model,
            tensor_adapter_config=tensor_adapter_config)

        examples = []
        for _ in range(4):
            examples.append(
                self._makeExample(classes='first', scores=0.0, labels='third'))

        with beam.Pipeline() as pipeline:
            predict_extracts = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=1)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            def check_result(got):
                try:
                    self.assertLen(got, 4)
                    # We can't verify the actual predictions, but we can verify the keys.
                    for item in got:
                        self.assertIn(constants.BATCHED_PREDICTIONS_KEY, item)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(predict_extracts, check_result, label='result')
    def testPredictExtractorWithSequentialKerasModel(self):
        # Note that the input will be called 'test_input'
        model = tf.keras.models.Sequential([
            tf.keras.layers.Dense(1,
                                  activation=tf.nn.sigmoid,
                                  input_shape=(2, ),
                                  name='test')
        ])
        model.compile(optimizer=tf.keras.optimizers.Adam(lr=.001),
                      loss=tf.keras.losses.binary_crossentropy,
                      metrics=['accuracy'])

        train_features = {'test_input': [[0.0, 0.0], [1.0, 1.0]]}
        labels = [[1], [0]]
        example_weights = [1.0, 0.5]
        dataset = tf.data.Dataset.from_tensor_slices(
            (train_features, labels, example_weights))
        dataset = dataset.shuffle(buffer_size=1).repeat().batch(2)
        model.fit(dataset, steps_per_epoch=1)

        export_dir = self._getExportDir()
        model.save(export_dir, save_format='tf')

        eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
        schema = text_format.Parse(
            """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "test"
              value {
                dense_tensor {
                  column_name: "test"
                  shape { dim { size: 2 } }
                }
              }
            }
          }
        }
        feature {
          name: "test"
          type: FLOAT
        }
        feature {
          name: "non_model_feature"
          type: INT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model=eval_shared_model,
            tensor_adapter_config=tensor_adapter_config)

        # Notice that the features are 'test' but the model expects 'test_input'.
        # This tests that the PredictExtractor properly handles this case.
        examples = [
            self._makeExample(
                test=[0.0,
                      0.0], non_model_feature=0),  # should be ignored by model
            self._makeExample(
                test=[1.0,
                      1.0], non_model_feature=1),  # should be ignored by model
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    # We can't verify the actual predictions, but we can verify the keys.
                    for item in got:
                        self.assertIn(constants.BATCHED_PREDICTIONS_KEY, item)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testPredictExtractorWithRegressionModel(self):
        temp_export_dir = self._getExportDir()
        export_dir, _ = (fixed_prediction_estimator_extra_fields.
                         simple_fixed_prediction_estimator_extra_fields(
                             temp_export_dir, None))

        eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
        schema = text_format.Parse(
            """
        feature {
          name: "prediction"
          type: FLOAT
        }
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "fixed_int"
          type: INT
        }
        feature {
          name: "fixed_float"
          type: FLOAT
        }
        feature {
          name: "fixed_string"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model=eval_shared_model,
            tensor_adapter_config=tensor_adapter_config)

        examples = [
            self._makeExample(prediction=0.2,
                              label=1.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.8,
                              label=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string2'),
            self._makeExample(prediction=0.5,
                              label=0.0,
                              fixed_int=2,
                              fixed_float=1.0,
                              fixed_string='fixed_string3')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    self.assertIn(constants.BATCHED_PREDICTIONS_KEY, got[0])
                    expected_preds = [0.2, 0.8, 0.5]
                    self.assertAlmostEqual(
                        got[0][constants.BATCHED_PREDICTIONS_KEY],
                        expected_preds)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testPredictExtractorWithMultiModels(self):
        temp_export_dir = self._getExportDir()
        export_dir1, _ = multi_head.simple_multi_head(temp_export_dir, None)
        export_dir2, _ = multi_head.simple_multi_head(temp_export_dir, None)

        eval_config = config.EvalConfig(model_specs=[
            config.ModelSpec(name='model1'),
            config.ModelSpec(name='model2')
        ])
        eval_shared_model1 = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir1, tags=[tf.saved_model.SERVING])
        eval_shared_model2 = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir2, tags=[tf.saved_model.SERVING])
        schema = text_format.Parse(
            """
        feature {
          name: "age"
          type: FLOAT
        }
        feature {
          name: "langauge"
          type: BYTES
        }
        feature {
          name: "english_label"
          type: FLOAT
        }
        feature {
          name: "chinese_label"
          type: FLOAT
        }
        feature {
          name: "other_label"
          type: FLOAT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model={
                'model1': eval_shared_model1,
                'model2': eval_shared_model2
            },
            tensor_adapter_config=tensor_adapter_config)

        examples = [
            self._makeExample(age=1.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=1.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='other',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=1.0)
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=4)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    for item in got:
                        # We can't verify the actual predictions, but we can verify the keys
                        self.assertIn(constants.BATCHED_PREDICTIONS_KEY, item)
                        for pred in item[constants.BATCHED_PREDICTIONS_KEY]:
                            for model_name in ('model1', 'model2'):
                                self.assertIn(model_name, pred)
                                for output_name in ('chinese_head',
                                                    'english_head',
                                                    'other_head'):
                                    for pred_key in ('logistic',
                                                     'probabilities',
                                                     'all_classes'):
                                        self.assertIn(
                                            output_name + '/' + pred_key,
                                            pred[model_name])

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testWriteValidationResults(self):
        model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir()
        eval_shared_model = self._build_keras_model(model_dir, mul=0)
        baseline_eval_shared_model = self._build_keras_model(baseline_dir,
                                                             mul=1)
        validations_file = os.path.join(self._getTempDir(),
                                        constants.VALIDATIONS_KEY)
        schema = text_format.Parse(
            """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "input"
              value {
                dense_tensor {
                  column_name: "input"
                  shape { dim { size: 1 } }
                }
              }
            }
          }
        }
        feature {
          name: "input"
          type: FLOAT
        }
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "extra_feature"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        examples = [
            self._makeExample(input=0.0,
                              label=1.0,
                              example_weight=1.0,
                              extra_feature='non_model_feature'),
            self._makeExample(input=1.0,
                              label=0.0,
                              example_weight=0.5,
                              extra_feature='non_model_feature'),
        ]

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(name='candidate',
                                 label_key='label',
                                 example_weight_key='example_weight'),
                config.ModelSpec(name='baseline',
                                 label_key='label',
                                 example_weight_key='example_weight',
                                 is_baseline=True)
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=[
                config.MetricsSpec(
                    metrics=[
                        config.MetricConfig(
                            class_name='WeightedExampleCount',
                            # 1.5 < 1, NOT OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    upper_bound={'value': 1}))),
                        config.MetricConfig(
                            class_name='ExampleCount',
                            # 2 > 10, NOT OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    lower_bound={'value': 10}))),
                        config.MetricConfig(
                            class_name='MeanLabel',
                            # 0 > 0 and 0 > 0%?: NOT OK.
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                    direction=config.MetricDirection.
                                    HIGHER_IS_BETTER,
                                    relative={'value': 0},
                                    absolute={'value': 0}))),
                        config.MetricConfig(
                            # MeanPrediction = (0+0)/(1+0.5) = 0
                            class_name='MeanPrediction',
                            # -.01 < 0 < .01, OK.
                            # Diff% = -.333/.333 = -100% < -99%, OK.
                            # Diff = 0 - .333 = -.333 < 0, OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    upper_bound={'value': .01},
                                    lower_bound={'value': -.01}),
                                change_threshold=config.GenericChangeThreshold(
                                    direction=config.MetricDirection.
                                    LOWER_IS_BETTER,
                                    relative={'value': -.99},
                                    absolute={'value': 0})))
                    ],
                    model_names=['candidate', 'baseline']),
            ],
            options=config.Options(
                disabled_outputs={'values': ['eval_config.json']}),
        )
        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        eval_shared_models = {
            'candidate': eval_shared_model,
            'baseline': baseline_eval_shared_model
        }
        extractors = [
            batched_input_extractor.BatchedInputExtractor(eval_config),
            batched_predict_extractor_v2.BatchedPredictExtractor(
                eval_shared_model=eval_shared_models,
                eval_config=eval_config,
                tensor_adapter_config=tensor_adapter_config),
            unbatch_extractor.UnbatchExtractor(),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config, eval_shared_model=eval_shared_models)
        ]
        output_paths = {
            constants.VALIDATIONS_KEY: validations_file,
        }
        writers = [
            metrics_plots_and_validations_writer.
            MetricsPlotsAndValidationsWriter(output_paths,
                                             add_metrics_callbacks=[])
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            _ = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'BatchExamples' >> tfx_io.BeamSource()
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | 'ExtractEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators)
                |
                'WriteResults' >> model_eval_lib.WriteResults(writers=writers))
            # pylint: enable=no-value-for-parameter

        validation_result = model_eval_lib.load_validation_result(
            os.path.dirname(validations_file))

        expected_validations = [
            text_format.Parse(
                """
            metric_key {
              name: "weighted_example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                upper_bound {
                  value: 1.0
                }
              }
            }
            metric_value {
              double_value {
                value: 1.5
              }
            }
            """, validation_result_pb2.ValidationFailure()),
            text_format.Parse(
                """
            metric_key {
              name: "example_count"
            }
            metric_threshold {
              value_threshold {
                lower_bound {
                  value: 10.0
                }
              }
            }
            metric_value {
              double_value {
                value: 2.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
            text_format.Parse(
                """
            metric_key {
              name: "mean_label"
              model_name: "candidate"
              is_diff: true
            }
            metric_threshold {
              change_threshold {
                absolute {
                  value: 0.0
                }
                relative {
                  value: 0.0
                }
                direction: HIGHER_IS_BETTER
              }
            }
            metric_value {
              double_value {
                value: 0.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        ]
        self.assertFalse(validation_result.validation_ok)
        self.assertLen(validation_result.metric_validations_per_slice, 1)
        self.assertCountEqual(
            expected_validations,
            validation_result.metric_validations_per_slice[0].failures)
Esempio n. 7
0
    def testBatchSizeLimitWithKerasModel(self):
        input1 = tf.keras.layers.Input(shape=(1, ),
                                       batch_size=1,
                                       name='input1')
        input2 = tf.keras.layers.Input(shape=(1, ),
                                       batch_size=1,
                                       name='input2')

        inputs = [input1, input2]
        input_layer = tf.keras.layers.concatenate(inputs)

        def add_1(tensor):
            return tf.add_n([tensor, tf.constant(1.0, shape=(1, 2))])

        assert_layer = tf.keras.layers.Lambda(add_1)(input_layer)

        model = tf.keras.models.Model(inputs, assert_layer)
        model.compile(optimizer=tf.keras.optimizers.Adam(lr=.001),
                      loss=tf.keras.losses.binary_crossentropy,
                      metrics=['accuracy'])

        export_dir = self._getExportDir()
        model.save(export_dir, save_format='tf')

        eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
        schema = text_format.Parse(
            """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "input1"
              value {
                dense_tensor {
                  column_name: "input1"
                  shape { dim { size: 1 } }
                }
              }
            }
            tensor_representation {
              key: "input2"
              value {
                dense_tensor {
                  column_name: "input2"
                  shape { dim { size: 1 } }
                }
              }
            }
          }
        }
        feature {
          name: "input1"
          type: FLOAT
        }
        feature {
          name: "input2"
          type: FLOAT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.BATCHED_INPUT_KEY)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model=eval_shared_model,
            tensor_adapter_config=tensor_adapter_config)

        examples = []
        for _ in range(4):
            examples.append(self._makeExample(input1=0.0, input2=1.0))

        with beam.Pipeline() as pipeline:
            predict_extracts = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=1)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            # pylint: enable=no-value-for-parameter
            def check_result(got):
                try:
                    self.assertLen(got, 4)
                    # We can't verify the actual predictions, but we can verify the keys.
                    for item in got:
                        self.assertIn(constants.BATCHED_PREDICTIONS_KEY, item)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(predict_extracts, check_result, label='result')
Esempio n. 8
0
    def testPredictExtractorWithMultiClassModel(self):
        temp_export_dir = self._getExportDir()
        export_dir, _ = dnn_classifier.simple_dnn_classifier(temp_export_dir,
                                                             None,
                                                             n_classes=3)

        eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
        schema = text_format.Parse(
            """
        feature {
          name: "age"
          type: FLOAT
        }
        feature {
          name: "langauge"
          type: BYTES
        }
        feature {
          name: "label"
          type: INT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        input_extractor = batched_input_extractor.BatchedInputExtractor(
            eval_config)
        predict_extractor = batched_predict_extractor_v2.BatchedPredictExtractor(
            eval_config=eval_config,
            eval_shared_model=eval_shared_model,
            tensor_adapter_config=tensor_adapter_config)

        examples = [
            self._makeExample(age=1.0, language='english', label=0),
            self._makeExample(age=2.0, language='chinese', label=1),
            self._makeExample(age=3.0, language='english', label=2),
            self._makeExample(age=4.0, language='chinese', label=1),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=4)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | input_extractor.stage_name >> input_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    # We can't verify the actual predictions, but we can verify the keys.
                    for item in got:
                        self.assertIn(constants.PREDICTIONS_KEY, item)
                        for pred in item[constants.PREDICTIONS_KEY]:
                            for pred_key in ('probabilities', 'all_classes'):
                                self.assertIn(pred_key, pred)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')