def json_to_proto(response_json, response_cls, ignore_unknown_fields=True): """ Converts a JSON-compliant dictionary into a `protobuf` `Message` object. Parameters ---------- response_json : dict JSON object representing a Protocol Buffer message. response_cls : type `protobuf` `Message` subclass, e.g. ``CreateProject.Response``. ignore_unknown_fields : bool, default True Whether to allow (and ignore) fields in `response_json` that are not defined in `response_cls`. This is for forward compatibility with the back end; if the Client protos are outdated and we get a response with new fields, ``True`` prevents an error. Returns ------- google.protobuf.message.Message `protobuf` `Message` object represented by `response_json`. """ return json_format.Parse(json.dumps(response_json), response_cls(), ignore_unknown_fields=ignore_unknown_fields)
def testParseNull(self): message = json_format_proto3_pb2.TestMessage() parsed_message = json_format_proto3_pb2.TestMessage() self.FillAllFields(parsed_message) json_format.Parse('{"int32Value": null, ' '"int64Value": null, ' '"uint32Value": null,' '"uint64Value": null,' '"floatValue": null,' '"doubleValue": null,' '"boolValue": null,' '"stringValue": null,' '"bytesValue": null,' '"messageValue": null,' '"enumValue": null,' '"repeatedInt32Value": null,' '"repeatedInt64Value": null,' '"repeatedUint32Value": null,' '"repeatedUint64Value": null,' '"repeatedFloatValue": null,' '"repeatedDoubleValue": null,' '"repeatedBoolValue": null,' '"repeatedStringValue": null,' '"repeatedBytesValue": null,' '"repeatedMessageValue": null,' '"repeatedEnumValue": null' '}', parsed_message) self.assertEqual(message, parsed_message) self.assertRaisesRegex( json_format.ParseError, 'Failed to parse repeatedInt32Value field: ' 'null is not allowed to be used as an element in a repeated field.', json_format.Parse, '{"repeatedInt32Value":[1, null]}', parsed_message)
def do_test(request): test_message = conformance_pb2.TestAllTypes() response = conformance_pb2.ConformanceResponse() test_message = conformance_pb2.TestAllTypes() try: if request.WhichOneof('payload') == 'protobuf_payload': try: test_message.ParseFromString(request.protobuf_payload) except message.DecodeError as e: response.parse_error = str(e) return response elif request.WhichOneof('payload') == 'json_payload': try: json_format.Parse(request.json_payload, test_message) except json_format.ParseError as e: response.parse_error = str(e) return response else: raise ProtocolError("Request didn't have payload.") if request.requested_output_format == conformance_pb2.UNSPECIFIED: raise ProtocolError("Unspecified output format") elif request.requested_output_format == conformance_pb2.PROTOBUF: response.protobuf_payload = test_message.SerializeToString() elif request.requested_output_format == conformance_pb2.JSON: response.json_payload = json_format.MessageToJson(test_message) except Exception as e: response.runtime_error = str(e) return response
storage_client = storage.Client() match = re.match(r'gs://([^/]+)/(.+)', gcs_destination_uri) bucket_name = match.group(1) prefix = match.group(2) bucket = storage_client.get_bucket(bucket_name) # List object with the given prefix blob_list = list(bucket.list_blobs(prefix=prefix)) print('Output files:') """for blob in blob_list: print(blob.name)""" output = blob_list[0] json_string = output.download_as_string() response = json_format.Parse(json_string, vision.types.AnnotateFileResponse()) for i in range(batch_size): first_page_response = response.responses[i] annotation = first_page_response.full_text_annotation print("\n\n") with open(str(new_files) + ".docx", 'a') as f: print(annotation.text + "\n\n\n\n\n", file=f) print("Output printed to " + str(new_files) + " doc file") ########################################################## # output1 = blob_list[1] # json_string1 = output1.download_as_string() # response1 = json_format.Parse( # json_string1, vision.types.AnnotateFileResponse()) # second_page_response = response1.responses[1]
def async_detect_document(gcs_source_uri, gcs_destination_uri): """OCR with PDF/TIFF as source files on GCS""" from google.cloud import vision from google.cloud import storage from google.protobuf import json_format import re # Supported mime_types are: 'application/pdf' and 'image/tiff' mime_type = 'application/pdf' # How many pages should be grouped into each json output file. batch_size = 10 client = vision.ImageAnnotatorClient() feature = vision.types.Feature( type=vision.enums.Feature.Type.DOCUMENT_TEXT_DETECTION) gcs_source = vision.types.GcsSource(uri=gcs_source_uri) input_config = vision.types.InputConfig(gcs_source=gcs_source, mime_type=mime_type) gcs_destination = vision.types.GcsDestination(uri=gcs_destination_uri) output_config = vision.types.OutputConfig(gcs_destination=gcs_destination, batch_size=batch_size) async_request = vision.types.AsyncAnnotateFileRequest( features=[feature], input_config=input_config, output_config=output_config) operation = client.async_batch_annotate_files(requests=[async_request]) print('Waiting for the operation to finish.') operation.result(timeout=180) # Once the request has completed and the output has been # written to GCS, we can list all the output files. storage_client = storage.Client() # match = re.match(r'gs://([^/]+)', gcs_destination_uri) # bucket_name = match.group(1) # prefix = match.group(2) bucket = storage_client.get_bucket("prep-to-certification") # List objects with the given prefix. blob_list = list(bucket.list_blobs(prefix="output")) print('Output files:') for blob in blob_list: print(blob.name) # Process the first output file from GCS. # Since we specified batch_size=2, the first response contains # the first two pages of the input file. output = blob_list[0] json_string = output.download_as_string() response = json_format.Parse(json_string, vision.types.AnnotateFileResponse()) # The actual response for the first page of the input file. page_number = 1 for page in response.responses: first_page_response = page annotation = first_page_response.full_text_annotation # second_page_response = response.responses[1] # annotation2 = second_page_response.full_text_annotation # Here we print the full text from the first page. # The response contains more information: # annotation/pages/blocks/paragraphs/words/symbols # including confidence scores and bounding boxes # print(u'Full text:\n{}'.format( # annotation.text)) with open("Page" + str(page_number) + ".txt", "w") as text_file: print(u'Full text:/n{}'.format(annotation.text), file=text_file) page_number = page_number + 1
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Uses a user-supplied tf.estimator to train a TensorFlow model locally. The Trainer Executor invokes a training_fn callback function provided by the user via the module_file parameter. With the tf.estimator returned by this function, the Trainer Executor then builds a TensorFlow model using the user-provided tf.estimator. Args: input_dict: Input dict from input key to a list of ML-Metadata Artifacts. - examples: Examples used for training, must include 'train' and 'eval' splits. - transform_output: Optional input transform graph. - schema: Schema of the data. output_dict: Output dict from output key to a list of Artifacts. - output: Exported model. exec_properties: A dict of execution properties. - train_args: JSON string of trainer_pb2.TrainArgs instance, providing args for training. - eval_args: JSON string of trainer_pb2.EvalArgs instance, providing args for eval. - module_file: Python module file containing UDF model definition. - warm_starting: Whether or not we need to do warm starting. - warm_start_from: Optional. If warm_starting is True, this is the directory to find previous model to warm start on. Returns: None Raises: ValueError: When neither or both of 'module_file' and 'trainer_fn' are present in 'exec_properties'. """ self._log_startup(input_dict, output_dict, exec_properties) # TODO(zhitaoli): Deprecate this in a future version. if exec_properties.get('custom_config', None): cmle_args = exec_properties.get('custom_config', {}).get('cmle_training_args') if cmle_args: executor_class_path = '.'.join([Executor.__module__, Executor.__name__]) absl.logging.warn( 'Passing \'cmle_training_args\' to trainer directly is deprecated, ' 'please use extension executor at ' 'tfx.extensions.google_cloud_ai_platform.trainer.executor instead') return runner.start_cmle_training(input_dict, output_dict, exec_properties, executor_class_path, cmle_args) trainer_fn = self._GetTrainerFn(exec_properties) # Set up training parameters train_files = [ _all_files_pattern( artifact_utils.get_split_uri(input_dict['examples'], 'train')) ] transform_output = artifact_utils.get_single_uri( input_dict['transform_output']) if input_dict.get( 'transform_output', None) else None eval_files = [ _all_files_pattern( artifact_utils.get_split_uri(input_dict['examples'], 'eval')) ] schema_file = io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri(input_dict['schema'])) train_args = trainer_pb2.TrainArgs() eval_args = trainer_pb2.EvalArgs() json_format.Parse(exec_properties['train_args'], train_args) json_format.Parse(exec_properties['eval_args'], eval_args) # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with # num_steps=None. Conversion of the proto to python will set the default # value of an int as 0 so modify the value here. Tensorflow will raise an # error if num_steps <= 0. train_steps = train_args.num_steps or None eval_steps = eval_args.num_steps or None output_path = artifact_utils.get_single_uri(output_dict['output']) serving_model_dir = path_utils.serving_model_dir(output_path) eval_model_dir = path_utils.eval_model_dir(output_path) # Assemble warm start path if needed. warm_start_from = None if exec_properties.get('warm_starting') and exec_properties.get( 'warm_start_from'): previous_model_dir = os.path.join(exec_properties['warm_start_from'], path_utils.SERVING_MODEL_DIR) if previous_model_dir and tf.io.gfile.exists( os.path.join(previous_model_dir, self._CHECKPOINT_FILE_NAME)): warm_start_from = previous_model_dir # TODO(b/126242806) Use PipelineInputs when it is available in third_party. hparams = _HParamWrapper( # A list of uris for train files. train_files=train_files, # An optional single uri for transform graph produced by TFT. Will be # None if not specified. transform_output=transform_output, # A single uri for the output directory of the serving model. serving_model_dir=serving_model_dir, # A list of uris for eval files. eval_files=eval_files, # A single uri for schema file. schema_file=schema_file, # Number of train steps. train_steps=train_steps, # Number of eval steps. eval_steps=eval_steps, # A single uri for the model directory to warm start from. warm_start_from=warm_start_from) schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) training_spec = trainer_fn(hparams, schema) # Train the model absl.logging.info('Training model.') tf.estimator.train_and_evaluate(training_spec['estimator'], training_spec['train_spec'], training_spec['eval_spec']) absl.logging.info('Training complete. Model written to %s', serving_model_dir) # Export an eval savedmodel for TFMA absl.logging.info('Exporting eval_savedmodel for TFMA.') tfma.export.export_eval_savedmodel( estimator=training_spec['estimator'], export_dir_base=eval_model_dir, eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) absl.logging.info('Exported eval_savedmodel to %s.', eval_model_dir)
properties={ 'int1': artifact.Property(type=artifact.PropertyType.INT), 'int2': artifact.Property(type=artifact.PropertyType.INT), 'float1': artifact.Property(type=artifact.PropertyType.FLOAT), 'float2': artifact.Property(type=artifact.PropertyType.FLOAT), 'string1': artifact.Property(type=artifact.PropertyType.STRING), 'string2': artifact.Property(type=artifact.PropertyType.STRING), }) _mlmd_artifact_type = metadata_store_pb2.ArtifactType() json_format.Parse( json.dumps({ 'name': 'MyTypeName3', 'properties': { 'int1': 'INT', 'int2': 'INT', 'float1': 'DOUBLE', 'float2': 'DOUBLE', 'string1': 'STRING', 'string2': 'STRING' } }), _mlmd_artifact_type) _MyArtifact3 = artifact._ArtifactType(mlmd_artifact_type=_mlmd_artifact_type) # pylint: disable=invalid-name class _MyValueArtifact(artifact.ValueArtifact): TYPE_NAME = 'MyValueTypeName' def encode(self, value: Text): assert isinstance(value, Text), value return value.encode('utf-8')
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Runs batch inference on a given model with given input examples. Args: input_dict: Input dict from input key to a list of Artifacts. - examples: examples for inference. - model: exported model. - model_blessing: model blessing result output_dict: Output dict from output key to a list of Artifacts. - output: bulk inference results. exec_properties: A dict of execution properties. - model_spec: JSON string of bulk_inferrer_pb2.ModelSpec instance. - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) if 'examples' not in input_dict: raise ValueError('\'examples\' is missing in input dict.') if 'inference_result' not in output_dict: raise ValueError('\'inference_result\' is missing in output dict.') output = artifact_utils.get_single_instance( output_dict['inference_result']) if 'model' not in input_dict: raise ValueError('Input models are not valid, model ' 'need to be specified.') if 'model_blessing' in input_dict: model_blessing = artifact_utils.get_single_instance( input_dict['model_blessing']) if not model_utils.is_model_blessed(model_blessing): output.set_int_custom_property('inferred', 0) logging.info('Model on %s was not blessed', model_blessing.uri) return else: logging.info( 'Model blessing is not provided, exported model will be ' 'used.') model = artifact_utils.get_single_instance(input_dict['model']) model_path = path_utils.serving_model_path(model.uri) logging.info('Use exported model from %s.', model_path) data_spec = bulk_inferrer_pb2.DataSpec() json_format.Parse(exec_properties['data_spec'], data_spec) example_uris = {} if data_spec.example_splits: for example in input_dict['examples']: for split in artifact_utils.decode_split_names( example.split_names): if split in data_spec.example_splits: example_uris[split] = os.path.join(example.uri, split) else: for example in input_dict['examples']: for split in artifact_utils.decode_split_names( example.split_names): example_uris[split] = os.path.join(example.uri, split) model_spec = bulk_inferrer_pb2.ModelSpec() json_format.Parse(exec_properties['model_spec'], model_spec) output_path = os.path.join(output.uri, _PREDICTION_LOGS_DIR_NAME) self._run_model_inference(model_path, example_uris, output_path, model_spec) logging.info('BulkInferrer generates prediction log to %s', output_path) output.set_int_custom_property('inferred', 1)
def _get_artman_config(artman_yaml): config_pb = Config() with open(artman_yaml, 'r') as f: artman_config_json_string = json.dumps(yaml.load(f)) json_format.Parse(artman_config_json_string, config_pb) return config_pb
def main(): # Log to the container's stdout so Kubeflow Pipelines UI can display logs to # the user. logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger().setLevel(logging.INFO) parser = argparse.ArgumentParser() parser.add_argument('--pipeline_name', type=str, required=True) parser.add_argument('--pipeline_root', type=str, required=True) parser.add_argument('--kubeflow_metadata_config', type=str, required=True) parser.add_argument('--beam_pipeline_args', type=str, required=True) parser.add_argument('--additional_pipeline_args', type=str, required=True) parser.add_argument('--component_launcher_class_path', type=str, required=True) parser.add_argument('--enable_cache', action='store_true') parser.add_argument('--serialized_component', type=str, required=True) parser.add_argument('--component_config', type=str, required=True) args = parser.parse_args() component = json_utils.loads(args.serialized_component) component_config = json_utils.loads(args.component_config) component_launcher_class = import_utils.import_class_by_path( args.component_launcher_class_path) if not issubclass(component_launcher_class, base_component_launcher.BaseComponentLauncher): raise TypeError( 'component_launcher_class "%s" is not subclass of base_component_launcher.BaseComponentLauncher' % component_launcher_class) kubeflow_metadata_config = kubeflow_pb2.KubeflowMetadataConfig() json_format.Parse(args.kubeflow_metadata_config, kubeflow_metadata_config) metadata_connection = kubeflow_metadata_adapter.KubeflowMetadataAdapter( _get_metadata_connection_config(kubeflow_metadata_config)) driver_args = data_types.DriverArgs(enable_cache=args.enable_cache) beam_pipeline_args = _make_beam_pipeline_args(args.beam_pipeline_args) additional_pipeline_args = json.loads(args.additional_pipeline_args) launcher = component_launcher_class.create( component=component, pipeline_info=data_types.PipelineInfo( pipeline_name=args.pipeline_name, pipeline_root=args.pipeline_root, run_id=os.environ['WORKFLOW_ID']), driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=beam_pipeline_args, additional_pipeline_args=additional_pipeline_args, component_config=component_config) # Attach necessary labels to distinguish different runner and DSL. # TODO(zhitaoli): Pass this from KFP runner side when the same container # entrypoint can be used by a different runner. with telemetry_utils.scoped_labels({ telemetry_utils.TFX_RUNNER: 'kfp', }): execution_info = launcher.launch() # Dump the UI metadata. _dump_ui_metadata(component, execution_info)
def GenerateExamplesByBeam( self, pipeline: beam.Pipeline, exec_properties: Dict[Text, Any], ) -> Dict[Text, beam.pvalue.PCollection]: """Converts input source to serialized record splits based on configs. Custom ExampleGen executor should provide GetInputSourceToExamplePTransform for converting input split to serialized records. Overriding this 'GenerateExamplesByBeam' method instead if complex logic is need, e.g., custom spliting logic. Args: pipeline: Beam pipeline. exec_properties: A dict of execution properties. Depends on detailed example gen implementation. - input_base: an external directory containing the data files. - input_config: JSON string of example_gen_pb2.Input instance, providing input configuration. - output_config: JSON string of example_gen_pb2.Output instance, providing output configuration. - output_data_format: Payload format of generated data in output artifact, one of example_gen_pb2.PayloadFormat enum. Returns: Dict of beam PCollection with split name as key, each PCollection is a single output split that contains serialized records. """ # Get input split information. input_config = example_gen_pb2.Input() json_format.Parse(exec_properties[utils.INPUT_CONFIG_KEY], input_config) # Get output split information. output_config = example_gen_pb2.Output() json_format.Parse(exec_properties[utils.OUTPUT_CONFIG_KEY], output_config) # Get output split names. split_names = utils.generate_output_split_names(input_config, output_config) # Make beam_pipeline_args available in exec_properties since certain # example_gen executors need this information. # TODO(b/155441037): Revisit necessity of this when BigQueryExampleGen # does not branch on runner anymore. exec_properties['_beam_pipeline_args'] = self._beam_pipeline_args or [] example_splits = [] input_to_record = self.GetInputSourceToExamplePTransform() if output_config.split_config.splits: # Use output splits, input must have only one split. assert len( input_config.splits ) == 1, 'input must have only one split when output split is specified.' # Calculate split buckets. buckets = [] total_buckets = 0 for split in output_config.split_config.splits: total_buckets += split.hash_buckets buckets.append(total_buckets) example_splits = ( pipeline | 'InputToRecord' >> # pylint: disable=no-value-for-parameter input_to_record(exec_properties, input_config.splits[0].pattern) | 'SplitData' >> beam.Partition(_PartitionFn, len(buckets), buckets, output_config.split_config)) else: # Use input splits. for split in input_config.splits: examples = ( pipeline | 'InputToRecord[{}]'.format(split.name) >> # pylint: disable=no-value-for-parameter input_to_record(exec_properties, split.pattern)) example_splits.append(examples) result = {} for index, example_split in enumerate(example_splits): result[split_names[index]] = example_split return result
def ImportExampleGen( input_base_path: InputPath('ExternalPath'), #input_path: InputPath('ExternalPath'), examples_path: OutputPath('Examples'), input_config: 'JsonObject: example_gen_pb2.Input' = None, output_config: 'JsonObject: example_gen_pb2.Output' = None, ): """ TFX ImportExampleGen component. The ImportExampleGen component takes TFRecord files with TF Example data format, and generates train and eval examples for downsteam components. This component provides consistent and configurable partition, and it also shuffle the dataset for ML best practice. Args: input: A Channel of 'ExternalPath' type, which includes one artifact whose uri is an external directory with TFRecord files inside (required). input_config: An example_gen_pb2.Input instance, providing input configuration. If unset, the files under input_base will be treated as a single split. output_config: An example_gen_pb2.Output instance, providing output configuration. If unset, default splits will be 'train' and 'eval' with size 2:1. Returns: examples: Optional channel of 'ExamplesPath' for output train and eval examples. Raises: RuntimeError: Only one of query and input_config should be set. """ from tfx.components.example_gen.import_example_gen.component import ImportExampleGen as component_class #Generated code import json import os import tensorflow from google.protobuf import json_format, message from tfx.types import Artifact, channel_utils, artifact_utils arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value_obj = argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message): # Maybe FIX: execution_parameter.type can also be a tuple argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments[name + '_path'] if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path + '/' # ? if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = tensorflow.io.gfile.listdir(artifact_path) artifact.split_names = artifact_utils.encode_split_names(sorted(subdirs)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = {name: channel.get() for name, channel in component_class_instance.inputs.get_all().items()} output_dict = {name: channel.get() for name, channel in component_class_instance.outputs.get_all().items()} exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, artifacts in output_dict.items(): base_artifact_path = arguments[name + '_path'] # Are there still cases where output channel has multiple artifacts? for idx, artifact in enumerate(artifacts): subdir = str(idx + 1) if idx > 0 else '' artifact.uri = os.path.join(base_artifact_path, subdir) # Ends with '/' print('component instance: ' + str(component_class_instance)) #executor = component_class.EXECUTOR_SPEC.executor_class() # Same executor = component_class_instance.executor_spec.executor_class() executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, )
def _get_manifest_proxy(self, retrieval_token): # type: (str) -> beam_artifact_api_pb2.ProxyManifest with self._open(self._manifest_path(retrieval_token), 'r') as fin: return json_format.Parse( fin.read().decode('utf-8'), beam_artifact_api_pb2.ProxyManifest())
def parse_json_file(file_name: Text, message: ProtoMessage) -> ProtoMessage: """Parses a protobuf message from a JSON file and return itself.""" contents = fileio.open(file_name).read() json_format.Parse(contents, message) return message
def to_flyte_idl(self): return _json_format.Parse(_json.dumps(self.to_dict()), _struct.Struct())
def do_GoogleOCR(gcs_source_uri, gcs_destination_uri): #bs parameter """ Perform OCR on a PDF uploaded in google cloud storage, generate output as JSON responses and save it in a destination URI """ # Supported mime_types are: 'application/pdf' and 'image/tiff' mime_type = 'application/pdf' # How many pages should be grouped into each json output file. batch_size = 1 client = vision.ImageAnnotatorClient() feature = vision.types.Feature( type=vision.enums.Feature.Type.DOCUMENT_TEXT_DETECTION) gcs_source = vision.types.GcsSource(uri=gcs_source_uri) input_config = vision.types.InputConfig(gcs_source=gcs_source, mime_type=mime_type) gcs_destination = vision.types.GcsDestination(uri=gcs_destination_uri) output_config = vision.types.OutputConfig(gcs_destination=gcs_destination, batch_size=batch_size) async_request = vision.types.AsyncAnnotateFileRequest( features=[feature], input_config=input_config, output_config=output_config) operation = client.async_batch_annotate_files(requests=[async_request]) print('Waiting for the operation to finish.') operation.result(timeout=180) # Once the request has completed and the output has been # written to GCS, we can list all the output files. storage_client = storage.Client() match = re.match(r'gs://([^/]+)/(.+)', gcs_destination_uri) bucket_name = match.group(1) print("bucket_name", bucket_name) prefix = match.group(2) bucket = storage_client.get_bucket( bucket_name) #1.16.0(bucket_or_name=bucket_name) # List objects with the given prefix. blob_list = list(bucket.list_blobs(prefix=prefix)) new_list = [] for i in range(len(blob_list)): str_convert = str(blob_list[i]).replace("<", "").replace(">", "").split(", ")[1] if str_convert[-3:] == "pdf": pass else: new_list.append(str_convert) sorted_blob_list = [ bucket.blob(i) for i in natsorted(new_list, reverse=False) ] all_text = "" #sort the blob_list for i in range(len(sorted_blob_list)): try: output = sorted_blob_list[i] json_string = output.download_as_string() response = json_format.Parse(json_string, vision.types.AnnotateFileResponse()) first_page_response = response.responses[0] annotation = first_page_response.full_text_annotation except: print("SKIP---->", i) all_text += annotation.text return prefix[:-1], all_text
def json_to_proto(self, json): ex = (tf.train.SequenceExample() if self.config.get('are_sequence_examples') else tf.train.Example()) json_format.Parse(json, ex) return ex
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Runs batch inference on a given model with given input examples. This function creates a new model (if necessary) and a new model version before inference, and cleans up resources after inference. It provides re-executability as it cleans up (only) the model resources that are created during the process even inference job failed. Args: input_dict: Input dict from input key to a list of Artifacts. - examples: examples for inference. - model: exported model. - model_blessing: model blessing result output_dict: Output dict from output key to a list of Artifacts. - output: bulk inference results. exec_properties: A dict of execution properties. - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance. - custom_config: custom_config.ai_platform_serving_args need to contain the serving job parameters sent to Google Cloud AI Platform. For the full set of parameters, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.models Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) if 'examples' not in input_dict: raise ValueError('\'examples\' is missing in input dict.') if 'inference_result' not in output_dict: raise ValueError('\'inference_result\' is missing in output dict.') output = artifact_utils.get_single_instance(output_dict['inference_result']) if 'model' not in input_dict: raise ValueError('Input models are not valid, model ' 'need to be specified.') if 'model_blessing' in input_dict: model_blessing = artifact_utils.get_single_instance( input_dict['model_blessing']) if not model_utils.is_model_blessed(model_blessing): output.set_int_custom_property('inferred', 0) logging.info('Model on %s was not blessed', model_blessing.uri) return else: logging.info('Model blessing is not provided, exported model will be ' 'used.') if _CUSTOM_CONFIG_KEY not in exec_properties: raise ValueError('Input exec properties are not valid, {} ' 'need to be specified.'.format(_CUSTOM_CONFIG_KEY)) custom_config = json_utils.loads( exec_properties.get(_CUSTOM_CONFIG_KEY, 'null')) if custom_config is not None and not isinstance(custom_config, Dict): raise ValueError('custom_config in execution properties needs to be a ' 'dict.') ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY) if not ai_platform_serving_args: raise ValueError( '\'ai_platform_serving_args\' is missing in \'custom_config\'') service_name, api_version = runner.get_service_name_and_api_version( ai_platform_serving_args) executor_class_path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.get_labels_dict() model = artifact_utils.get_single_instance(input_dict['model']) model_path = path_utils.serving_model_path(model.uri) logging.info('Use exported model from %s.', model_path) # Use model artifact uri to generate model version to guarantee the # 1:1 mapping from model version to model. model_version = 'version_' + hashlib.sha256(model.uri.encode()).hexdigest() inference_spec = self._get_inference_spec(model_path, model_version, ai_platform_serving_args) data_spec = bulk_inferrer_pb2.DataSpec() json_format.Parse(exec_properties['data_spec'], data_spec) api = discovery.build(service_name, api_version) new_model_created = False try: new_model_created = runner.create_model_for_aip_prediction_if_not_exist( api, job_labels, ai_platform_serving_args) runner.deploy_model_for_aip_prediction( api, model_path, model_version, ai_platform_serving_args, job_labels, skip_model_creation=True, set_default_version=False, ) self._run_model_inference(data_spec, input_dict['examples'], output.uri, inference_spec) except Exception as e: logging.error('Error in executing CloudAIBulkInferrerComponent: %s', str(e)) output.set_int_custom_property('inferred', 0) raise finally: # Guarantee newly created resources are cleaned up even if theinference # job failed. # Clean up the newly deployed model. runner.delete_model_version_from_aip_if_exists(api, model_version, ai_platform_serving_args) if new_model_created: runner.delete_model_from_aip_if_exists(api, ai_platform_serving_args) # Mark the inferenence as successful after resources are cleaned up. output.set_int_custom_property('inferred', 1)
def deserialize(cls, data: str) -> Any: """Deserializes an Artifact object from JSON dict.""" artifact = pipeline_spec_pb2.RuntimeArtifact() json_format.Parse(data, artifact, ignore_unknown_fields=True) return cls.get_from_runtime_artifact(artifact)
def train(self, progress_bar=False): checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time self.dataset.load_samples(processes=1, progress_bar=progress_bar) datas, txts = self.dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(datas) == 0: raise Exception("Empty dataset is not allowed. Check if the data is at the correct location") if self.validation_dataset: self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar) validation_datas, validation_txts = self.validation_dataset.train_samples(skip_empty=checkpoint_params.skip_invalid_gt) if len(validation_datas) == 0: raise Exception("Validation dataset is empty. Provide valid validation data for early stopping.") else: validation_datas, validation_txts = [], [] # preprocessing steps texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar) datas = self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_txts = self.txt_preproc.apply(validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_datas = self.data_preproc.apply(validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar) # compute the codec codec = self.codec if self.codec else Codec.from_texts(texts, whitelist=self.codec_whitelist) # data augmentation on preprocessed data if self.data_augmenter: datas, texts = self.data_augmenter.augment_datas(datas, texts, n_augmentations=self.n_augmentations, processes=checkpoint_params.processes, progress_bar=progress_bar) # TODO: validation data augmentation # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0, # processes=checkpoint_params.processes, progress_bar=progress_bar) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height network_params.classes = len(codec) if self.weights: # if we load the weights, take care of codec changes as-well with open(self.weights + '.json', 'r') as f: restore_checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format( network_params.features, checkpoint_params.model.line_height )) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights else: codec_changes = None # store the new codec checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) # compute the labels with (new/current) codec labels = [codec.encode(txt) for txt in texts] backend = create_backend_from_proto(network_params, weights=self.weights, ) backend.set_train_data(datas, labels) backend.set_prediction_data(validation_datas) if codec_changes: backend.realign_model_labels(*codec_changes) backend.prepare(train=True) loss_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.loss_stats) ler_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.ler_stats) dt_stats = RunningStatistics(checkpoint_params.stats_size, checkpoint_params.dt_stats) early_stopping_enabled = self.validation_dataset is not None \ and checkpoint_params.early_stopping_frequency > 0 \ and checkpoint_params.early_stopping_nbest > 1 early_stopping_best_accuracy = checkpoint_params.early_stopping_best_accuracy early_stopping_best_cur_nbest = checkpoint_params.early_stopping_best_cur_nbest early_stopping_best_at_iter = checkpoint_params.early_stopping_best_at_iter early_stopping_predictor = Predictor(codec=codec, text_postproc=self.txt_postproc, backend=backend) # Start the actual training # ==================================================================================== iter = checkpoint_params.iter # helper function to write a checkpoint def make_checkpoint(base_dir, prefix, version=None): if version: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{}.ckpt".format(prefix, version))) else: checkpoint_path = os.path.abspath(os.path.join(base_dir, "{}{:08d}.ckpt".format(prefix, iter + 1))) print("Storing checkpoint to '{}'".format(checkpoint_path)) backend.save_checkpoint(checkpoint_path) checkpoint_params.iter = iter checkpoint_params.loss_stats[:] = loss_stats.values checkpoint_params.ler_stats[:] = ler_stats.values checkpoint_params.dt_stats[:] = dt_stats.values checkpoint_params.total_time = time.time() - train_start_time checkpoint_params.early_stopping_best_accuracy = early_stopping_best_accuracy checkpoint_params.early_stopping_best_cur_nbest = early_stopping_best_cur_nbest checkpoint_params.early_stopping_best_at_iter = early_stopping_best_at_iter with open(checkpoint_path + ".json", 'w') as f: f.write(json_format.MessageToJson(checkpoint_params)) return checkpoint_path try: last_checkpoint = None # Training loop, can be interrupted by early stopping for iter in range(iter, checkpoint_params.max_iters): checkpoint_params.iter = iter iter_start_time = time.time() result = backend.train_step(checkpoint_params.batch_size) if not np.isfinite(result['loss']): print("Error: Loss is not finite! Trying to restart from last checkpoint.") if not last_checkpoint: raise Exception("No checkpoint written yet. Training must be stopped.") else: # reload also non trainable weights, such as solver-specific variables backend.load_checkpoint_weights(last_checkpoint, restore_only_trainable=False) continue loss_stats.push(result['loss']) ler_stats.push(result['ler']) dt_stats.push(time.time() - iter_start_time) if iter % checkpoint_params.display == 0: pred_sentence = self.txt_postproc.apply("".join(codec.decode(result["decoded"][0]))) gt_sentence = self.txt_postproc.apply("".join(codec.decode(result["gt"][0]))) print("#{:08d}: loss={:.8f} ler={:.8f} dt={:.8f}s".format(iter, loss_stats.mean(), ler_stats.mean(), dt_stats.mean())) print(" PRED: '{}'".format(pred_sentence)) print(" TRUE: '{}'".format(gt_sentence)) if (iter + 1) % checkpoint_params.checkpoint_frequency == 0: last_checkpoint = make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix) if early_stopping_enabled and (iter + 1) % checkpoint_params.early_stopping_frequency == 0: print("Checking early stopping model") out = early_stopping_predictor.predict_raw(validation_datas, batch_size=checkpoint_params.batch_size, progress_bar=progress_bar, apply_preproc=False) pred_texts = [d.sentence for d in out] result = Evaluator.evaluate(gt_data=validation_txts, pred_data=pred_texts, progress_bar=progress_bar) accuracy = 1 - result["avg_ler"] if accuracy > early_stopping_best_accuracy: early_stopping_best_accuracy = accuracy early_stopping_best_cur_nbest = 1 early_stopping_best_at_iter = iter + 1 # overwrite as best model last_checkpoint = make_checkpoint( checkpoint_params.early_stopping_best_model_output_dir, prefix="", version=checkpoint_params.early_stopping_best_model_prefix, ) print("Found better model with accuracy of {:%}".format(early_stopping_best_accuracy)) else: early_stopping_best_cur_nbest += 1 print("No better model found. Currently accuracy of {:%} at iter {} (remaining nbest = {})". format(early_stopping_best_accuracy, early_stopping_best_at_iter, checkpoint_params.early_stopping_nbest - early_stopping_best_cur_nbest)) if accuracy > 0 and early_stopping_best_cur_nbest >= checkpoint_params.early_stopping_nbest: print("Early stopping now.") break except KeyboardInterrupt as e: print("Storing interrupted checkpoint") make_checkpoint(checkpoint_params.output_dir, checkpoint_params.output_model_prefix, "interrupted") raise e print("Total time {}s for {} iterations.".format(time.time() - train_start_time, iter))
def read_encrypted(self) -> tink_pb2.EncryptedKeyset: try: return json_format.Parse(self._serialized_keyset, tink_pb2.EncryptedKeyset()) except json_format.ParseError as e: raise core.TinkError(e)
def main(): # Setup signal handlers signal.signal(signal.SIGTERM, atexit_function) signal.signal(signal.SIGHUP, atexit_function) signal.signal(signal.SIGINT, atexit_function) parser = argparse.ArgumentParser( description='Pulsar Functions Python Instance') parser.add_argument('--function_details', required=True, help='Function Details Json String') parser.add_argument('--py', required=True, help='Full Path of Function Code File') parser.add_argument('--instance_id', required=True, help='Instance Id') parser.add_argument('--function_id', required=True, help='Function Id') parser.add_argument('--function_version', required=True, help='Function Version') parser.add_argument('--pulsar_serviceurl', required=True, help='Pulsar Service Url') parser.add_argument('--client_auth_plugin', required=False, help='Client authentication plugin') parser.add_argument('--client_auth_params', required=False, help='Client authentication params') parser.add_argument('--use_tls', required=False, help='Use tls') parser.add_argument('--tls_allow_insecure_connection', required=False, help='Tls allow insecure connection') parser.add_argument('--hostname_verification_enabled', required=False, help='Enable hostname verification') parser.add_argument('--tls_trust_cert_path', required=False, help='Tls trust cert file path') parser.add_argument('--port', required=True, help='Instance Port', type=int) parser.add_argument('--max_buffered_tuples', required=True, help='Maximum number of Buffered tuples') parser.add_argument('--logging_directory', required=True, help='Logging Directory') parser.add_argument('--logging_file', required=True, help='Log file name') args = parser.parse_args() function_details = Function_pb2.FunctionDetails() json_format.Parse(args.function_details, function_details) log_file = os.path.join( args.logging_directory, util.getFullyQualifiedFunctionName(function_details.tenant, function_details.namespace, function_details.name), "%s-%s.log" % (args.logging_file, args.instance_id)) log.init_rotating_logger(level=logging.INFO, logfile=log_file, max_files=5, max_bytes=10 * 1024 * 1024) Log.info("Starting Python instance with %s" % str(args)) authentication = None use_tls = False tls_allow_insecure_connection = False tls_trust_cert_path = None if args.client_auth_plugin and args.client_auth_params: authentication = pulsar.Authentication(args.client_auth_plugin, args.client_auth_params) if args.use_tls == "true": use_tls = True if args.tls_allow_insecure_connection == "true": tls_allow_insecure_connection = True if args.tls_trust_cert_path: tls_trust_cert_path = args.tls_trust_cert_path pulsar_client = pulsar.Client(args.pulsar_serviceurl, authentication, 30, 1, 1, 50000, None, use_tls, tls_trust_cert_path, tls_allow_insecure_connection) pyinstance = python_instance.PythonInstance(str(args.instance_id), str(args.function_id), str(args.function_version), function_details, int(args.max_buffered_tuples), str(args.py), pulsar_client) pyinstance.run() server_instance = server.serve(args.port, pyinstance) global to_run while to_run: time.sleep(1) pyinstance.join() sys.exit(1)
def _resolve_proto_operator( self, op: placeholder_pb2.ProtoOperator ) -> Union[int, float, str, bool, bytes]: """Evaluates the proto operator.""" raw_message = self.resolve(op.expression) if raw_message is None: raise NullDereferenceError(op.expression) if isinstance(raw_message, str): # We need descriptor pool to parse encoded raw messages. pool = descriptor_pool.Default() for file_descriptor in op.proto_schema.file_descriptors.file: pool.Add(file_descriptor) message_descriptor = pool.FindMessageTypeByName( op.proto_schema.message_type) factory = message_factory.MessageFactory(pool) message_type = factory.GetPrototype(message_descriptor) value = message_type() json_format.Parse(raw_message, value, descriptor_pool=pool) elif isinstance(raw_message, message.Message): # Message such as platform config should not be encoded. value = raw_message else: raise ValueError( f"Got unsupported value type for proto operator: {type(raw_message)}." ) if op.proto_field_path: for field in op.proto_field_path: if field.startswith("."): value = getattr(value, field[1:]) continue map_key = re.findall(r"\[['\"](.+)['\"]\]", field) if len(map_key) == 1: value = value[map_key[0]] continue index = re.findall(r"\[(\d+)\]", field) if index and str.isdecimal(index[0]): value = value[int(index[0])] continue raise ValueError(f"Got unsupported proto field path: {field}") # Non-message primitive values are returned directly. if isinstance(value, (int, float, str, bool, bytes)): return value if not isinstance(value, message.Message): raise ValueError(f"Got unsupported value type {type(value)} " "from accessing proto field path.") # For message-typed values, we need to consider serialization format. if op.serialization_format: if op.serialization_format == placeholder_pb2.ProtoOperator.JSON: return json_format.MessageToJson( message=value, sort_keys=True, preserving_proto_field_name=True) if op.serialization_format == placeholder_pb2.ProtoOperator.TEXT_FORMAT: return text_format.MessageToString(value) if op.serialization_format == placeholder_pb2.ProtoOperator.BINARY: return value.SerializeToString() raise ValueError( "Proto operator resolves to a proto message value. A serialization " "format is needed to render it.")
def async_detect_document_tibetan(args): book_name = os.path.splitext(args.filepath)[0] gcs_source_uri = "gs://" + args.bucket + "/" + args.filepath gcs_destination_uri = "gs://" + args.bucket + "/" + book_name + "/" """OCR with PDF/TIFF as source files on GCS""" # Supported mime_types are: 'application/pdf' and 'image/tiff' mime_type = 'image/tiff' if args.tiff else 'application/pdf' # How many pages should be grouped into each json output file. batch_size = 1 client = vision.ImageAnnotatorClient() feature = vision.types.Feature( type=vision.enums.Feature.Type.DOCUMENT_TEXT_DETECTION) gcs_source = vision.types.GcsSource(uri=gcs_source_uri) input_config = vision.types.InputConfig(gcs_source=gcs_source, mime_type=mime_type) gcs_destination = vision.types.GcsDestination(uri=gcs_destination_uri) output_config = vision.types.OutputConfig(gcs_destination=gcs_destination, batch_size=batch_size) image_context = vision.types.ImageContext(language_hints=["en"]) async_request = vision.types.AsyncAnnotateFileRequest( features=[feature], input_config=input_config, output_config=output_config, image_context=image_context) operation = client.async_batch_annotate_files(requests=[async_request]) print('Waiting for the operation to finish.') operation.result() # Once the request has completed and the output has been # written to GCS, we can list all the output files. storage_client = storage.Client() bucket = storage_client.get_bucket(bucket_name=args.bucket) # List objects with the given prefix. get_file_number = lambda name: int( re.search(r'(\d+)\.json', name).group(1)) blob_list = sorted([(get_file_number(blob.name), blob) for blob in bucket.list_blobs(prefix=book_name + "/")]) print('Output files:') for _, blob in blob_list: print(blob.name) output_name = os.path.join( args.output_dir, book_name + ".txt") if args.output_dir else book_name + ".txt" print('Collecting all text locally as {}'.format(output_name)) with open(output_name, "w") as f: # Collect all text from outputs for i, output in blob_list: json_string = output.download_as_string() response = json_format.Parse(json_string, vision.types.AnnotateFileResponse()) # The actual response for the first page of the input file. first_page_response = response.responses[0] annotation = first_page_response.full_text_annotation f.write(annotation.text) if i == 1: # Here we print the full text from the first page. # The response contains more information: # annotation/pages/blocks/paragraphs/words/symbols # including confidence scores and bounding boxes print(u'Full text:\n{}'.format(annotation.text))
def generic_command(client, args): params = json_format.Parse(args.params, Struct()) response = client.GenericCommand( magmad_pb2.GenericCommandParams(command=args.command, params=params), ) print(response)
def Do(self, input_dict, output_dict, exec_properties): """Push model to target if blessed. Args: input_dict: Input dict from input key to a list of artifacts, including: - model_export: exported model from trainer. - model_blessing: model blessing path from model_validator. output_dict: Output dict from key to a list of artifacts, including: - model_push: A list of 'ModelPushPath' artifact of size one. It will include the model in this push execution if the model was pushed. exec_properties: A dict of execution properties, including: - push_destination: JSON string of pusher_pb2.PushDestination instance, providing instruction of destination to push model. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) model_export = types.get_single_instance(input_dict['model_export']) model_export_uri = model_export.uri model_blessing_uri = types.get_single_uri(input_dict['model_blessing']) model_push = types.get_single_instance(output_dict['model_push']) model_push_uri = model_push.uri # TODO(jyzhao): should this be in driver or executor. if not tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')): model_push.set_int_custom_property('pushed', 0) tf.logging.info('Model on %s was not blessed', ) return tf.logging.info('Model pushing.') # Copy the model we are pushing into model_path = path_utils.serving_model_path(model_export_uri) # Note: we do not have a logical model version right now. This # model_version is a timestamp mapped to trainer's exporter. model_version = os.path.basename(model_path) tf.logging.info('Model version is %s', model_version) io_utils.copy_dir(model_path, os.path.join(model_push_uri, model_version)) tf.logging.info('Model written to %s.', model_push_uri) # Copied to a fixed outside path, which can be listened by model server. # # If model is already successfully copied to outside before, stop copying. # This is because model validator might blessed same model twice (check # mv driver) with different blessing output, we still want Pusher to # handle the mv output again to keep metadata tracking, but no need to # copy to outside path again.. # TODO(jyzhao): support rpc push and verification. push_destination = pusher_pb2.PushDestination() json_format.Parse(exec_properties['push_destination'], push_destination) serving_path = os.path.join(push_destination.filesystem.base_directory, model_version) if tf.gfile.Exists(serving_path): tf.logging.info( 'Destination directory %s already exists, skipping current push.', serving_path) else: # tf.serving won't load partial model, it will retry until fully copied. io_utils.copy_dir(model_path, serving_path) tf.logging.info('Model written to serving path %s.', serving_path) model_push.set_int_custom_property('pushed', 1) model_push.set_string_custom_property('pushed_model', model_export_uri) model_push.set_int_custom_property('pushed_model_id', model_export.id) tf.logging.info('Model pushed to %s.', serving_path) if exec_properties.get('custom_config'): cmle_serving_args = exec_properties.get( 'custom_config', {}).get('cmle_serving_args') if cmle_serving_args is not None: return cmle_runner.deploy_model_for_serving( serving_path, model_version, cmle_serving_args, exec_properties['log_root'])
def fromJson(json, protoClass): """ Deserialise json into an instance of protobuf class """ return json_format.Parse(json, protoClass())
def resolve_input_artifacts( self, input_channels: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, ) -> Dict[Text, List[types.Artifact]]: """Overrides BaseDriver.resolve_input_artifacts().""" del driver_args # unused del pipeline_info # unused input_config = example_gen_pb2.Input() json_format.Parse(exec_properties['input_config'], input_config) input_dict = channel_utils.unwrap_channel_dict(input_channels) for input_list in input_dict.values(): for single_input in input_list: absl.logging.debug('Processing input %s.' % single_input.uri) absl.logging.debug('single_input %s.' % single_input) absl.logging.debug('single_input.artifact %s.' % single_input.artifact) # Set the fingerprint of input. split_fingerprints = [] select_span = None for split in input_config.splits: # If SPAN is specified, pipeline will process the latest span, note # that this span number must be the same for all splits and it will # be stored in metadata as the span of input artifact. if _SPAN_SPEC in split.pattern: latest_span = self._retrieve_latest_span( single_input.uri, split) if select_span is None: select_span = latest_span if select_span != latest_span: raise ValueError( 'Latest span should be the same for each split: %s != %s' % (select_span, latest_span)) split.pattern = split.pattern.replace( _SPAN_SPEC, select_span) pattern = os.path.join(single_input.uri, split.pattern) split_fingerprints.append( io_utils.generate_fingerprint(split.name, pattern)) fingerprint = '\n'.join(split_fingerprints) single_input.set_string_custom_property( _FINGERPRINT, fingerprint) if select_span is None: select_span = '0' single_input.set_string_custom_property(_SPAN, select_span) matched_artifacts = [] for artifact in self._metadata_handler.get_artifacts_by_uri( single_input.uri): if (artifact.custom_properties[_FINGERPRINT].string_value == fingerprint) and ( artifact.custom_properties[_SPAN].string_value == select_span): matched_artifacts.append(artifact) if matched_artifacts: # TODO(b/138845899): consider use span instead of id. # If there are multiple matches, get the latest one for caching. # Using id because spans are the same for matched artifacts. latest_artifact = max(matched_artifacts, key=lambda artifact: artifact.id) absl.logging.debug('latest_artifact %s.' % (latest_artifact)) absl.logging.debug('type(latest_artifact) %s.' % type(latest_artifact)) single_input.set_artifact(latest_artifact) else: # TODO(jyzhao): whether driver should be read-only for metadata. [new_artifact] = self._metadata_handler.publish_artifacts( [single_input]) # pylint: disable=unbalanced-tuple-unpacking absl.logging.debug('Registered new input: %s' % (new_artifact)) single_input.set_artifact(new_artifact) exec_properties['input_config'] = json_format.MessageToJson( input_config, sort_keys=True) return input_dict
def CheckParseBack(self, message, parsed_message): json_format.Parse(json_format.MessageToJson(message), parsed_message) self.assertEqual(message, parsed_message)
def read_from_json(path: type_utils.PathLike) -> dataset_info_pb2.DatasetInfo: """Read JSON-formatted proto into DatasetInfo proto.""" json_str = utils.as_path(path).read_text() # Parse it back into a proto. parsed_proto = json_format.Parse(json_str, dataset_info_pb2.DatasetInfo()) return parsed_proto