def testBatchSizeLimit(self):
    temp_eval_export_dir = self._getEvalExportDir()
    _, eval_export_dir = batch_size_limited_classifier.simple_batch_size_limited_classifier(
        None, temp_eval_export_dir)
    eval_shared_model = model_eval_lib.default_eval_shared_model(
        eval_saved_model_path=eval_export_dir)
    with beam.Pipeline() as pipeline:
      examples = [
          self._makeExample(classes='first', scores=0.0, labels='third'),
          self._makeExample(classes='first', scores=0.0, labels='third'),
          self._makeExample(classes='first', scores=0.0, labels='third'),
          self._makeExample(classes='first', scores=0.0, labels='third'),
      ]
      serialized_examples = [e.SerializeToString() for e in examples]

      predict_extracts = (
          pipeline
          | beam.Create(serialized_examples, reshuffle=False)
          # Our diagnostic outputs, pass types.Extracts throughout, however our
          # aggregating functions do not use this interface.
          | beam.Map(lambda x: {constants.INPUT_KEY: x})
          | 'Predict' >> predict_extractor._TFMAPredict(
              eval_shared_models={'': eval_shared_model}))

      def check_result(got):
        self.assertLen(got, 4)
        for item in got:
          self.assertIn(constants.PREDICTIONS_KEY, item)

      util.assert_that(predict_extracts, check_result)
  def testBatchSizeLimit(self):
    temp_export_dir = self._getExportDir()
    _, export_dir = batch_size_limited_classifier.simple_batch_size_limited_classifier(
        None, temp_export_dir)
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
    eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
    predict_extractor = predict_extractor_v2.PredictExtractor(
        eval_config=eval_config, eval_shared_model=eval_shared_model)
    with beam.Pipeline() as pipeline:
      examples = [
          self._makeExample(classes='first', scores=0.0, labels='third'),
          self._makeExample(classes='first', scores=0.0, labels='third'),
          self._makeExample(classes='first', scores=0.0, labels='third'),
          self._makeExample(classes='first', scores=0.0, labels='third'),
      ]

      predict_extracts = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples],
                                    reshuffle=False)
          | 'FeaturesToExtracts' >> model_eval_lib.InputsToExtracts()
          | predict_extractor.stage_name >> predict_extractor.ptransform)

      def check_result(got):
        try:
          self.assertLen(got, 4)
          # We can't verify the actual predictions, but we can verify the keys.
          for item in got:
            self.assertIn(constants.PREDICTIONS_KEY, item)

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(predict_extracts, check_result, label='result')
Esempio n. 3
0
  def testBatchSizeLimit(self):
    temp_export_dir = self._getExportDir()
    _, export_dir = batch_size_limited_classifier.simple_batch_size_limited_classifier(
        None, temp_export_dir)
    eval_shared_model = self.createTestEvalSharedModel(
        eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
    eval_config = config.EvalConfig(model_specs=[config.ModelSpec()])
    schema = text_format.Parse(
        """
        feature {
          name: "classes"
          type: BYTES
        }
        feature {
          name: "scores"
          type: FLOAT
        }
        feature {
          name: "labels"
          type: BYTES
        }
        """, schema_pb2.Schema())
    tfx_io = test_util.InMemoryTFExampleRecord(
        schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
    tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
        arrow_schema=tfx_io.ArrowSchema(),
        tensor_representations=tfx_io.TensorRepresentations())
    feature_extractor = features_extractor.FeaturesExtractor(eval_config)
    prediction_extractor = predictions_extractor.PredictionsExtractor(
        eval_config=eval_config,
        eval_shared_model=eval_shared_model,
        tensor_adapter_config=tensor_adapter_config)

    examples = []
    for _ in range(4):
      examples.append(
          self._makeExample(classes='first', scores=0.0, labels='third'))

    with beam.Pipeline() as pipeline:
      predict_extracts = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples],
                                    reshuffle=False)
          | 'BatchExamples' >> tfx_io.BeamSource(batch_size=1)
          | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
          | feature_extractor.stage_name >> feature_extractor.ptransform
          | prediction_extractor.stage_name >> prediction_extractor.ptransform)

      def check_result(got):
        try:
          self.assertLen(got, 4)
          # We can't verify the actual predictions, but we can verify the keys.
          for item in got:
            self.assertIn(constants.PREDICTIONS_KEY, item)

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(predict_extracts, check_result, label='result')