def resolve_input_artifacts( self, input_channels: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, ) -> Dict[Text, List[types.Artifact]]: """Overrides BaseDriver.resolve_input_artifacts().""" del driver_args # unused del pipeline_info # unused input_config = example_gen_pb2.Input() json_format.Parse(exec_properties['input_config'], input_config) input_dict = channel_utils.unwrap_channel_dict(input_channels) for input_list in input_dict.values(): for single_input in input_list: absl.logging.debug('Processing input %s.' % single_input.uri) absl.logging.debug('single_input %s.' % single_input) absl.logging.debug('single_input.artifact %s.' % single_input.artifact) # Set the fingerprint of input. split_fingerprints = [] select_span = None for split in input_config.splits: # If SPAN is specified, pipeline will process the latest span, note # that this span number must be the same for all splits and it will # be stored in metadata as the span of input artifact. if _SPAN_SPEC in split.pattern: latest_span = self._retrieve_latest_span( single_input.uri, split) if select_span is None: select_span = latest_span if select_span != latest_span: raise ValueError( 'Latest span should be the same for each split: %s != %s' % (select_span, latest_span)) split.pattern = split.pattern.replace( _SPAN_SPEC, select_span) pattern = os.path.join(single_input.uri, split.pattern) split_fingerprints.append( io_utils.generate_fingerprint(split.name, pattern)) fingerprint = '\n'.join(split_fingerprints) single_input.set_string_custom_property( _FINGERPRINT, fingerprint) if select_span is None: select_span = '0' single_input.set_string_custom_property(_SPAN, select_span) matched_artifacts = [] for artifact in self._metadata_handler.get_artifacts_by_uri( single_input.uri): if (artifact.custom_properties[_FINGERPRINT].string_value == fingerprint) and ( artifact.custom_properties[_SPAN].string_value == select_span): matched_artifacts.append(artifact) if matched_artifacts: # TODO(b/138845899): consider use span instead of id. # If there are multiple matches, get the latest one for caching. # Using id because spans are the same for matched artifacts. latest_artifact = max(matched_artifacts, key=lambda artifact: artifact.id) absl.logging.debug('latest_artifact %s.' % (latest_artifact)) absl.logging.debug('type(latest_artifact) %s.' % type(latest_artifact)) single_input.set_artifact(latest_artifact) else: # TODO(jyzhao): whether driver should be read-only for metadata. [new_artifact] = self._metadata_handler.publish_artifacts( [single_input]) # pylint: disable=unbalanced-tuple-unpacking absl.logging.debug('Registered new input: %s' % (new_artifact)) single_input.set_artifact(new_artifact) exec_properties['input_config'] = json_format.MessageToJson( input_config, sort_keys=True) return input_dict
def run_component(full_component_class_name: str, temp_directory_path: Optional[str] = None, beam_pipeline_args: Optional[List[str]] = None, **arguments): r"""Loads a component, instantiates it with arguments and runs its executor. The component class is instantiated, so the component code is executed, not just the executor code. To pass artifact URI, use <input_name>_uri argument name. To pass artifact property, use <input_name>_<property> argument name. Protobuf property values can be passed as JSON-serialized protobufs. # pylint: disable=line-too-long Example:: # When run as a script: python3 scripts/run_component.py \ --full-component-class-name tfx.components.StatisticsGen \ --examples-uri gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/examples/1/ \ --examples-split-names '["train", "eval"]' \ --output-uri gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/ # When run as a function: run_component( full_component_class_name='tfx.components.StatisticsGen', examples_uri='gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/sxamples/1/', examples_split_names='["train", "eval"]', output_uri='gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/', ) Args: full_component_class_name: The component class name including module name. temp_directory_path: Optional. Temporary directory path for the executor. beam_pipeline_args: Optional. Arguments to pass to the Beam pipeline. **arguments: Key-value pairs with component arguments. """ component_class = import_utils.import_class_by_path( full_component_class_name) component_arguments = {} for name, execution_param in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value = arguments.get(name, None) if argument_value is None: continue param_type = execution_param.type if (isinstance(param_type, type) and issubclass(param_type, message.Message)): argument_value_obj = param_type() proto_utils.json_to_proto(argument_value, argument_value_obj) elif param_type is int: argument_value_obj = int(argument_value) elif param_type is float: argument_value_obj = float(argument_value) else: argument_value_obj = argument_value component_arguments[name] = argument_value_obj for input_name, channel_param in component_class.SPEC_CLASS.INPUTS.items(): uri = (arguments.get(input_name + '_uri') or arguments.get(input_name + '_path')) if uri: artifact = channel_param.type() artifact.uri = uri # Setting the artifact properties for property_name, property_spec in (channel_param.type.PROPERTIES or {}).items(): property_arg_name = input_name + '_' + property_name if property_arg_name in arguments: property_value = arguments[property_arg_name] if property_spec.type == PropertyType.INT: property_value = int(property_value) if property_spec.type == PropertyType.FLOAT: property_value = float(property_value) setattr(artifact, property_name, property_value) component_arguments[input_name] = channel_utils.as_channel( [artifact]) component_instance = component_class(**component_arguments) input_dict = channel_utils.unwrap_channel_dict(component_instance.inputs) output_dict = channel_utils.unwrap_channel_dict(component_instance.outputs) exec_properties = component_instance.exec_properties # Generating paths for output artifacts for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items( ): uri = (arguments.get('output_' + output_name + '_uri') or arguments.get(output_name + '_uri') or arguments.get(output_name + '_path')) if uri: artifacts = output_dict[output_name] if not artifacts: artifacts.append(channel_param.type()) for artifact in artifacts: artifact.uri = uri if issubclass(component_instance.executor_spec.executor_class, base_beam_executor.BaseBeamExecutor): executor_context = base_beam_executor.BaseBeamExecutor.Context( beam_pipeline_args=beam_pipeline_args, tmp_dir=temp_directory_path, unique_id='', ) else: executor_context = base_executor.BaseExecutor.Context( extra_flags=beam_pipeline_args, tmp_dir=temp_directory_path, unique_id='', ) executor = component_instance.executor_spec.executor_class( executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) # Writing out the output artifact properties for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items( ): for property_name in channel_param.type.PROPERTIES or []: property_path_arg_name = output_name + '_' + property_name + '_path' property_path = arguments.get(property_path_arg_name) if property_path: artifacts = output_dict[output_name] for artifact in artifacts: property_value = getattr(artifact, property_name) os.makedirs(os.path.dirname(property_path), exist_ok=True) with open(property_path, 'w') as f: f.write(str(property_value))
def StatisticsGen( examples_uri: 'ExamplesUri', output_statistics_uri: 'ExampleStatisticsUri', schema_uri: 'SchemaUri' = None, exclude_splits: str = None, beam_pipeline_args: list = None, ) -> NamedTuple('Outputs', [ ('statistics_uri', 'ExampleStatisticsUri'), ]): from tfx.components.statistics_gen.component import StatisticsGen as component_class #Generated code import os import tempfile from tensorflow.io import gfile from google.protobuf import json_format, message from tfx.types import channel_utils, artifact_utils from tfx.components.base import base_executor arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items( ): argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass( parameter_type, message.Message): argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) else: argument_value_obj = argument_value component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments.get(name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path.rstrip( '/' ) + '/' # Some TFX components require that the artifact URIs end with a slash if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = gfile.listdir(artifact_path) # Workaround for https://github.com/tensorflow/tensorflow/issues/39167 subdirs = [subdir.rstrip('/') for subdir in subdirs] split_names = [ subdir.replace('Split-', '') for subdir in subdirs ] artifact.split_names = artifact_utils.encode_split_names( sorted(split_names)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = channel_utils.unwrap_channel_dict( component_class_instance.inputs.get_all()) output_dict = {} exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, channel in component_class_instance.outputs.items(): artifact_path = arguments.get('output_' + name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel.type() artifact.uri = artifact_path.rstrip( '/' ) + '/' # Some TFX components require that the artifact URIs end with a slash artifact_list = [artifact] channel._artifacts = artifact_list output_dict[name] = artifact_list print('component instance: ' + str(component_class_instance)) executor_context = base_executor.BaseExecutor.Context( beam_pipeline_args=arguments.get('beam_pipeline_args'), tmp_dir=tempfile.gettempdir(), unique_id='tfx_component', ) executor = component_class_instance.executor_spec.executor_class( executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) return (output_statistics_uri, )
def unwrap_channel_dict( channel_dict: Dict[Text, Channel]) -> Dict[Text, List[types.Artifact]]: return channel_utils.unwrap_channel_dict(channel_dict)
def run_component( full_component_class_name: Text, temp_directory_path: Text = None, beam_pipeline_args: List[Text] = None, **arguments ): r"""Loads a component, instantiates it with arguments and runs its executor. The component class is instantiated, so the component code is executed, not just the executor code. To pass artifact URI, use <input_name>_uri argument name. To pass artifact property, use <input_name>_<property> argument name. Protobuf property values can be passed as JSON-serialized protobufs. # pylint: disable=line-too-long Example:: # When run as a script: python3 scripts/run_component.py \ --full-component-class-name tfx.components.StatisticsGen \ --examples-uri gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/examples/1/ \ --examples-split-names '["train", "eval"]' \ --output-uri gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/ # When run as a function: run_component( full_component_class_name='tfx.components.StatisticsGen', examples_uri='gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/sxamples/1/', examples_split_names='["train", "eval"]', output_uri='gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/', ) Args: full_component_class_name: The component class name including module name. temp_directory_path: Optional. Temporary directory path for the executor. beam_pipeline_args: Optional. Arguments to pass to the Beam pipeline. **arguments: Key-value pairs with component arguments. """ component_class = import_utils.import_class_by_path(full_component_class_name) component_arguments = {} for name, execution_param in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value = arguments.get(name, None) if argument_value is None: continue param_type = execution_param.type if (isinstance(param_type, type) and issubclass(param_type, message.Message)): argument_value_obj = param_type() json_format.Parse(argument_value, argument_value_obj) else: argument_value_obj = argument_value component_arguments[name] = argument_value_obj for input_name, channel_param in component_class.SPEC_CLASS.INPUTS.items(): uri = (arguments.get(input_name + '_uri') or arguments.get(input_name + '_path')) if uri: artifact = channel_param.type() artifact.uri = uri # Setting the artifact properties for property_name in channel_param.type.PROPERTIES: property_arg_name = input_name + '_' + property_name if property_arg_name in arguments: setattr(artifact, property_name, arguments[property_arg_name]) component_arguments[input_name] = channel_utils.as_channel([artifact]) component_instance = component_class(**component_arguments) input_dict = channel_utils.unwrap_channel_dict( component_instance.inputs.get_all()) output_dict = channel_utils.unwrap_channel_dict( component_instance.outputs.get_all()) exec_properties = component_instance.exec_properties # Generating paths for output artifacts for output_name, artifacts in output_dict.items(): uri = (arguments.get('output_' + output_name + '_uri') or arguments.get(output_name + '_uri') or arguments.get(output_name + '_path')) if uri: for artifact in artifacts: artifact.uri = uri executor_context = base_executor.BaseExecutor.Context( beam_pipeline_args=beam_pipeline_args, tmp_dir=temp_directory_path, unique_id='', ) executor = component_instance.executor_spec.executor_class(executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, )
def Trainer( examples_uri: 'ExamplesUri', schema_uri: 'SchemaUri', output_model_uri: 'ModelUri', train_args: {'JsonObject': {'data_type': 'proto:tfx.components.trainer.TrainArgs'}}, eval_args: {'JsonObject': {'data_type': 'proto:tfx.components.trainer.EvalArgs'}}, transform_graph_uri: 'TransformGraphUri' = None, base_model_uri: 'ModelUri' = None, hyperparameters_uri: 'HyperParametersUri' = None, module_file: str = None, run_fn: str = None, trainer_fn: str = None, custom_config: dict = None, beam_pipeline_args: list = None, ) -> NamedTuple('Outputs', [ ('model_uri', 'ModelUri'), ]): from tfx.components import Trainer as component_class #Generated code import json import os import tempfile import tensorflow from google.protobuf import json_format, message from tfx.types import channel_utils, artifact_utils from tfx.components.base import base_executor arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(): argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message): argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) else: argument_value_obj = argument_value component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments.get(name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path.rstrip('/') + '/' # Some TFX components require that the artifact URIs end with a slash if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = tensorflow.io.gfile.listdir(artifact_path) # Workaround for https://github.com/tensorflow/tensorflow/issues/39167 subdirs = [subdir.rstrip('/') for subdir in subdirs] artifact.split_names = artifact_utils.encode_split_names(sorted(subdirs)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = channel_utils.unwrap_channel_dict(component_class_instance.inputs.get_all()) output_dict = channel_utils.unwrap_channel_dict(component_class_instance.outputs.get_all()) exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, artifacts in output_dict.items(): base_artifact_path = arguments.get('output_' + name + '_uri') or arguments.get(name + '_path') if base_artifact_path: # Are there still cases where output channel has multiple artifacts? for idx, artifact in enumerate(artifacts): subdir = str(idx + 1) if idx > 0 else '' artifact.uri = os.path.join(base_artifact_path, subdir) # Ends with '/' print('component instance: ' + str(component_class_instance)) # Workaround for a TFX+Beam bug to make DataflowRunner work. # Remove after the next release that has https://github.com/tensorflow/tfx/commit/ddb01c02426d59e8bd541e3fd3cbaaf68779b2df import tfx tfx.version.__version__ += 'dev' executor_context = base_executor.BaseExecutor.Context( beam_pipeline_args=beam_pipeline_args, tmp_dir=tempfile.gettempdir(), unique_id='tfx_component', ) executor = component_class_instance.executor_spec.executor_class(executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, ) return (output_model_uri, )
def Transform( examples_path: InputPath('Examples'), schema_path: InputPath('Schema'), transform_graph_path: OutputPath('TransformGraph'), transformed_examples_path: OutputPath('Examples'), updated_analyzer_cache_path: OutputPath('TransformCache'), analyzer_cache_path: InputPath('TransformCache') = None, module_file: str = None, preprocessing_fn: str = None, force_tf_compat_v1: int = None, custom_config: str = None, splits_config: { 'JsonObject': { 'data_type': 'proto:tfx.components.transform.SplitsConfig' } } = None, ): from tfx.components.transform.component import Transform as component_class #Generated code import os import tempfile from tensorflow.io import gfile from google.protobuf import json_format, message from tfx.types import channel_utils, artifact_utils from tfx.components.base import base_executor arguments = locals().copy() component_class_args = {} for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items( ): argument_value = arguments.get(name, None) if argument_value is None: continue parameter_type = execution_parameter.type if isinstance(parameter_type, type) and issubclass( parameter_type, message.Message): argument_value_obj = parameter_type() json_format.Parse(argument_value, argument_value_obj) else: argument_value_obj = argument_value component_class_args[name] = argument_value_obj for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items(): artifact_path = arguments.get(name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel_parameter.type() artifact.uri = artifact_path.rstrip( '/' ) + '/' # Some TFX components require that the artifact URIs end with a slash if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES: # Recovering splits subdirs = gfile.listdir(artifact_path) # Workaround for https://github.com/tensorflow/tensorflow/issues/39167 subdirs = [subdir.rstrip('/') for subdir in subdirs] split_names = [ subdir.replace('Split-', '') for subdir in subdirs ] artifact.split_names = artifact_utils.encode_split_names( sorted(split_names)) component_class_args[name] = channel_utils.as_channel([artifact]) component_class_instance = component_class(**component_class_args) input_dict = channel_utils.unwrap_channel_dict( component_class_instance.inputs.get_all()) output_dict = {} exec_properties = component_class_instance.exec_properties # Generating paths for output artifacts for name, channel in component_class_instance.outputs.items(): artifact_path = arguments.get('output_' + name + '_uri') or arguments.get(name + '_path') if artifact_path: artifact = channel.type() artifact.uri = artifact_path.rstrip( '/' ) + '/' # Some TFX components require that the artifact URIs end with a slash artifact_list = [artifact] channel._artifacts = artifact_list output_dict[name] = artifact_list print('component instance: ' + str(component_class_instance)) executor_context = base_executor.BaseExecutor.Context( beam_pipeline_args=arguments.get('beam_pipeline_args'), tmp_dir=tempfile.gettempdir(), unique_id='tfx_component', ) executor = component_class_instance.executor_spec.executor_class( executor_context) executor.Do( input_dict=input_dict, output_dict=output_dict, exec_properties=exec_properties, )