def testConstructSubclassQueryBasedWithRangeConfig(self): # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be # recorded in output Example artifact. range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) example_gen = TestQueryBasedExampleGenComponent( input_config=example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='single', pattern='select * from table where date=@span_yyyymmdd_utc' ), ]), range_config=range_config) self.assertEqual({}, example_gen.inputs) self.assertEqual(driver.QueryBasedDriver, example_gen.driver_class) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, example_gen.outputs[ standard_component_specs.EXAMPLES_KEY].type_name) stored_range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto( example_gen.exec_properties[ standard_component_specs.RANGE_CONFIG_KEY], stored_range_config) self.assertEqual(range_config, stored_range_config)
def _extract_conn_config(self, custom_config): unpacked_custom_config = example_gen_pb2.CustomConfig() proto_utils.json_to_proto(custom_config, unpacked_custom_config) conn_config = presto_config_pb2.PrestoConnConfig() unpacked_custom_config.custom_config.Unpack(conn_config) return conn_config
def ResolveSplitsConfig( splits_config_str: Optional[str], examples: List[types.Artifact]) -> transform_pb2.SplitsConfig: """Resolve SplitsConfig proto for the transfrom request.""" result = transform_pb2.SplitsConfig() if splits_config_str: proto_utils.json_to_proto(splits_config_str, result) if not result.analyze: raise ValueError( 'analyze cannot be empty when splits_config is set.') return result result.analyze.append('train') # All input artifacts should have the same set of split names. split_names = set( artifact_utils.decode_split_names(examples[0].split_names)) for artifact in examples: artifact_split_names = set( artifact_utils.decode_split_names(artifact.split_names)) if split_names != artifact_split_names: raise ValueError( 'Not all input artifacts have the same split names: (%s, %s)' % (split_names, artifact_split_names)) result.transform.extend(split_names) logging.info("Analyze the 'train' split and transform all splits when " 'splits_config is not set.') return result
def testQueryBasedDriver(self): # Create exec proterties. exec_properties = { standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern= "select * from table where span={SPAN} and split='s1'" ), example_gen_pb2.Input.Split( name='s2', pattern= "select * from table where span={SPAN} and split='s2'") ])), standard_component_specs.RANGE_CONFIG_KEY: proto_utils.proto_to_json( range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=2, end_span_number=2))), } # Prepare output_dict example = standard_artifacts.Examples() example.uri = 'my_uri' output_dict = {standard_component_specs.EXAMPLES_KEY: [example]} query_based_driver = driver.QueryBasedDriver(self._mock_metadata) result = query_based_driver.run( portable_data_types.ExecutionInfo(output_dict=output_dict, exec_properties=exec_properties)) self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME], 2) self.assertIsNone(exec_properties[utils.VERSION_PROPERTY_NAME]) self.assertIsNone(exec_properties[utils.FINGERPRINT_PROPERTY_NAME]) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "select * from table where span=2 and split='s1'" } splits { name: "s2" pattern: "select * from table where span=2 and split='s2'" }""", updated_input_config) self.assertLen( result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, '2')
def get_tune_args( exec_properties: Dict[str, Any]) -> Optional[tuner_pb2.TuneArgs]: """Returns TuneArgs protos from execution properties, if present.""" tune_args = exec_properties.get(standard_component_specs.TUNE_ARGS_KEY) if not tune_args: return None result = tuner_pb2.TuneArgs() proto_utils.json_to_proto(tune_args, result) return result
def _get_inference_spec( self, model_path: Text, exec_properties: Dict[Text, Any]) -> model_spec_pb2.InferenceSpecType: model_spec = bulk_inferrer_pb2.ModelSpec() proto_utils.json_to_proto(exec_properties['model_spec'], model_spec) saved_model_spec = model_spec_pb2.SavedModelSpec( model_path=model_path, tag=model_spec.tag, signature_name=model_spec.model_signature_name) result = model_spec_pb2.InferenceSpecType() result.saved_model_spec.CopyFrom(saved_model_spec) return result
def testConstructWithStaticRangeConfig(self): range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=1, end_span_number=1)) example_gen = component.FileBasedExampleGen( input_base='path', range_config=range_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) stored_range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(example_gen.exec_properties['range_config'], stored_range_config) self.assertEqual(range_config, stored_range_config)
def testConstructWithCustomConfig(self): custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any()) example_gen = component.FileBasedExampleGen( input_base='path', custom_config=custom_config, custom_executor_spec=executor_spec.BeamExecutorSpec( TestExampleGenExecutor)) stored_custom_config = example_gen_pb2.CustomConfig() proto_utils.json_to_proto( example_gen.exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY], stored_custom_config) self.assertEqual(custom_config, stored_custom_config)
def _get_inference_spec( self, model_path: str, exec_properties: Dict[str, Any]) -> model_spec_pb2.InferenceSpecType: model_spec = bulk_inferrer_pb2.ModelSpec() proto_utils.json_to_proto( exec_properties[standard_component_specs.MODEL_SPEC_KEY], model_spec) saved_model_spec = model_spec_pb2.SavedModelSpec( model_path=model_path, tag=model_spec.tag, signature_name=model_spec.model_signature_name) result = model_spec_pb2.InferenceSpecType() result.saved_model_spec.CopyFrom(saved_model_spec) return result
def testConstructWithInputConfig(self): input_config = example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), example_gen_pb2.Input.Split(name='test', pattern='test/*') ]) example_gen = TestFileBasedExampleGenComponent( input_base='path', input_config=input_config) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) stored_input_config = example_gen_pb2.Input() proto_utils.json_to_proto(example_gen.exec_properties['input_config'], stored_input_config) self.assertEqual(input_config, stored_input_config)
def _dict_to_object(self, dict_data: Dict[Text, Any]) -> Any: """Converts a dictionary to an object.""" if _TFX_OBJECT_TYPE_KEY not in dict_data: return dict_data object_type = dict_data.pop(_TFX_OBJECT_TYPE_KEY) def _extract_class(d): module_name = d.pop(_MODULE_KEY) class_name = d.pop(_CLASS_KEY) return getattr(importlib.import_module(module_name), class_name) if object_type == _ObjectType.JSONABLE: jsonable_class_type = _extract_class(dict_data) if not issubclass(jsonable_class_type, Jsonable): raise ValueError('Class %s must be a subclass of Jsonable' % jsonable_class_type) return jsonable_class_type.from_json_dict(dict_data) if object_type == _ObjectType.CLASS: return _extract_class(dict_data) if object_type == _ObjectType.PROTO: proto_class_type = _extract_class(dict_data) if not issubclass(proto_class_type, message.Message): raise ValueError( 'Class %s must be a subclass of proto.Message' % proto_class_type) if _PROTO_VALUE_KEY not in dict_data: raise ValueError('Missing proto value in json dict') return proto_utils.json_to_proto(dict_data[_PROTO_VALUE_KEY], proto_class_type())
def resolve_exec_properties( self, exec_properties: Dict[Text, Any], pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, Any]: """Overrides BaseDriver.resolve_exec_properties().""" del pipeline_info, component_info input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], input_config) input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY] logging.debug('Processing input %s.', input_base) range_config = None range_config_entry = exec_properties.get( standard_component_specs.RANGE_CONFIG_KEY) if range_config_entry: range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(range_config_entry, range_config) if range_config.HasField('static_range'): # For ExampleGen, StaticRange must specify an exact span to look for, # since only one span is processed at a time. start_span_number = range_config.static_range.start_span_number end_span_number = range_config.static_range.end_span_number if start_span_number != end_span_number: raise ValueError( 'Start and end span numbers for RangeConfig.static_range must ' 'be equal: (%s, %s)' % (start_span_number, end_span_number)) # Note that this function updates the input_config.splits.pattern. fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version( input_base, input_config.splits, range_config) exec_properties[standard_component_specs. INPUT_CONFIG_KEY] = proto_utils.proto_to_json( input_config) exec_properties[utils.SPAN_PROPERTY_NAME] = span exec_properties[utils.VERSION_PROPERTY_NAME] = version exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint return exec_properties
def _parse_parameters(self, raw_args: Mapping[str, Any]): """Parse arguments to ComponentSpec.""" unparsed_args = set(raw_args.keys()) inputs = {} outputs = {} self.exec_properties = {} # First, check that the arguments are set. for arg_name, arg in itertools.chain(self.PARAMETERS.items(), self.INPUTS.items(), self.OUTPUTS.items()): if arg_name not in unparsed_args: if arg.optional: continue else: raise ValueError('Missing argument %r to %s.' % (arg_name, self.__class__)) unparsed_args.remove(arg_name) # Type check the argument. value = raw_args[arg_name] if arg.optional and value is None: continue arg.type_check(arg_name, value) # Populate the appropriate dictionary for each parameter type. for arg_name, arg in self.PARAMETERS.items(): if arg.optional and arg_name not in raw_args: continue value = raw_args[arg_name] if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) and value and not _is_runtime_param(value)): if arg.use_proto: if isinstance(value, dict): value = proto_utils.dict_to_proto(value, arg.type()) elif isinstance(value, str): value = proto_utils.json_to_proto(value, arg.type()) else: # Create deterministic json string as it will be stored in metadata # for cache check. if isinstance(value, dict): value = json_utils.dumps(value) elif not isinstance(value, str): value = proto_utils.proto_to_json(value) self.exec_properties[arg_name] = value for arg_dict, param_dict in ((self.INPUTS, inputs), (self.OUTPUTS, outputs)): for arg_name, arg in arg_dict.items(): if arg.optional and not raw_args.get(arg_name): continue value = raw_args[arg_name] param_dict[arg_name] = value self.inputs = inputs self.outputs = outputs
def testConstructWithOutputConfig(self): output_config = example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=2), example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1), example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=1) ])) example_gen = TestFileBasedExampleGenComponent( input_base='path', output_config=output_config) self.assertEqual(standard_artifacts.Examples.TYPE_NAME, example_gen.outputs['examples'].type_name) stored_output_config = example_gen_pb2.Output() proto_utils.json_to_proto(example_gen.exec_properties['output_config'], stored_output_config) self.assertEqual(output_config, stored_output_config)
def resolve_exec_properties( self, exec_properties: Dict[Text, Any], pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, Any]: """Overrides BaseDriver.resolve_exec_properties().""" del pipeline_info, component_info input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], input_config) input_base = exec_properties.get( standard_component_specs.INPUT_BASE_KEY) logging.debug('Processing input %s.', input_base) range_config = None range_config_entry = exec_properties.get( standard_component_specs.RANGE_CONFIG_KEY) if range_config_entry: range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(range_config_entry, range_config) processor = self.get_input_processor(splits=input_config.splits, range_config=range_config, input_base_uri=input_base) span, version = processor.resolve_span_and_version() fingerprint = processor.get_input_fingerprint(span, version) # Updates the input_config.splits.pattern. for split in input_config.splits: split.pattern = processor.get_pattern_for_span_version( split.pattern, span, version) exec_properties[standard_component_specs. INPUT_CONFIG_KEY] = proto_utils.proto_to_json( input_config) exec_properties[utils.SPAN_PROPERTY_NAME] = span exec_properties[utils.VERSION_PROPERTY_NAME] = version exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint return exec_properties
def testConstructWithRangeConfig(self): range_config = range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange(start_span_number=2, end_span_number=2)) # @span_yyyymmdd_utc will replaced to '19700103' to query, span `2` will be # recorded in output Example artifact. big_query_example_gen = component.BigQueryExampleGen( query='select * from table where date=@span_yyyymmdd_utc', range_config=range_config) self.assertEqual( standard_artifacts.Examples.TYPE_NAME, big_query_example_gen.outputs[ standard_component_specs.EXAMPLES_KEY].type_name) stored_range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto( big_query_example_gen.exec_properties[ standard_component_specs.RANGE_CONFIG_KEY], stored_range_config) self.assertEqual(range_config, stored_range_config)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Contract for running InfraValidator Executor. Args: input_dict: - `model`: Single `Model` artifact that we're validating. - `examples`: `Examples` artifacts to be used for test requests. output_dict: - `blessing`: Single `InfraBlessing` artifact containing the validated result. It is an empty file with the name either of INFRA_BLESSED or INFRA_NOT_BLESSED. exec_properties: - `serving_spec`: Serialized `ServingSpec` configuration. - `validation_spec`: Serialized `ValidationSpec` configuration. - `request_spec`: Serialized `RequestSpec` configuration. """ self._log_startup(input_dict, output_dict, exec_properties) model = artifact_utils.get_single_instance(input_dict[MODEL_KEY]) blessing = artifact_utils.get_single_instance(output_dict[BLESSING_KEY]) if input_dict.get(EXAMPLES_KEY): examples = artifact_utils.get_single_instance(input_dict[EXAMPLES_KEY]) else: examples = None serving_spec = infra_validator_pb2.ServingSpec() proto_utils.json_to_proto(exec_properties[SERVING_SPEC_KEY], serving_spec) if not serving_spec.model_name: serving_spec.model_name = _DEFAULT_MODEL_NAME validation_spec = infra_validator_pb2.ValidationSpec() if exec_properties.get(VALIDATION_SPEC_KEY): proto_utils.json_to_proto(exec_properties[VALIDATION_SPEC_KEY], validation_spec) if not validation_spec.num_tries: validation_spec.num_tries = _DEFAULT_NUM_TRIES if not validation_spec.max_loading_time_seconds: validation_spec.max_loading_time_seconds = _DEFAULT_MAX_LOADING_TIME_SEC if exec_properties.get(REQUEST_SPEC_KEY): request_spec = infra_validator_pb2.RequestSpec() proto_utils.json_to_proto(exec_properties[REQUEST_SPEC_KEY], request_spec) else: request_spec = None with self._InstallGracefulShutdownHandler(): self._Do( model=model, examples=examples, blessing=blessing, serving_spec=serving_spec, validation_spec=validation_spec, request_spec=request_spec, )
def _PrestoToExample( # pylint: disable=invalid-name pipeline: beam.Pipeline, exec_properties: Dict[str, Any], split_pattern: str) -> beam.pvalue.PCollection: """Read from Presto and transform to TF examples. Args: pipeline: beam pipeline. exec_properties: A dict of execution properties. split_pattern: Split.pattern in Input config, a Presto sql string. Returns: PCollection of TF examples. """ conn_config = example_gen_pb2.CustomConfig() proto_utils.json_to_proto(exec_properties['custom_config'], conn_config) presto_config = presto_config_pb2.PrestoConnConfig() conn_config.custom_config.Unpack(presto_config) client = _deserialize_conn_config(presto_config) return (pipeline | 'Query' >> beam.Create([split_pattern]) | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client)) | 'ToTFExample' >> beam.Map(_row_to_example))
def extract_properties( execution: metadata_store_pb2.Execution ) -> Dict[str, types.ExecPropertyTypes]: """Extracts execution properties from mlmd Execution.""" result = {} for key, prop in itertools.chain(execution.properties.items(), execution.custom_properties.items()): if execution_lib.is_schema_key(key): continue schema_key = execution_lib.get_schema_key(key) schema = None if schema_key in execution.custom_properties: schema = proto_utils.json_to_proto( data_types_utils.get_metadata_value( execution.custom_properties[schema_key]), pipeline_pb2.Value.Schema()) value = data_types_utils.get_parsed_value(prop, schema) if value is None: raise ValueError( f'Unexpected property with empty value; key: {key}') result[key] = value return result
def _run_driver(exec_properties: Dict[str, Any], outputs_dict: Dict[str, List[artifact.Artifact]], output_metadata_uri: str, name_from_id: Optional[Dict[int, str]] = None) -> None: """Runs the driver, writing its output as a ExecutorOutput proto. The main goal of this driver is to calculate the span and fingerprint of input data, allowing for the executor invocation to be skipped if the ExampleGen component has been previously run on the same data with the same configuration. This span and fingerprint are added as new custom execution properties to an ExecutorOutput proto and written to a GCS path. The CAIP pipelines system reads this file and updates MLMD with the new execution properties. Args: exec_properties: These are required to contain the following properties: 'input_base_uri': A path from which files will be read and their span/fingerprint calculated. 'input_config': A json-serialized tfx.proto.example_gen_pb2.InputConfig proto message. See https://www.tensorflow.org/tfx/guide/examplegen for more details. 'output_config': A json-serialized tfx.proto.example_gen_pb2.OutputConfig proto message. See https://www.tensorflow.org/tfx/guide/examplegen for more details. outputs_dict: The mapping of the output artifacts. output_metadata_uri: A path at which an ExecutorOutput message will be written with updated execution properties and output artifacts. The CAIP Pipelines service will update the task's properties and artifacts prior to running the executor. name_from_id: Optional. Mapping from the converted int-typed id to str-typed runtime artifact name, which should be unique. """ if name_from_id is None: name_from_id = {} logging.set_verbosity(logging.INFO) logging.info('exec_properties = %s\noutput_metadata_uri = %s', exec_properties, output_metadata_uri) input_base_uri = exec_properties[utils.INPUT_BASE_KEY] input_config = example_gen_pb2.Input() proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY], input_config) # TODO(b/161734559): Support range config. fingerprint, select_span, version = utils.calculate_splits_fingerprint_span_and_version( input_base_uri, input_config.splits) logging.info('Calculated span: %s', select_span) logging.info('Calculated fingerprint: %s', fingerprint) exec_properties[utils.SPAN_PROPERTY_NAME] = select_span exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint exec_properties[utils.VERSION_PROPERTY_NAME] = version if utils.EXAMPLES_KEY not in outputs_dict: raise ValueError( 'Example artifact was missing in the ExampleGen outputs.') example_artifact = artifact_utils.get_single_instance( outputs_dict[utils.EXAMPLES_KEY]) driver.update_output_artifact( exec_properties=exec_properties, output_artifact=example_artifact.mlmd_artifact) # Log the output metadata file output_metadata = pipeline_pb2.ExecutorOutput() output_metadata.parameters[ utils.FINGERPRINT_PROPERTY_NAME].string_value = fingerprint output_metadata.parameters[utils.SPAN_PROPERTY_NAME].string_value = str( select_span) output_metadata.parameters[ utils.INPUT_CONFIG_KEY].string_value = json_format.MessageToJson( input_config) output_metadata.artifacts[utils.EXAMPLES_KEY].artifacts.add().CopyFrom( kubeflow_v2_entrypoint_utils.to_runtime_artifact( example_artifact, name_from_id)) fileio.makedirs(os.path.dirname(output_metadata_uri)) with fileio.open(output_metadata_uri, 'wb') as f: f.write(json_format.MessageToJson(output_metadata, sort_keys=True))
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Runs batch inference on a given model with given input examples. Args: input_dict: Input dict from input key to a list of Artifacts. - examples: examples for inference. - model: exported model. - model_blessing: model blessing result, optional. output_dict: Output dict from output key to a list of Artifacts. - output: bulk inference results. exec_properties: A dict of execution properties. - model_spec: JSON string of bulk_inferrer_pb2.ModelSpec instance. - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) if output_dict.get(standard_component_specs.INFERENCE_RESULT_KEY): inference_result = artifact_utils.get_single_instance( output_dict[standard_component_specs.INFERENCE_RESULT_KEY]) else: inference_result = None if output_dict.get(standard_component_specs.OUTPUT_EXAMPLES_KEY): output_examples = artifact_utils.get_single_instance( output_dict[standard_component_specs.OUTPUT_EXAMPLES_KEY]) else: output_examples = None if 'examples' not in input_dict: raise ValueError('\'examples\' is missing in input dict.') if 'model' not in input_dict: raise ValueError('Input models are not valid, model ' 'need to be specified.') if standard_component_specs.MODEL_BLESSING_KEY in input_dict: model_blessing = artifact_utils.get_single_instance( input_dict[standard_component_specs.MODEL_BLESSING_KEY]) if not model_utils.is_model_blessed(model_blessing): logging.info('Model on %s was not blessed', model_blessing.uri) return else: logging.info( 'Model blessing is not provided, exported model will be ' 'used.') model = artifact_utils.get_single_instance( input_dict[standard_component_specs.MODEL_KEY]) model_path = path_utils.serving_model_path( model.uri, path_utils.is_old_model_artifact(model)) logging.info('Use exported model from %s.', model_path) data_spec = bulk_inferrer_pb2.DataSpec() proto_utils.json_to_proto( exec_properties[standard_component_specs.DATA_SPEC_KEY], data_spec) output_example_spec = bulk_inferrer_pb2.OutputExampleSpec() if exec_properties.get( standard_component_specs.OUTPUT_EXAMPLE_SPEC_KEY): proto_utils.json_to_proto( exec_properties[ standard_component_specs.OUTPUT_EXAMPLE_SPEC_KEY], output_example_spec) self._run_model_inference( data_spec, output_example_spec, input_dict[standard_component_specs.EXAMPLES_KEY], output_examples, inference_result, self._get_inference_spec(model_path, exec_properties))
def Do( self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any], ) -> None: """Take input data source and generates serialized data splits. The output is intended to be serialized tf.train.Examples or tf.train.SequenceExamples protocol buffer in gzipped TFRecord format, but subclasses can choose to override to write to any serialized records payload into gzipped TFRecord as specified, so long as downstream component can consume it. The format of payload is added to `payload_format` custom property of the output Example artifact. Args: input_dict: Input dict from input key to a list of Artifacts. Depends on detailed example gen implementation. output_dict: Output dict from output key to a list of Artifacts. - examples: splits of serialized records. exec_properties: A dict of execution properties. Depends on detailed example gen implementation. - input_base: an external directory containing the data files. - input_config: JSON string of example_gen_pb2.Input instance, providing input configuration. - output_config: JSON string of example_gen_pb2.Output instance, providing output configuration. - output_data_format: Payload format of generated data in output artifact, one of example_gen_pb2.PayloadFormat enum. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) input_config = example_gen_pb2.Input() proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY], input_config) output_config = example_gen_pb2.Output() proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY], output_config) examples_artifact = artifact_utils.get_single_instance( output_dict[utils.EXAMPLES_KEY]) examples_artifact.split_names = artifact_utils.encode_split_names( utils.generate_output_split_names(input_config, output_config)) logging.info('Generating examples.') with self._make_beam_pipeline() as pipeline: example_splits = self.GenerateExamplesByBeam(pipeline, exec_properties) # pylint: disable=expression-not-assigned, no-value-for-parameter for split_name, example_split in example_splits.items(): (example_split | 'WriteSplit[{}]'.format(split_name) >> _WriteSplit( artifact_utils.get_split_uri(output_dict[utils.EXAMPLES_KEY], split_name))) # pylint: enable=expression-not-assigned, no-value-for-parameter output_payload_format = exec_properties.get(utils.OUTPUT_DATA_FORMAT_KEY) if output_payload_format: for output_examples_artifact in output_dict[utils.EXAMPLES_KEY]: examples_utils.set_payload_format( output_examples_artifact, output_payload_format) logging.info('Examples generated.')
def GenerateExamplesByBeam( self, pipeline: beam.Pipeline, exec_properties: Dict[Text, Any], ) -> Dict[Text, beam.pvalue.PCollection]: """Converts input source to serialized record splits based on configs. Custom ExampleGen executor should provide GetInputSourceToExamplePTransform for converting input split to serialized records. Overriding this 'GenerateExamplesByBeam' method instead if complex logic is need, e.g., custom spliting logic. Args: pipeline: Beam pipeline. exec_properties: A dict of execution properties. Depends on detailed example gen implementation. - input_base: an external directory containing the data files. - input_config: JSON string of example_gen_pb2.Input instance, providing input configuration. - output_config: JSON string of example_gen_pb2.Output instance, providing output configuration. - output_data_format: Payload format of generated data in output artifact, one of example_gen_pb2.PayloadFormat enum. Returns: Dict of beam PCollection with split name as key, each PCollection is a single output split that contains serialized records. """ # Get input split information. input_config = example_gen_pb2.Input() proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY], input_config) # Get output split information. output_config = example_gen_pb2.Output() proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY], output_config) # Get output split names. split_names = utils.generate_output_split_names(input_config, output_config) # Make beam_pipeline_args available in exec_properties since certain # example_gen executors need this information. exec_properties['_beam_pipeline_args'] = self._beam_pipeline_args or [] example_splits = [] input_to_record = self.GetInputSourceToExamplePTransform() if output_config.split_config.splits: # Use output splits, input must have only one split. assert len( input_config.splits ) == 1, 'input must have only one split when output split is specified.' # Calculate split buckets. buckets = [] total_buckets = 0 for split in output_config.split_config.splits: total_buckets += split.hash_buckets buckets.append(total_buckets) example_splits = ( pipeline | 'InputToRecord' >> # pylint: disable=no-value-for-parameter input_to_record(exec_properties, input_config.splits[0].pattern) | 'SplitData' >> beam.Partition(_PartitionFn, len(buckets), buckets, output_config.split_config)) else: # Use input splits. for split in input_config.splits: examples = ( pipeline | 'InputToRecord[{}]'.format(split.name) >> # pylint: disable=no-value-for-parameter input_to_record(exec_properties, split.pattern)) example_splits.append(examples) result = {} for index, example_split in enumerate(example_splits): result[split_names[index]] = example_split return result
def testResolveExecProperties(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') fileio.makedirs(self._input_base_path) # Create exec proterties. self._exec_properties = { standard_component_specs.INPUT_BASE_KEY: self._input_base_path, standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN:2}/version{VERSION:2}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN:2}/version{VERSION:2}/split2/*') ])), standard_component_specs.RANGE_CONFIG_KEY: None, } # Test align of span number. span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'version01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'version01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') span2_v1_split1 = os.path.join(self._input_base_path, 'span02', 'version01', 'split1', 'data') io_utils.write_string_file(span2_v1_split1, 'testing21') # Check that error raised when span does not match. with self.assertRaisesRegexp( ValueError, 'Latest span should be the same for each split'): self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) span2_v1_split2 = os.path.join(self._input_base_path, 'span02', 'version01', 'split2', 'data') io_utils.write_string_file(span2_v1_split2, 'testing22') span2_v2_split1 = os.path.join(self._input_base_path, 'span02', 'version02', 'split1', 'data') io_utils.write_string_file(span2_v2_split1, 'testing21') # Check that error raised when span matches, but version does not match. with self.assertRaisesRegexp( ValueError, 'Latest version should be the same for each split'): self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) span2_v2_split2 = os.path.join(self._input_base_path, 'span02', 'version02', 'split2', 'data') io_utils.write_string_file(span2_v2_split2, 'testing22') # Test if latest span and version selected when span and version aligns # for each split. self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 2) self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 2) self.assertRegex( self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME], r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY], updated_input_config) # Check if latest span is selected. self.assertProtoEquals( """ splits { name: "s1" pattern: "span02/version02/split1/*" } splits { name: "s2" pattern: "span02/version02/split2/*" }""", updated_input_config) # Test driver behavior using RangeConfig with static range. self._exec_properties[ standard_component_specs. INPUT_CONFIG_KEY] = proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN:2}/version{VERSION:2}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN:2}/version{VERSION:2}/split2/*'), ])) self._exec_properties[ standard_component_specs. RANGE_CONFIG_KEY] = proto_utils.proto_to_json( range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=2))) with self.assertRaisesRegexp( ValueError, 'For ExampleGen, start and end span numbers'): self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) self._exec_properties[ standard_component_specs. RANGE_CONFIG_KEY] = proto_utils.proto_to_json( range_config_pb2.RangeConfig( static_range=range_config_pb2.StaticRange( start_span_number=1, end_span_number=1))) self._file_based_driver.resolve_exec_properties( self._exec_properties, None, None) self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 1) self.assertEqual(self._exec_properties[utils.VERSION_PROPERTY_NAME], 1) self.assertRegex( self._exec_properties[utils.FINGERPRINT_PROPERTY_NAME], r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY], updated_input_config) # Check if correct span inside static range is selected. self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/version01/split1/*" } splits { name: "s2" pattern: "span01/version01/split2/*" }""", updated_input_config)
def run_component(full_component_class_name: str, temp_directory_path: Optional[str] = None, beam_pipeline_args: Optional[List[str]] = None, **arguments): r"""Loads a component, instantiates it with arguments and runs its executor. The component class is instantiated, so the component code is executed, not just the executor code. To pass artifact URI, use <input_name>_uri argument name. To pass artifact property, use <input_name>_<property> argument name. Protobuf property values can be passed as JSON-serialized protobufs. # pylint: disable=line-too-long Example:: # When run as a script: python3 scripts/run_component.py \ --full-component-class-name tfx.components.StatisticsGen \ --examples-uri gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/examples/1/ \ --examples-split-names '["train", "eval"]' \ --output-uri gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/ # When run as a function: run_component( full_component_class_name='tfx.components.StatisticsGen', examples_uri='gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/sxamples/1/', examples_split_names='["train", "eval"]', output_uri='gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/', ) Args: full_component_class_name: The component class name including module name. temp_directory_path: Optional. Temporary directory path for the executor. beam_pipeline_args: Optional. Arguments to pass to the Beam pipeline. **arguments: Key-value pairs with component arguments. """ component_class = import_utils.import_class_by_path( full_component_class_name) component_arguments = {} for name, execution_param in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value = arguments.get(name, None) if argument_value is None: continue param_type = execution_param.type if (isinstance(param_type, type) and issubclass(param_type, message.Message)): argument_value_obj = param_type() proto_utils.json_to_proto(argument_value, argument_value_obj) elif param_type is int: argument_value_obj = int(argument_value) elif param_type is float: argument_value_obj = float(argument_value) else: argument_value_obj = argument_value component_arguments[name] = argument_value_obj for input_name, channel_param in component_class.SPEC_CLASS.INPUTS.items(): uri = (arguments.get(input_name + '_uri') or arguments.get(input_name + '_path')) if uri: artifact = channel_param.type() artifact.uri = uri # Setting the artifact properties for property_name, property_spec in (channel_param.type.PROPERTIES or {}).items(): property_arg_name = input_name + '_' + property_name if property_arg_name in arguments: property_value = arguments[property_arg_name] if property_spec.type == PropertyType.INT: property_value = int(property_value) if property_spec.type == PropertyType.FLOAT: property_value = float(property_value) setattr(artifact, property_name, property_value) component_arguments[input_name] = channel_utils.as_channel( [artifact]) component_instance = component_class(**component_arguments) input_dict = channel_utils.unwrap_channel_dict(component_instance.inputs) output_dict = channel_utils.unwrap_channel_dict(component_instance.outputs) exec_properties = component_instance.exec_properties # Generating paths for output artifacts for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items( ): uri = (arguments.get('output_' + output_name + '_uri') or arguments.get(output_name + '_uri') or arguments.get(output_name + '_path')) if uri: artifacts = output_dict[output_name] if not artifacts: artifacts.append(channel_param.type()) for artifact in artifacts: artifact.uri = uri if issubclass(component_instance.executor_spec.executor_class, base_beam_executor.BaseBeamExecutor): executor_context = base_beam_executor.BaseBeamExecutor.Context( beam_pipeline_args=beam_pipeline_args, tmp_dir=temp_directory_path, unique_id='', ) else: executor_context = base_executor.BaseExecutor.Context( extra_flags=beam_pipeline_args, tmp_dir=temp_directory_path, unique_id='', ) executor = component_instance.executor_spec.executor_class( executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) # Writing out the output artifact properties for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items( ): for property_name in channel_param.type.PROPERTIES or []: property_path_arg_name = output_name + '_' + property_name + '_path' property_path = arguments.get(property_path_arg_name) if property_path: artifacts = output_dict[output_name] for artifact in artifacts: property_value = getattr(artifact, property_name) os.makedirs(os.path.dirname(property_path), exist_ok=True) with open(property_path, 'w') as f: f.write(str(property_value))
def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> None: """Runs batch inference on a given model with given input examples. This function creates a new model (if necessary) and a new model version before inference, and cleans up resources after inference. It provides re-executability as it cleans up (only) the model resources that are created during the process even inference job failed. Args: input_dict: Input dict from input key to a list of Artifacts. - examples: examples for inference. - model: exported model. - model_blessing: model blessing result output_dict: Output dict from output key to a list of Artifacts. - output: bulk inference results. exec_properties: A dict of execution properties. - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance. - custom_config: custom_config.ai_platform_serving_args need to contain the serving job parameters sent to Google Cloud AI Platform. For the full set of parameters, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.models Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) if output_dict.get('inference_result'): inference_result = artifact_utils.get_single_instance( output_dict['inference_result']) else: inference_result = None if output_dict.get('output_examples'): output_examples = artifact_utils.get_single_instance( output_dict['output_examples']) else: output_examples = None if 'examples' not in input_dict: raise ValueError('`examples` is missing in input dict.') if 'model' not in input_dict: raise ValueError('Input models are not valid, model ' 'need to be specified.') if 'model_blessing' in input_dict: model_blessing = artifact_utils.get_single_instance( input_dict['model_blessing']) if not model_utils.is_model_blessed(model_blessing): logging.info('Model on %s was not blessed', model_blessing.uri) return else: logging.info( 'Model blessing is not provided, exported model will be ' 'used.') if _CUSTOM_CONFIG_KEY not in exec_properties: raise ValueError( 'Input exec properties are not valid, {} ' 'need to be specified.'.format(_CUSTOM_CONFIG_KEY)) custom_config = json_utils.loads( exec_properties.get(_CUSTOM_CONFIG_KEY, 'null')) if custom_config is not None and not isinstance(custom_config, Dict): raise ValueError( 'custom_config in execution properties needs to be a ' 'dict.') ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY) if not ai_platform_serving_args: raise ValueError( '`ai_platform_serving_args` is missing in `custom_config`') service_name, api_version = runner.get_service_name_and_api_version( ai_platform_serving_args) executor_class_path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__) with telemetry_utils.scoped_labels( {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}): job_labels = telemetry_utils.make_labels_dict() model = artifact_utils.get_single_instance(input_dict['model']) model_path = path_utils.serving_model_path( model.uri, path_utils.is_old_model_artifact(model)) logging.info('Use exported model from %s.', model_path) # Use model artifact uri to generate model version to guarantee the # 1:1 mapping from model version to model. model_version = 'version_' + hashlib.sha256( model.uri.encode()).hexdigest() inference_spec = self._get_inference_spec(model_path, model_version, ai_platform_serving_args) data_spec = bulk_inferrer_pb2.DataSpec() proto_utils.json_to_proto(exec_properties['data_spec'], data_spec) output_example_spec = bulk_inferrer_pb2.OutputExampleSpec() if exec_properties.get('output_example_spec'): proto_utils.json_to_proto(exec_properties['output_example_spec'], output_example_spec) endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY) if endpoint and 'regions' in ai_platform_serving_args: raise ValueError( '`endpoint` and `ai_platform_serving_args.regions` cannot be set simultaneously' ) api = discovery.build( service_name, api_version, requestBuilder=telemetry_utils.TFXHttpRequest, client_options=client_options.ClientOptions(api_endpoint=endpoint), ) new_model_endpoint_created = False try: new_model_endpoint_created = runner.create_model_for_aip_prediction_if_not_exist( job_labels, ai_platform_serving_args, api) runner.deploy_model_for_aip_prediction( serving_path=model_path, model_version_name=model_version, ai_platform_serving_args=ai_platform_serving_args, api=api, labels=job_labels, skip_model_endpoint_creation=True, set_default=False, ) self._run_model_inference(data_spec, output_example_spec, input_dict['examples'], output_examples, inference_result, inference_spec) except Exception as e: logging.error( 'Error in executing CloudAIBulkInferrerComponent: %s', str(e)) raise finally: # Guarantee newly created resources are cleaned up even if the inference # job failed. # Clean up the newly deployed model. runner.delete_model_from_aip_if_exists( model_version_name=model_version, ai_platform_serving_args=ai_platform_serving_args, api=api, delete_model_endpoint=new_model_endpoint_created)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Push model to target directory if blessed. Args: input_dict: Input dict from input key to a list of artifacts, including: - model: exported model from trainer. - model_blessing: model blessing path from model_validator. A push action delivers the model exports produced by Trainer to the destination defined in component config. output_dict: Output dict from key to a list of artifacts, including: - pushed_model: A list of 'ModelPushPath' artifact of size one. It will include the model in this push execution if the model was pushed. exec_properties: A dict of execution properties, including: - push_destination: JSON string of pusher_pb2.PushDestination instance, providing instruction of destination to push model. Returns: None """ self._log_startup(input_dict, output_dict, exec_properties) model_push = artifact_utils.get_single_instance( output_dict[standard_component_specs.PUSHED_MODEL_KEY]) if not self.CheckBlessing(input_dict): self._MarkNotPushed(model_push) return model_export = artifact_utils.get_single_instance( input_dict[standard_component_specs.MODEL_KEY]) model_path = path_utils.serving_model_path( model_export.uri, path_utils.is_old_model_artifact(model_export)) # Push model to the destination, which can be listened by a model server. # # If model is already successfully copied to outside before, stop copying. # This is because model validator might blessed same model twice (check # mv driver) with different blessing output, we still want Pusher to # handle the mv output again to keep metadata tracking, but no need to # copy to outside path again.. # TODO(jyzhao): support rpc push and verification. push_destination = pusher_pb2.PushDestination() proto_utils.json_to_proto( exec_properties[standard_component_specs.PUSH_DESTINATION_KEY], push_destination) destination_kind = push_destination.WhichOneof('destination') if destination_kind == 'filesystem': fs_config = push_destination.filesystem if fs_config.versioning == _Versioning.AUTO: fs_config.versioning = _Versioning.UNIX_TIMESTAMP if fs_config.versioning == _Versioning.UNIX_TIMESTAMP: model_version = str(int(time.time())) else: raise NotImplementedError('Invalid Versioning {}'.format( fs_config.versioning)) logging.info('Model version: %s', model_version) serving_path = os.path.join(fs_config.base_directory, model_version) if fileio.exists(serving_path): logging.info( 'Destination directory %s already exists, skipping current push.', serving_path) else: # tf.serving won't load partial model, it will retry until fully copied. io_utils.copy_dir(model_path, serving_path) logging.info('Model written to serving path %s.', serving_path) else: raise NotImplementedError( 'Invalid push destination {}'.format(destination_kind)) # Copy the model to pushing uri for archiving. io_utils.copy_dir(model_path, model_push.uri) self._MarkPushed(model_push, pushed_destination=serving_path, pushed_version=model_version) logging.info('Model pushed to %s.', model_push.uri)
def testDriverRunFn(self): # Create input dir. self._input_base_path = os.path.join(self._test_dir, 'input_base') fileio.makedirs(self._input_base_path) # Fake previous outputs span1_v1_split1 = os.path.join(self._input_base_path, 'span01', 'split1', 'data') io_utils.write_string_file(span1_v1_split1, 'testing11') span1_v1_split2 = os.path.join(self._input_base_path, 'span01', 'split2', 'data') io_utils.write_string_file(span1_v1_split2, 'testing12') ir_driver = driver.FileBasedDriver(self._mock_metadata) example = standard_artifacts.Examples() # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. output_dic = {standard_component_specs.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { standard_component_specs.INPUT_BASE_KEY: self._input_base_path, standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='s1', pattern='span{SPAN:2}/split1/*'), example_gen_pb2.Input.Split( name='s2', pattern='span{SPAN:2}/split2/*') ])), } result = ir_driver.run( portable_data_types.ExecutionInfo(output_dict=output_dic, exec_properties=exec_properties)) # Assert exec_properties' values exec_properties = result.exec_properties self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[ standard_component_specs.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ splits { name: "s1" pattern: "span01/split1/*" } splits { name: "s2" pattern: "span01/split2/*" }""", updated_input_config) self.assertRegex( exec_properties[utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values self.assertLen( result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts, 1) output_example = result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[ utils.SPAN_PROPERTY_NAME].string_value, '1') self.assertRegex( output_example.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value, r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' )
def get_common_fn_args(input_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any], working_dir: Text = None) -> FnArgs: """Get common args of training and tuning.""" if input_dict.get(standard_component_specs.TRANSFORM_GRAPH_KEY): transform_graph_path = artifact_utils.get_single_uri( input_dict[standard_component_specs.TRANSFORM_GRAPH_KEY]) else: transform_graph_path = None if input_dict.get(standard_component_specs.SCHEMA_KEY): schema_path = io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri( input_dict[standard_component_specs.SCHEMA_KEY])) else: schema_path = None train_args = trainer_pb2.TrainArgs() eval_args = trainer_pb2.EvalArgs() proto_utils.json_to_proto( exec_properties[standard_component_specs.TRAIN_ARGS_KEY], train_args) proto_utils.json_to_proto( exec_properties[standard_component_specs.EVAL_ARGS_KEY], eval_args) # Default behavior is train on `train` split (when splits is empty in train # args) and evaluate on `eval` split (when splits is empty in eval args). if not train_args.splits: train_args.splits.append('train') absl.logging.info("Train on the 'train' split when train_args.splits is " 'not set.') if not eval_args.splits: eval_args.splits.append('eval') absl.logging.info("Evaluate on the 'eval' split when eval_args.splits is " 'not set.') train_files = [] for train_split in train_args.splits: train_files.extend([ io_utils.all_files_pattern(uri) for uri in artifact_utils.get_split_uris( input_dict[standard_component_specs.EXAMPLES_KEY], train_split) ]) eval_files = [] for eval_split in eval_args.splits: eval_files.extend([ io_utils.all_files_pattern(uri) for uri in artifact_utils.get_split_uris( input_dict[standard_component_specs.EXAMPLES_KEY], eval_split) ]) data_accessor = DataAccessor( tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact( input_dict[standard_component_specs.EXAMPLES_KEY], _TELEMETRY_DESCRIPTORS), record_batch_factory=tfxio_utils.get_record_batch_factory_from_artifact( input_dict[standard_component_specs.EXAMPLES_KEY], _TELEMETRY_DESCRIPTORS), data_view_decode_fn=tfxio_utils.get_data_view_decode_fn_from_artifact( input_dict[standard_component_specs.EXAMPLES_KEY], _TELEMETRY_DESCRIPTORS) ) # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with # num_steps=None. Conversion of the proto to python will set the default # value of an int as 0 so modify the value here. Tensorflow will raise an # error if num_steps <= 0. train_steps = train_args.num_steps or None eval_steps = eval_args.num_steps or None # Load and deserialize custom config from execution properties. # Note that in the component interface the default serialization of custom # config is 'null' instead of '{}'. Therefore we need to default the # json_utils.loads to 'null' then populate it with an empty dict when # needed. custom_config = json_utils.loads( exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null')) return FnArgs( working_dir=working_dir, train_files=train_files, eval_files=eval_files, train_steps=train_steps, eval_steps=eval_steps, schema_path=schema_path, transform_graph_path=transform_graph_path, data_accessor=data_accessor, custom_config=custom_config, )
def test_json_to_proto(self): json_str = '{"obsolete_field":2,"string_value":"x"}' result = proto_utils.json_to_proto(json_str, foo_pb2.TestProto()) self.assertEqual(result, foo_pb2.TestProto(string_value='x')) # Make sure that returned type is not message.Message self.assertEqual(result.string_value, 'x')