示例#1
0
def _RunInference(
    pipeline: beam.Pipeline, payload_format: int,
    inference_endpoint: model_spec_pb2.InferenceSpecType
) -> beam.pvalue.PCollection:
    """Runs model inference on given examples data."""
    return (pipeline
            | 'ParseExamples' >> beam.Map(_MakeParseFn(payload_format))
            | 'RunInference' >> run_inference.RunInference(inference_endpoint))
示例#2
0
    def _run_model_inference(self, model_path: Text,
                             example_uris: Mapping[Text,
                                                   Text], output_path: Text,
                             model_spec: bulk_inferrer_pb2.ModelSpec) -> None:
        """Runs model inference on given example data.

    Args:
      model_path: Path to model.
      example_uris: Mapping of example split name to example uri.
      output_path: Path to output generated prediction logs.
      model_spec: bulk_inferrer_pb2.ModelSpec instance.

    Returns:
      None
    """

        try:
            from tfx_bsl.public.beam import run_inference
            from tfx_bsl.public.proto import model_spec_pb2
        except ImportError:
            # TODO(b/151468119): Remove this branch after next release.
            run_inference = importlib.import_module(
                'tfx_bsl.beam.run_inference')
            model_spec_pb2 = importlib.import_module(
                'tfx_bsl.proto.model_spec_pb2')
        saved_model_spec = model_spec_pb2.SavedModelSpec(
            model_path=model_path,
            tag=model_spec.tag,
            signature_name=model_spec.model_signature_name)
        # TODO(b/151468119): Remove this branch after next release.
        if getattr(model_spec_pb2, 'InferenceEndpoint', False):
            inference_endpoint = getattr(model_spec_pb2, 'InferenceEndpoint')()
        else:
            inference_endpoint = model_spec_pb2.InferenceSpecType()
        inference_endpoint.saved_model_spec.CopyFrom(saved_model_spec)
        with self._make_beam_pipeline() as pipeline:
            data_list = []
            for split, example_uri in example_uris.items():
                data = (
                    pipeline
                    | 'ReadData[{}]'.format(split) >> beam.io.ReadFromTFRecord(
                        file_pattern=io_utils.all_files_pattern(example_uri)))
                data_list.append(data)
            _ = ([data for data in data_list]
                 | 'FlattenExamples' >> beam.Flatten(pipeline=pipeline)
                 | 'ParseExamples' >> beam.Map(tf.train.Example.FromString)
                 | 'RunInference' >>
                 run_inference.RunInference(inference_endpoint)
                 | 'WritePredictionLogs' >> beam.io.WriteToTFRecord(
                     output_path,
                     file_name_suffix='.gz',
                     coder=beam.coders.ProtoCoder(
                         prediction_log_pb2.PredictionLog)))
        logging.info('Inference result written to %s.', output_path)
示例#3
0
    def _run_model_inference(
            self, data_spec: bulk_inferrer_pb2.DataSpec,
            examples: List[types.Artifact], output_uri: Text,
            inference_endpoint: model_spec_pb2.InferenceSpecType) -> bool:
        """Runs model inference on given example data.

    Args:
      data_spec: bulk_inferrer_pb2.DataSpec instance.
      examples: List of example artifacts.
      output_uri: Output artifact uri.
      inference_endpoint: Model inference endpoint.

    Returns:
      Whether the inference job succeed.
    """

        example_uris = {}
        if data_spec.example_splits:
            for example in examples:
                for split in artifact_utils.decode_split_names(
                        example.split_names):
                    if split in data_spec.example_splits:
                        example_uris[split] = os.path.join(example.uri, split)
        else:
            for example in examples:
                for split in artifact_utils.decode_split_names(
                        example.split_names):
                    example_uris[split] = os.path.join(example.uri, split)
        output_path = os.path.join(output_uri, _PREDICTION_LOGS_DIR_NAME)
        logging.info('BulkInferrer generates prediction log to %s',
                     output_path)

        with self._make_beam_pipeline() as pipeline:
            data_list = []
            for split, example_uri in example_uris.items():
                data = (
                    pipeline
                    | 'ReadData[{}]'.format(split) >> beam.io.ReadFromTFRecord(
                        file_pattern=io_utils.all_files_pattern(example_uri)))
                data_list.append(data)
            _ = (
                data_list
                | 'FlattenExamples' >> beam.Flatten(pipeline=pipeline)
                # TODO(b/131873699): Use the correct Example type here, which
                # is either Example or SequenceExample.
                | 'ParseExamples' >> beam.Map(tf.train.Example.FromString)
                | 'RunInference' >>
                run_inference.RunInference(inference_endpoint)
                | 'WritePredictionLogs' >> beam.io.WriteToTFRecord(
                    output_path,
                    file_name_suffix='.gz',
                    coder=beam.coders.ProtoCoder(
                        prediction_log_pb2.PredictionLog)))
        logging.info('Inference result written to %s.', output_path)
示例#4
0
def _RunInference(
    pipeline: beam.Pipeline, example_uri: Text,
    inference_endpoint: model_spec_pb2.InferenceSpecType
) -> beam.pvalue.PCollection:
  """Runs model inference on given examples data."""
  # TODO(b/174703893): adopt standardized input.
  return (
      pipeline
      | 'ReadData' >> beam.io.ReadFromTFRecord(
          file_pattern=io_utils.all_files_pattern(example_uri))
      # TODO(b/131873699): Use the correct Example type here, which
      # is either Example or SequenceExample.
      | 'ParseExamples' >> beam.Map(tf.train.Example.FromString)
      | 'RunInference' >> run_inference.RunInference(inference_endpoint))