Exemple #1
0
    def resolve_input_artifacts(
        self,
        input_channels: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
    ) -> Dict[Text, List[types.Artifact]]:
        """Overrides BaseDriver.resolve_input_artifacts()."""
        del driver_args  # unused
        del pipeline_info  # unused

        input_config = example_gen_pb2.Input()
        json_format.Parse(exec_properties['input_config'], input_config)

        input_dict = channel_utils.unwrap_channel_dict(input_channels)
        for input_list in input_dict.values():
            for single_input in input_list:
                absl.logging.debug('Processing input %s.' % single_input.uri)
                absl.logging.debug('single_input %s.' % single_input)
                absl.logging.debug('single_input.artifact %s.' %
                                   single_input.artifact)

                # Set the fingerprint of input.
                split_fingerprints = []
                select_span = None
                for split in input_config.splits:
                    # If SPAN is specified, pipeline will process the latest span, note
                    # that this span number must be the same for all splits and it will
                    # be stored in metadata as the span of input artifact.
                    if _SPAN_SPEC in split.pattern:
                        latest_span = self._retrieve_latest_span(
                            single_input.uri, split)
                        if select_span is None:
                            select_span = latest_span
                        if select_span != latest_span:
                            raise ValueError(
                                'Latest span should be the same for each split: %s != %s'
                                % (select_span, latest_span))
                        split.pattern = split.pattern.replace(
                            _SPAN_SPEC, select_span)

                    pattern = os.path.join(single_input.uri, split.pattern)
                    split_fingerprints.append(
                        io_utils.generate_fingerprint(split.name, pattern))
                fingerprint = '\n'.join(split_fingerprints)
                single_input.set_string_custom_property(
                    _FINGERPRINT, fingerprint)
                if select_span is None:
                    select_span = '0'
                single_input.set_string_custom_property(_SPAN, select_span)

                matched_artifacts = []
                for artifact in self._metadata_handler.get_artifacts_by_uri(
                        single_input.uri):
                    if (artifact.custom_properties[_FINGERPRINT].string_value
                            == fingerprint) and (
                                artifact.custom_properties[_SPAN].string_value
                                == select_span):
                        matched_artifacts.append(artifact)

                if matched_artifacts:
                    # TODO(b/138845899): consider use span instead of id.
                    # If there are multiple matches, get the latest one for caching.
                    # Using id because spans are the same for matched artifacts.
                    latest_artifact = max(matched_artifacts,
                                          key=lambda artifact: artifact.id)
                    absl.logging.debug('latest_artifact %s.' %
                                       (latest_artifact))
                    absl.logging.debug('type(latest_artifact) %s.' %
                                       type(latest_artifact))

                    single_input.set_artifact(latest_artifact)
                else:
                    # TODO(jyzhao): whether driver should be read-only for metadata.
                    [new_artifact] = self._metadata_handler.publish_artifacts(
                        [single_input])  # pylint: disable=unbalanced-tuple-unpacking
                    absl.logging.debug('Registered new input: %s' %
                                       (new_artifact))
                    single_input.set_artifact(new_artifact)

        exec_properties['input_config'] = json_format.MessageToJson(
            input_config, sort_keys=True)
        return input_dict
Exemple #2
0
def run_component(full_component_class_name: str,
                  temp_directory_path: Optional[str] = None,
                  beam_pipeline_args: Optional[List[str]] = None,
                  **arguments):
    r"""Loads a component, instantiates it with arguments and runs its executor.

  The component class is instantiated, so the component code is executed,
  not just the executor code.

  To pass artifact URI, use <input_name>_uri argument name.
  To pass artifact property, use <input_name>_<property> argument name.
  Protobuf property values can be passed as JSON-serialized protobufs.

  # pylint: disable=line-too-long

  Example::

    # When run as a script:
    python3 scripts/run_component.py \
      --full-component-class-name tfx.components.StatisticsGen \
      --examples-uri gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/examples/1/ \
      --examples-split-names '["train", "eval"]' \
      --output-uri gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/

    # When run as a function:
    run_component(
      full_component_class_name='tfx.components.StatisticsGen',
      examples_uri='gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/sxamples/1/',
      examples_split_names='["train", "eval"]',
      output_uri='gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/',
    )

  Args:
    full_component_class_name: The component class name including module name.
    temp_directory_path: Optional. Temporary directory path for the executor.
    beam_pipeline_args: Optional. Arguments to pass to the Beam pipeline.
    **arguments: Key-value pairs with component arguments.
  """
    component_class = import_utils.import_class_by_path(
        full_component_class_name)

    component_arguments = {}

    for name, execution_param in component_class.SPEC_CLASS.PARAMETERS.items():
        argument_value = arguments.get(name, None)
        if argument_value is None:
            continue
        param_type = execution_param.type
        if (isinstance(param_type, type)
                and issubclass(param_type, message.Message)):
            argument_value_obj = param_type()
            proto_utils.json_to_proto(argument_value, argument_value_obj)
        elif param_type is int:
            argument_value_obj = int(argument_value)
        elif param_type is float:
            argument_value_obj = float(argument_value)
        else:
            argument_value_obj = argument_value
        component_arguments[name] = argument_value_obj

    for input_name, channel_param in component_class.SPEC_CLASS.INPUTS.items():
        uri = (arguments.get(input_name + '_uri')
               or arguments.get(input_name + '_path'))
        if uri:
            artifact = channel_param.type()
            artifact.uri = uri
            # Setting the artifact properties
            for property_name, property_spec in (channel_param.type.PROPERTIES
                                                 or {}).items():
                property_arg_name = input_name + '_' + property_name
                if property_arg_name in arguments:
                    property_value = arguments[property_arg_name]
                    if property_spec.type == PropertyType.INT:
                        property_value = int(property_value)
                    if property_spec.type == PropertyType.FLOAT:
                        property_value = float(property_value)
                    setattr(artifact, property_name, property_value)
            component_arguments[input_name] = channel_utils.as_channel(
                [artifact])

    component_instance = component_class(**component_arguments)

    input_dict = channel_utils.unwrap_channel_dict(component_instance.inputs)
    output_dict = channel_utils.unwrap_channel_dict(component_instance.outputs)
    exec_properties = component_instance.exec_properties

    # Generating paths for output artifacts
    for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items(
    ):
        uri = (arguments.get('output_' + output_name + '_uri')
               or arguments.get(output_name + '_uri')
               or arguments.get(output_name + '_path'))
        if uri:
            artifacts = output_dict[output_name]
            if not artifacts:
                artifacts.append(channel_param.type())
            for artifact in artifacts:
                artifact.uri = uri

    if issubclass(component_instance.executor_spec.executor_class,
                  base_beam_executor.BaseBeamExecutor):
        executor_context = base_beam_executor.BaseBeamExecutor.Context(
            beam_pipeline_args=beam_pipeline_args,
            tmp_dir=temp_directory_path,
            unique_id='',
        )
    else:
        executor_context = base_executor.BaseExecutor.Context(
            extra_flags=beam_pipeline_args,
            tmp_dir=temp_directory_path,
            unique_id='',
        )
    executor = component_instance.executor_spec.executor_class(
        executor_context)
    executor.Do(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
    )

    # Writing out the output artifact properties
    for output_name, channel_param in component_class.SPEC_CLASS.OUTPUTS.items(
    ):
        for property_name in channel_param.type.PROPERTIES or []:
            property_path_arg_name = output_name + '_' + property_name + '_path'
            property_path = arguments.get(property_path_arg_name)
            if property_path:
                artifacts = output_dict[output_name]
                for artifact in artifacts:
                    property_value = getattr(artifact, property_name)
                    os.makedirs(os.path.dirname(property_path), exist_ok=True)
                    with open(property_path, 'w') as f:
                        f.write(str(property_value))
Exemple #3
0
def StatisticsGen(
    examples_uri: 'ExamplesUri',
    output_statistics_uri: 'ExampleStatisticsUri',
    schema_uri: 'SchemaUri' = None,
    exclude_splits: str = None,
    beam_pipeline_args: list = None,
) -> NamedTuple('Outputs', [
    ('statistics_uri', 'ExampleStatisticsUri'),
]):
    from tfx.components.statistics_gen.component import StatisticsGen as component_class

    #Generated code
    import os
    import tempfile
    from tensorflow.io import gfile
    from google.protobuf import json_format, message
    from tfx.types import channel_utils, artifact_utils
    from tfx.components.base import base_executor

    arguments = locals().copy()

    component_class_args = {}

    for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(
    ):
        argument_value = arguments.get(name, None)
        if argument_value is None:
            continue
        parameter_type = execution_parameter.type
        if isinstance(parameter_type, type) and issubclass(
                parameter_type, message.Message):
            argument_value_obj = parameter_type()
            json_format.Parse(argument_value, argument_value_obj)
        else:
            argument_value_obj = argument_value
        component_class_args[name] = argument_value_obj

    for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items():
        artifact_path = arguments.get(name + '_uri') or arguments.get(name +
                                                                      '_path')
        if artifact_path:
            artifact = channel_parameter.type()
            artifact.uri = artifact_path.rstrip(
                '/'
            ) + '/'  # Some TFX components require that the artifact URIs end with a slash
            if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES:
                # Recovering splits
                subdirs = gfile.listdir(artifact_path)
                # Workaround for https://github.com/tensorflow/tensorflow/issues/39167
                subdirs = [subdir.rstrip('/') for subdir in subdirs]
                split_names = [
                    subdir.replace('Split-', '') for subdir in subdirs
                ]
                artifact.split_names = artifact_utils.encode_split_names(
                    sorted(split_names))
            component_class_args[name] = channel_utils.as_channel([artifact])

    component_class_instance = component_class(**component_class_args)

    input_dict = channel_utils.unwrap_channel_dict(
        component_class_instance.inputs.get_all())
    output_dict = {}
    exec_properties = component_class_instance.exec_properties

    # Generating paths for output artifacts
    for name, channel in component_class_instance.outputs.items():
        artifact_path = arguments.get('output_' + name +
                                      '_uri') or arguments.get(name + '_path')
        if artifact_path:
            artifact = channel.type()
            artifact.uri = artifact_path.rstrip(
                '/'
            ) + '/'  # Some TFX components require that the artifact URIs end with a slash
            artifact_list = [artifact]
            channel._artifacts = artifact_list
            output_dict[name] = artifact_list

    print('component instance: ' + str(component_class_instance))

    executor_context = base_executor.BaseExecutor.Context(
        beam_pipeline_args=arguments.get('beam_pipeline_args'),
        tmp_dir=tempfile.gettempdir(),
        unique_id='tfx_component',
    )
    executor = component_class_instance.executor_spec.executor_class(
        executor_context)
    executor.Do(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
    )

    return (output_statistics_uri, )
Exemple #4
0
def unwrap_channel_dict(
    channel_dict: Dict[Text, Channel]) -> Dict[Text, List[types.Artifact]]:
  return channel_utils.unwrap_channel_dict(channel_dict)
Exemple #5
0
def run_component(
    full_component_class_name: Text,
    temp_directory_path: Text = None,
    beam_pipeline_args: List[Text] = None,
    **arguments
):
  r"""Loads a component, instantiates it with arguments and runs its executor.

  The component class is instantiated, so the component code is executed,
  not just the executor code.

  To pass artifact URI, use <input_name>_uri argument name.
  To pass artifact property, use <input_name>_<property> argument name.
  Protobuf property values can be passed as JSON-serialized protobufs.

  # pylint: disable=line-too-long

  Example::

    # When run as a script:
    python3 scripts/run_component.py \
      --full-component-class-name tfx.components.StatisticsGen \
      --examples-uri gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/examples/1/ \
      --examples-split-names '["train", "eval"]' \
      --output-uri gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/

    # When run as a function:
    run_component(
      full_component_class_name='tfx.components.StatisticsGen',
      examples_uri='gs://my_bucket/chicago_taxi_simple/CsvExamplesGen/sxamples/1/',
      examples_split_names='["train", "eval"]',
      output_uri='gs://my_bucket/chicago_taxi_simple/StatisticsGen/output/1/',
    )

  Args:
    full_component_class_name: The component class name including module name.
    temp_directory_path: Optional. Temporary directory path for the executor.
    beam_pipeline_args: Optional. Arguments to pass to the Beam pipeline.
    **arguments: Key-value pairs with component arguments.
  """
  component_class = import_utils.import_class_by_path(full_component_class_name)

  component_arguments = {}

  for name, execution_param in component_class.SPEC_CLASS.PARAMETERS.items():
    argument_value = arguments.get(name, None)
    if argument_value is None:
      continue
    param_type = execution_param.type
    if (isinstance(param_type, type) and
        issubclass(param_type, message.Message)):
      argument_value_obj = param_type()
      json_format.Parse(argument_value, argument_value_obj)
    else:
      argument_value_obj = argument_value
    component_arguments[name] = argument_value_obj

  for input_name, channel_param in component_class.SPEC_CLASS.INPUTS.items():
    uri = (arguments.get(input_name + '_uri') or
           arguments.get(input_name + '_path'))
    if uri:
      artifact = channel_param.type()
      artifact.uri = uri
      # Setting the artifact properties
      for property_name in channel_param.type.PROPERTIES:
        property_arg_name = input_name + '_' + property_name
        if property_arg_name in arguments:
          setattr(artifact, property_name, arguments[property_arg_name])
      component_arguments[input_name] = channel_utils.as_channel([artifact])

  component_instance = component_class(**component_arguments)

  input_dict = channel_utils.unwrap_channel_dict(
      component_instance.inputs.get_all())
  output_dict = channel_utils.unwrap_channel_dict(
      component_instance.outputs.get_all())
  exec_properties = component_instance.exec_properties

  # Generating paths for output artifacts
  for output_name, artifacts in output_dict.items():
    uri = (arguments.get('output_' + output_name + '_uri') or
           arguments.get(output_name + '_uri') or
           arguments.get(output_name + '_path'))
    if uri:
      for artifact in artifacts:
        artifact.uri = uri

  executor_context = base_executor.BaseExecutor.Context(
      beam_pipeline_args=beam_pipeline_args,
      tmp_dir=temp_directory_path,
      unique_id='',
  )
  executor = component_instance.executor_spec.executor_class(executor_context)
  executor.Do(
      input_dict=input_dict,
      output_dict=output_dict,
      exec_properties=exec_properties,
  )
Exemple #6
0
def Trainer(
    examples_uri: 'ExamplesUri',
    schema_uri: 'SchemaUri',
    output_model_uri: 'ModelUri',
    train_args: {'JsonObject': {'data_type': 'proto:tfx.components.trainer.TrainArgs'}},
    eval_args: {'JsonObject': {'data_type': 'proto:tfx.components.trainer.EvalArgs'}},
    transform_graph_uri: 'TransformGraphUri' = None,
    base_model_uri: 'ModelUri' = None,
    hyperparameters_uri: 'HyperParametersUri' = None,
    module_file: str = None,
    run_fn: str = None,
    trainer_fn: str = None,
    custom_config: dict = None,
    beam_pipeline_args: list = None,
) -> NamedTuple('Outputs', [
    ('model_uri', 'ModelUri'),
]):
    from tfx.components import Trainer as component_class

    #Generated code
    import json
    import os
    import tempfile
    import tensorflow
    from google.protobuf import json_format, message
    from tfx.types import channel_utils, artifact_utils
    from tfx.components.base import base_executor

    arguments = locals().copy()

    component_class_args = {}

    for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items():
        argument_value = arguments.get(name, None)
        if argument_value is None:
            continue
        parameter_type = execution_parameter.type
        if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message):
            argument_value_obj = parameter_type()
            json_format.Parse(argument_value, argument_value_obj)
        else:
            argument_value_obj = argument_value
        component_class_args[name] = argument_value_obj

    for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items():
        artifact_path = arguments.get(name + '_uri') or arguments.get(name + '_path')
        if artifact_path:
            artifact = channel_parameter.type()
            artifact.uri = artifact_path.rstrip('/') + '/'  # Some TFX components require that the artifact URIs end with a slash
            if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES:
                # Recovering splits
                subdirs = tensorflow.io.gfile.listdir(artifact_path)
                # Workaround for https://github.com/tensorflow/tensorflow/issues/39167
                subdirs = [subdir.rstrip('/') for subdir in subdirs]
                artifact.split_names = artifact_utils.encode_split_names(sorted(subdirs))
            component_class_args[name] = channel_utils.as_channel([artifact])

    component_class_instance = component_class(**component_class_args)

    input_dict = channel_utils.unwrap_channel_dict(component_class_instance.inputs.get_all())
    output_dict = channel_utils.unwrap_channel_dict(component_class_instance.outputs.get_all())
    exec_properties = component_class_instance.exec_properties

    # Generating paths for output artifacts
    for name, artifacts in output_dict.items():
        base_artifact_path = arguments.get('output_' + name + '_uri') or arguments.get(name + '_path')
        if base_artifact_path:
            # Are there still cases where output channel has multiple artifacts?
            for idx, artifact in enumerate(artifacts):
                subdir = str(idx + 1) if idx > 0 else ''
                artifact.uri = os.path.join(base_artifact_path, subdir)  # Ends with '/'

    print('component instance: ' + str(component_class_instance))

    # Workaround for a TFX+Beam bug to make DataflowRunner work.
    # Remove after the next release that has https://github.com/tensorflow/tfx/commit/ddb01c02426d59e8bd541e3fd3cbaaf68779b2df
    import tfx
    tfx.version.__version__ += 'dev'

    executor_context = base_executor.BaseExecutor.Context(
        beam_pipeline_args=beam_pipeline_args,
        tmp_dir=tempfile.gettempdir(),
        unique_id='tfx_component',
    )
    executor = component_class_instance.executor_spec.executor_class(executor_context)
    executor.Do(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
    )

    return (output_model_uri, )
Exemple #7
0
def Transform(
    examples_path: InputPath('Examples'),
    schema_path: InputPath('Schema'),
    transform_graph_path: OutputPath('TransformGraph'),
    transformed_examples_path: OutputPath('Examples'),
    updated_analyzer_cache_path: OutputPath('TransformCache'),
    analyzer_cache_path: InputPath('TransformCache') = None,
    module_file: str = None,
    preprocessing_fn: str = None,
    force_tf_compat_v1: int = None,
    custom_config: str = None,
    splits_config: {
        'JsonObject': {
            'data_type': 'proto:tfx.components.transform.SplitsConfig'
        }
    } = None,
):
    from tfx.components.transform.component import Transform as component_class

    #Generated code
    import os
    import tempfile
    from tensorflow.io import gfile
    from google.protobuf import json_format, message
    from tfx.types import channel_utils, artifact_utils
    from tfx.components.base import base_executor

    arguments = locals().copy()

    component_class_args = {}

    for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items(
    ):
        argument_value = arguments.get(name, None)
        if argument_value is None:
            continue
        parameter_type = execution_parameter.type
        if isinstance(parameter_type, type) and issubclass(
                parameter_type, message.Message):
            argument_value_obj = parameter_type()
            json_format.Parse(argument_value, argument_value_obj)
        else:
            argument_value_obj = argument_value
        component_class_args[name] = argument_value_obj

    for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items():
        artifact_path = arguments.get(name + '_uri') or arguments.get(name +
                                                                      '_path')
        if artifact_path:
            artifact = channel_parameter.type()
            artifact.uri = artifact_path.rstrip(
                '/'
            ) + '/'  # Some TFX components require that the artifact URIs end with a slash
            if channel_parameter.type.PROPERTIES and 'split_names' in channel_parameter.type.PROPERTIES:
                # Recovering splits
                subdirs = gfile.listdir(artifact_path)
                # Workaround for https://github.com/tensorflow/tensorflow/issues/39167
                subdirs = [subdir.rstrip('/') for subdir in subdirs]
                split_names = [
                    subdir.replace('Split-', '') for subdir in subdirs
                ]
                artifact.split_names = artifact_utils.encode_split_names(
                    sorted(split_names))
            component_class_args[name] = channel_utils.as_channel([artifact])

    component_class_instance = component_class(**component_class_args)

    input_dict = channel_utils.unwrap_channel_dict(
        component_class_instance.inputs.get_all())
    output_dict = {}
    exec_properties = component_class_instance.exec_properties

    # Generating paths for output artifacts
    for name, channel in component_class_instance.outputs.items():
        artifact_path = arguments.get('output_' + name +
                                      '_uri') or arguments.get(name + '_path')
        if artifact_path:
            artifact = channel.type()
            artifact.uri = artifact_path.rstrip(
                '/'
            ) + '/'  # Some TFX components require that the artifact URIs end with a slash
            artifact_list = [artifact]
            channel._artifacts = artifact_list
            output_dict[name] = artifact_list

    print('component instance: ' + str(component_class_instance))

    executor_context = base_executor.BaseExecutor.Context(
        beam_pipeline_args=arguments.get('beam_pipeline_args'),
        tmp_dir=tempfile.gettempdir(),
        unique_id='tfx_component',
    )
    executor = component_class_instance.executor_spec.executor_class(
        executor_context)
    executor.Do(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
    )