コード例 #1
0
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    module_file: Text,
    ai_platform_training_args: Dict[Text, Text],
    ai_platform_serving_args: Dict[Text, Text],
    beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX and Kubeflow Pipelines.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: List of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.

  Returns:
    A TFX pipeline object.
  """

    # The rate at which to sample rows from the Taxi dataset using BigQuery.
    # The full taxi dataset is > 200M record.  In the interest of resource
    # savings and time, we've set the default for this example to be much smaller.
    # Feel free to crank it up and process the full dataset!
    # By default it generates a 0.1% random sample.
    query_sample_rate = data_types.RuntimeParameter(name='query_sample_rate',
                                                    ptype=float,
                                                    default=0.001)

    # This is the upper bound of FARM_FINGERPRINT in Bigquery (ie the max value of
    # signed int64).
    max_int64 = '0x7FFFFFFFFFFFFFFF'

    # The query that extracts the examples from BigQuery. The Chicago Taxi dataset
    # used for this example is a public dataset available on Google AI Platform.
    # https://console.cloud.google.com/marketplace/details/city-of-chicago-public-data/chicago-taxi-trips
    query = """
          SELECT
            pickup_community_area,
            fare,
            EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month,
            EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour,
            EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day,
            UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp,
            pickup_latitude,
            pickup_longitude,
            dropoff_latitude,
            dropoff_longitude,
            trip_miles,
            pickup_census_tract,
            dropoff_census_tract,
            payment_type,
            company,
            trip_seconds,
            dropoff_community_area,
            tips
          FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips`
          WHERE (ABS(FARM_FINGERPRINT(unique_key)) / {max_int64})
            < {query_sample_rate}""".format(
        max_int64=max_int64, query_sample_rate=str(query_sample_rate))

    # Number of epochs in training.
    train_steps = data_types.RuntimeParameter(
        name='train_steps',
        default=10000,
        ptype=int,
    )

    # Number of epochs in evaluation.
    eval_steps = data_types.RuntimeParameter(
        name='eval_steps',
        default=5000,
        ptype=int,
    )

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        query=query)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Update ai_platform_training_args if distributed training was enabled.
    # Number of worker machines used in distributed training.
    worker_count = data_types.RuntimeParameter(
        name='worker_count',
        default=2,
        ptype=int,
    )

    # Type of worker machines used in distributed training.
    worker_type = data_types.RuntimeParameter(
        name='worker_type',
        default='standard',
        ptype=str,
    )

    ai_platform_training_args = copy.copy(ai_platform_training_args)
    if FLAGS.distributed_training:
        ai_platform_training_args.update({
            # You can specify the machine types, the number of replicas for workers
            # and parameter servers.
            # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#ScaleTier
            'scaleTier': 'CUSTOM',
            'masterType': 'large_model',
            'workerType': worker_type,
            'parameterServerType': 'standard',
            'workerCount': worker_count,
            'parameterServerCount': 1
        })

    # Uses user-provided Python function that implements a model using TF-Learn
    # to train a model on Google Cloud AI Platform.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.Executor),
        module_file=module_file,
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={
            ai_platform_trainer_executor.TRAINING_ARGS_KEY:
            ai_platform_training_args
        })

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to  Google Cloud AI Platform if check passed.
    # TODO(b/162451308): Add pusher back to components list once AIP Prediction
    # Service supports TF>=2.3.
    _ = Pusher(custom_executor_spec=executor_spec.ExecutorClassSpec(
        ai_platform_pusher_executor.Executor),
               model=trainer.outputs['model'],
               model_blessing=evaluator.outputs['blessing'],
               custom_config={
                   ai_platform_pusher_executor.SERVING_ARGS_KEY:
                   ai_platform_serving_args
               })

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, schema_gen, example_validator,
            transform, trainer, model_resolver, evaluator
        ],
        beam_pipeline_args=beam_pipeline_args,
    )
コード例 #2
0
ファイル: component.py プロジェクト: vikrosj/tfx
class ImportExampleGen(component.FileBasedExampleGen):  # pylint: disable=protected-access
  """Official TFX ImportExampleGen component.

  The ImportExampleGen component takes TFRecord files with TF Example data
  format, and generates train and eval examples for downsteam components.
  This component provides consistent and configurable partition, and it also
  shuffle the dataset for ML best practice.
  """

  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(
      self,
      # TODO(b/159467778): deprecate this, use input_base instead.
      input: Optional[types.Channel] = None,  # pylint: disable=redefined-builtin
      input_base: Optional[Text] = None,
      input_config: Optional[Union[example_gen_pb2.Input, Dict[Text,
                                                               Any]]] = None,
      output_config: Optional[Union[example_gen_pb2.Output, Dict[Text,
                                                                 Any]]] = None,
      range_config: Optional[Union[range_config_pb2.RangeConfig,
                                   Dict[Text, Any]]] = None,
      payload_format: Optional[int] = example_gen_pb2.FORMAT_TF_EXAMPLE,
      example_artifacts: Optional[types.Channel] = None,
      instance_name: Optional[Text] = None):
    """Construct an ImportExampleGen component.

    Args:
       input: A Channel of type `standard_artifacts.ExternalArtifact`, which
        includes one artifact whose uri is an external directory containing the
        TFRecord files. (Deprecated by input_base)
      input_base: an external directory containing the TFRecord files.
      input_config: An example_gen_pb2.Input instance, providing input
        configuration. If unset, the files under input_base will be treated as a
        single split. If any field is provided as a RuntimeParameter,
        input_config should be constructed as a dict with the same field names
        as Input proto message.
      output_config: An example_gen_pb2.Output instance, providing output
        configuration. If unset, default splits will be 'train' and 'eval' with
        size 2:1. If any field is provided as a RuntimeParameter,
        output_config should be constructed as a dict with the same field names
        as Output proto message.
      range_config: An optional range_config_pb2.RangeConfig instance,
        specifying the range of span values to consider. If unset, driver will
        default to searching for latest span with no restrictions.
      payload_format: Payload format of input data. Should be one of
        example_gen_pb2.PayloadFormat enum. Note that payload format of output
        data is the same as input.
      example_artifacts: Optional channel of 'ExamplesPath' for output train and
        eval examples.
      instance_name: Optional unique instance name. Necessary if multiple
        ImportExampleGen components are declared in the same pipeline.
    """
    if input:
      logging.warning(
          'The "input" argument to the ImportExampleGen component has been '
          'deprecated by "input_base". Please update your usage as support for '
          'this argument will be removed soon.')
      input_base = artifact_utils.get_single_uri(list(input.get()))
    super(ImportExampleGen, self).__init__(
        input_base=input_base,
        input_config=input_config,
        output_config=output_config,
        range_config=range_config,
        example_artifacts=example_artifacts,
        output_data_format=payload_format,
        instance_name=instance_name)
コード例 #3
0
ファイル: component.py プロジェクト: vikrosj/tfx
class ExampleValidator(base_component.BaseComponent):
  """A TFX component to validate input examples.

  The ExampleValidator component uses [Tensorflow Data
  Validation](https://www.tensorflow.org/tfx/data_validation) to
  validate the statistics of some splits on input examples against a schema.

  The ExampleValidator component identifies anomalies in training and serving
  data. The component can be configured to detect different classes of anomalies
  in the data. It can:
    - perform validity checks by comparing data statistics against a schema that
      codifies expectations of the user.

  Schema Based Example Validation
  The ExampleValidator component identifies any anomalies in the example data by
  comparing data statistics computed by the StatisticsGen component against a
  schema. The schema codifies properties which the input data is expected to
  satisfy, and is provided and maintained by the user.

  Please see https://www.tensorflow.org/tfx/data_validation for more details.

  ## Example
  ```
  # Performs anomaly detection based on statistics and data schema.
  validate_stats = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=infer_schema.outputs['schema'])
  ```
  """

  SPEC_CLASS = ExampleValidatorSpec
  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(self,
               statistics: types.Channel = None,
               schema: types.Channel = None,
               exclude_splits: Optional[List[Text]] = None,
               anomalies: Optional[Text] = None,
               instance_name: Optional[Text] = None):
    """Construct an ExampleValidator component.

    Args:
      statistics: A Channel of type `standard_artifacts.ExampleStatistics`.
      schema: A Channel of type `standard_artifacts.Schema`. _required_
      exclude_splits: Names of splits that the example validator should not
        validate. Default behavior (when exclude_splits is set to None)
        is excluding no splits.
      anomalies: Output channel of type `standard_artifacts.ExampleAnomalies`.
      instance_name: Optional name assigned to this specific instance of
        ExampleValidator. Required only if multiple ExampleValidator components
        are declared in the same pipeline.  Either `stats` or `statistics` must
        be present in the arguments.
    """
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    if not anomalies:
      anomalies = types.Channel(type=standard_artifacts.ExampleAnomalies)
    spec = ExampleValidatorSpec(
        statistics=statistics,
        schema=schema,
        exclude_splits=json_utils.dumps(exclude_splits),
        anomalies=anomalies)
    super(ExampleValidator, self).__init__(
        spec=spec, instance_name=instance_name)
コード例 #4
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the Bert classication on Cola dataset pipline with TFX."""
    input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='train/*'),
        example_gen_pb2.Input.Split(name='eval', pattern='validation/*')
    ])

    # Brings data into the pipline
    example_gen = CsvExampleGen(input_base=data_root,
                                input_config=input_config)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that trains a model using TF-Learn.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        # Adjust these steps when training on the full dataset.
        train_args=trainer_pb2.TrainArgs(num_steps=2),
        eval_args=trainer_pb2.EvalArgs(num_steps=1))

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='label')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            # Adjust the threshold when training on the
                            # full dataset.
                            lower_bound={'value': 0.5}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-2})))
            ])
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    components = [
        example_gen,
        statistics_gen,
        schema_gen,
        example_validator,
        transform,
        trainer,
        model_resolver,
        evaluator,
        pusher,
    ]

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        enable_cache=True,
        beam_pipeline_args=beam_pipeline_args,
    )
コード例 #5
0
ファイル: test_utils.py プロジェクト: jay90099/tfx
def create_pipeline_components(
    pipeline_root: str,
    transform_module: str,
    trainer_module: str,
    bigquery_query: str = '',
    csv_input_location: str = '',
) -> List[base_node.BaseNode]:
    """Creates components for a simple Chicago Taxi TFX pipeline for testing.

  Args:
    pipeline_root: The root of the pipeline output.
    transform_module: The location of the transform module file.
    trainer_module: The location of the trainer module file.
    bigquery_query: The query to get input data from BigQuery. If not empty,
      BigQueryExampleGen will be used.
    csv_input_location: The location of the input data directory.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """

    if bool(bigquery_query) == bool(csv_input_location):
        raise ValueError(
            'Exactly one example gen is expected. ',
            'Please provide either bigquery_query or csv_input_location.')

    if bigquery_query:
        example_gen = tfx.extensions.google_cloud_big_query.BigQueryExampleGen(
            query=bigquery_query)
    else:
        example_gen = tfx.components.CsvExampleGen(
            input_base=csv_input_location)

    statistics_gen = tfx.components.StatisticsGen(
        examples=example_gen.outputs['examples'])
    schema_gen = tfx.components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=False)
    example_validator = tfx.components.ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    transform = tfx.components.Transform(
        examples=example_gen.outputs['examples'],
        schema=schema_gen.outputs['schema'],
        module_file=transform_module)
    latest_model_resolver = tfx.dsl.Resolver(
        strategy_class=tfx.dsl.experimental.LatestArtifactStrategy,
        model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model
                              )).with_id('Resolver.latest_model_resolver')
    trainer = tfx.components.Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        base_model=latest_model_resolver.outputs['model'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=tfx.proto.TrainArgs(num_steps=10),
        eval_args=tfx.proto.EvalArgs(num_steps=5),
        module_file=trainer_module,
    )
    # Get the latest blessed model for model validation.
    model_resolver = tfx.dsl.Resolver(
        strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
        model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
        model_blessing=tfx.dsl.Channel(
            type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
                'Resolver.latest_blessed_model_resolver')
    # Set the TFMA config for Model Evaluation and Validation.
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        metrics_specs=[
            tfma.MetricsSpec(
                metrics=[tfma.MetricConfig(class_name='ExampleCount')],
                thresholds={
                    'binary_accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.5}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ])
    evaluator = tfx.components.Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        eval_config=eval_config)

    with conditional.Cond(evaluator.outputs['blessing'].future()
                          [0].custom_property('blessed') == 1):
        pusher = tfx.components.Pusher(
            model=trainer.outputs['model'],
            push_destination=tfx.proto.
            PushDestination(filesystem=tfx.proto.PushDestination.Filesystem(
                base_directory=os.path.join(pipeline_root, 'model_serving'))))

    return [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        latest_model_resolver, trainer, model_resolver, evaluator, pusher
    ]
コード例 #6
0
ファイル: decorators_test.py プロジェクト: jeongukjae/tfx
 class _MySimpleComponent(_SimpleComponent):
   SPEC_CLASS = _BasicComponentSpec
   EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(
       base_executor.BaseExecutor)
コード例 #7
0
def create_test_pipeline():
    """Builds an Iris example pipeline with slight changes."""
    pipeline_name = "iris"
    iris_root = "iris_root"
    serving_model_dir = os.path.join(iris_root, "serving_model", pipeline_name)
    tfx_root = "tfx_root"
    data_path = os.path.join(tfx_root, "data_path")
    pipeline_root = os.path.join(tfx_root, "pipelines", pipeline_name)

    example_gen = CsvExampleGen(input=external_input(data_path))

    statistics_gen = StatisticsGen(examples=example_gen.outputs["examples"])

    importer = ImporterNode(instance_name="my_importer",
                            source_uri="m/y/u/r/i",
                            properties={
                                "split_names": "['train', 'eval']",
                            },
                            custom_properties={
                                "int_custom_property": 42,
                                "str_custom_property": "42",
                            },
                            artifact_type=standard_artifacts.Examples)

    schema_gen = SchemaGen(statistics=statistics_gen.outputs["statistics"],
                           infer_feature_shape=True)

    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs["statistics"],
        schema=schema_gen.outputs["schema"])

    trainer = Trainer(
        # Use RuntimeParameter as module_file to test out RuntimeParameter in
        # compiler.
        module_file=data_types.RuntimeParameter(name="module_file",
                                                default=os.path.join(
                                                    iris_root,
                                                    "iris_utils.py"),
                                                ptype=str),
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=example_gen.outputs["examples"],
        schema=schema_gen.outputs["schema"],
        train_args=trainer_pb2.TrainArgs(num_steps=2000),
        # Attaching `TrainerArgs` as platform config is not sensible practice,
        # but is only for testing purpose.
        eval_args=trainer_pb2.EvalArgs(num_steps=5)).with_platform_config(
            config=trainer_pb2.TrainArgs(num_steps=2000))

    model_resolver = ResolverNode(
        instance_name="latest_blessed_model_resolver",
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        baseline_model=Channel(type=standard_artifacts.Model,
                               producer_component_id="Trainer"),
        # Cannot add producer_component_id="Evaluator" for model_blessing as it
        # raises "producer component should have already been compiled" error.
        model_blessing=Channel(type=standard_artifacts.ModelBlessing))

    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name="eval")],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    "sparse_categorical_accuracy":
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={"value": 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={"value": -1e-10}))
                })
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs["examples"],
        model=trainer.outputs["model"],
        baseline_model=model_resolver.outputs["baseline_model"],
        eval_config=eval_config)

    pusher = Pusher(model=trainer.outputs["model"],
                    model_blessing=evaluator.outputs["blessing"],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            importer,
            schema_gen,
            example_validator,
            trainer,
            model_resolver,
            evaluator,
            pusher,
        ],
        enable_cache=False,
        beam_pipeline_args=["--my_testing_beam_pipeline_args=bar"],
        # Attaching `TrainerArgs` as platform config is not sensible practice,
        # but is only for testing purpose.
        platform_config=trainer_pb2.TrainArgs(num_steps=2000),
        execution_mode=pipeline.ExecutionMode.ASYNC)
コード例 #8
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text,
                     training_data_root: Text, inference_data_root: Text,
                     module_file: Text, metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""
    # Brings training data into the pipeline or otherwise joins/converts
    # training data.
    training_example_gen = CsvExampleGen(
        input_base=training_data_root).with_id('training_example_gen')

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(
        examples=training_example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=training_example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that implements a model using TF-Learn.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000))

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=training_example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Brings inference data into the pipeline.
    inference_example_gen = CsvExampleGen(
        input_base=inference_data_root,
        output_config=example_gen_pb2.Output(
            split_config=example_gen_pb2.SplitConfig(splits=[
                example_gen_pb2.SplitConfig.Split(name='unlabelled',
                                                  hash_buckets=100)
            ]))).with_id('inference_example_gen')

    # Performs offline batch inference over inference examples.
    bulk_inferrer = BulkInferrer(
        examples=inference_example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        # Empty data_spec.example_splits will result in using all splits.
        data_spec=bulk_inferrer_pb2.DataSpec(),
        model_spec=bulk_inferrer_pb2.ModelSpec())

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            training_example_gen, inference_example_gen, statistics_gen,
            schema_gen, example_validator, transform, trainer, model_resolver,
            evaluator, bulk_inferrer
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
コード例 #9
0
class SchemaGen(base_component.BaseComponent):
    """A TFX SchemaGen component to generate a schema from the training data.

  The SchemaGen component uses [TensorFlow Data
  Validation](https://www.tensorflow.org/tfx/data_validation) to
  generate a schema from input statistics.  The following TFX libraries use the
  schema:
    - TensorFlow Data Validation
    - TensorFlow Transform
    - TensorFlow Model Analysis

  In a typical TFX pipeline, the SchemaGen component generates a schema which is
  is consumed by the other pipeline components.

  Please see https://www.tensorflow.org/tfx/data_validation for more details.

  ## Example
  ```
    # Generates schema based on statistics files.
    infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
  ```
  """
    # TODO(b/123941608): Update pydoc about how to use a user provided schema

    SPEC_CLASS = SchemaGenSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

    def __init__(self,
                 statistics: Optional[types.Channel] = None,
                 infer_feature_shape: Optional[Union[
                     bool, data_types.RuntimeParameter]] = False,
                 exclude_splits: Optional[List[Text]] = None,
                 output: Optional[types.Channel] = None,
                 stats: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Constructs a SchemaGen component.

    Args:
      statistics: A Channel of `ExampleStatistics` type (required if spec is not
        passed). This should contain at least a `train` split. Other splits are
        currently ignored. _required_
      infer_feature_shape: Boolean (or RuntimeParameter) value indicating
        whether or not to infer the shape of features. If the feature shape is
        not inferred, downstream Tensorflow Transform component using the schema
        will parse input as tf.SparseTensor.
      exclude_splits: Names of splits that will not be taken into consideration
        when auto-generating a schema. Default behavior (when exclude_splits is
        set to None) is excluding no splits.
      output: Output `Schema` channel for schema result.
      stats: Backwards compatibility alias for the 'statistics' argument.
      instance_name: Optional name assigned to this specific instance of
        SchemaGen.  Required only if multiple SchemaGen components are declared
        in the same pipeline.  Either `statistics` or `stats` must be present in
        the input arguments.
    """
        if stats:
            logging.warning(
                'The "stats" argument to the SchemaGen component has '
                'been renamed to "statistics" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            statistics = stats
        if exclude_splits is None:
            exclude_splits = []
            logging.info(
                'Excluding no splits because exclude_splits is not set.')
        schema = output or types.Channel(type=standard_artifacts.Schema)
        if isinstance(infer_feature_shape, bool):
            infer_feature_shape = int(infer_feature_shape)
        spec = SchemaGenSpec(statistics=statistics,
                             infer_feature_shape=infer_feature_shape,
                             exclude_splits=json_utils.dumps(exclude_splits),
                             schema=schema)
        super(SchemaGen, self).__init__(spec=spec, instance_name=instance_name)
コード例 #10
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     custom_config: Dict[Text, Any], module_file: Text,
                     serving_model_dir: Text, metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the handwritten digit classification example using TFX."""
    # Store the configuration along with the pipeline run so results can be reproduced
    pipeline_configuration = FromCustomConfig(custom_config=custom_config)

    # Brings data into the pipeline.
    example_gen = ImportExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Create a filtered dataset - today we only want a model for small digits
    filter = Filter(examples=example_gen.outputs['examples'],
                    pipeline_configuration=pipeline_configuration.
                    outputs['pipeline_configuration'],
                    splits_to_transform=['train', 'eval'],
                    splits_to_copy=[])

    # Create a stratified dataset for evaluation
    stratified_examples = StratifiedSampler(
        examples=filter.outputs['filtered_examples'],
        pipeline_configuration=pipeline_configuration.
        outputs['pipeline_configuration'],
        samples_per_key=1200,
        splits_to_transform=['eval'],
        splits_to_copy=['train'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=filter.outputs['filtered_examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that trains a Keras model.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        custom_config=custom_config,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        train_args=trainer_pb2.TrainArgs(num_steps=5000),
        eval_args=trainer_pb2.EvalArgs(num_steps=100)).with_id(u'trainer')

    # Uses TFMA to compute evaluation statistics over features of a model and
    # performs quality validation of a candidate model.
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='image_class')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.8})))
            ])
        ])

    # Uses TFMA to compute the evaluation statistics over features of a model.
    evaluator = Evaluator(
        examples=stratified_examples.outputs['stratified_examples'],
        model=trainer.outputs['model'],
        eval_config=eval_config).with_id(u'evaluator')

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=serving_model_dir))).with_id(u'pusher')

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            pipeline_configuration,
            example_gen,
            filter,
            stratified_examples,
            statistics_gen,
            schema_gen,
            example_validator,
            transform,
            trainer,
            evaluator,
            pusher,
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
コード例 #11
0
ファイル: decorators.py プロジェクト: jay90099/tfx
def component(
    func: Optional[types.FunctionType] = None,
    component_annotation: Optional[Type[SystemExecution]] = None,
) -> Callable[..., Any]:
    """Decorator: creates a component from a typehint-annotated Python function.

  This decorator creates a component based on typehint annotations specified for
  the arguments and return value for a Python function. The decorator can be
  supplied with a parameter `component_annotation` to specify the annotation for
  this component decorator. This annotation hints which system execution type
  this python function-based component belongs to.
  Specifically, function arguments can be annotated with the following types and
  associated semantics:

  * `Parameter[T]` where `T` is `int`, `float`, `str`, or `bytes`: indicates
    that a primitive type execution parameter, whose value is known at pipeline
    construction time, will be passed for this argument. These parameters will
    be recorded in ML Metadata as part of the component's execution record. Can
    be an optional argument.
  * `int`, `float`, `str`, `bytes`: indicates that a primitive type value will
    be passed for this argument. This value is tracked as an `Integer`, `Float`
    `String` or `Bytes` artifact (see `tfx.types.standard_artifacts`) whose
    value is read and passed into the given Python component function. Can be
    an optional argument.
  * `InputArtifact[ArtifactType]`: indicates that an input artifact object of
    type `ArtifactType` (deriving from `tfx.types.Artifact`) will be passed for
    this argument. This artifact is intended to be consumed as an input by this
    component (possibly reading from the path specified by its `.uri`). Can be
    an optional argument by specifying a default value of `None`.
  * `OutputArtifact[ArtifactType]`: indicates that an output artifact object of
    type `ArtifactType` (deriving from `tfx.types.Artifact`) will be passed for
    this argument. This artifact is intended to be emitted as an output by this
    component (and written to the path specified by its `.uri`). Cannot be an
    optional argument.

  The return value typehint should be either empty or `None`, in the case of a
  component function that has no return values, or an instance of
  `OutputDict(key_1=type_1, ...)`, where each key maps to a given type (each
  type is a primitive value type, i.e. `int`, `float`, `str` or `bytes`; or
  `Optional[T]`, where T is a primitive type value, in which case `None` can be
  returned), to indicate that the return value is a dictionary with specified
  keys and value types.

  Note that output artifacts should not be included in the return value
  typehint; they should be included as `OutputArtifact` annotations in the
  function inputs, as described above.

  The function to which this decorator is applied must be at the top level of
  its Python module (it may not be defined within nested classes or function
  closures).

  This is example usage of component definition using this decorator:

      from tfx.dsl.components.base.annotations import OutputDict
      from tfx.dsl.components.base.annotations import
      InputArtifact
      from tfx.dsl.components.base.annotations import
      OutputArtifact
      from tfx.dsl.components.base.annotations import
      Parameter
      from tfx.dsl.components.base.decorators import component
      from tfx.types.standard_artifacts import Examples
      from tfx.types.standard_artifacts import Model

      @component(component_annotation=system_executions.Train)
      def MyTrainerComponent(
          training_data: InputArtifact[Examples],
          model: OutputArtifact[Model],
          dropout_hyperparameter: float,
          num_iterations: Parameter[int] = 10
          ) -> OutputDict(loss=float, accuracy=float):
        '''My simple trainer component.'''

        records = read_examples(training_data.uri)
        model_obj = train_model(records, num_iterations, dropout_hyperparameter)
        model_obj.write_to(model.uri)

        return {
          'loss': model_obj.loss,
          'accuracy': model_obj.accuracy
        }

      # Example usage in a pipeline graph definition:
      # ...
      trainer = MyTrainerComponent(
          training_data=example_gen.outputs['examples'],
          dropout_hyperparameter=other_component.outputs['dropout'],
          num_iterations=1000)
      pusher = Pusher(model=trainer.outputs['model'])
      # ...

  When the parameter `component_annotation` is not supplied, the default value
  is None. This is another example usage with `component_annotation` = None:

      @component
      def MyTrainerComponent(
          training_data: InputArtifact[Examples],
          model: OutputArtifact[Model],
          dropout_hyperparameter: float,
          num_iterations: Parameter[int] = 10
          ) -> OutputDict(loss=float, accuracy=float):
        '''My simple trainer component.'''

        records = read_examples(training_data.uri)
        model_obj = train_model(records, num_iterations, dropout_hyperparameter)
        model_obj.write_to(model.uri)

        return {
          'loss': model_obj.loss,
          'accuracy': model_obj.accuracy
        }

  Experimental: no backwards compatibility guarantees.

  Args:
    func: Typehint-annotated component executor function.
    component_annotation: used to annotate the python function-based component.
      It is a subclass of SystemExecution from
      third_party/py/tfx/types/system_executions.py; it can be None.

  Returns:
    `base_component.BaseComponent` subclass for the given component executor
    function.

  Raises:
    EnvironmentError: if the current Python interpreter is not Python 3.
  """
    if func is None:
        return functools.partial(component,
                                 component_annotation=component_annotation)

    # Defining a component within a nested class or function closure causes
    # problems because in this case, the generated component classes can't be
    # referenced via their qualified module path.
    #
    # See https://www.python.org/dev/peps/pep-3155/ for details about the special
    # '<locals>' namespace marker.
    if '<locals>' in func.__qualname__.split('.'):
        raise ValueError(
            'The @component decorator can only be applied to a function defined '
            'at the module level. It cannot be used to construct a component for a '
            'function defined in a nested class or function closure.')

    inputs, outputs, parameters, arg_formats, arg_defaults, returned_values = (
        function_parser.parse_typehint_component_function(func))

    spec_inputs = {}
    spec_outputs = {}
    spec_parameters = {}
    for key, artifact_type in inputs.items():
        spec_inputs[key] = component_spec.ChannelParameter(
            type=artifact_type, optional=(key in arg_defaults))
    for key, artifact_type in outputs.items():
        assert key not in arg_defaults, 'Optional outputs are not supported.'
        spec_outputs[key] = component_spec.ChannelParameter(type=artifact_type)
    for key, primitive_type in parameters.items():
        spec_parameters[key] = component_spec.ExecutionParameter(
            type=primitive_type, optional=(key in arg_defaults))
    component_spec_class = type(
        '%s_Spec' % func.__name__, (tfx_types.ComponentSpec, ), {
            'INPUTS': spec_inputs,
            'OUTPUTS': spec_outputs,
            'PARAMETERS': spec_parameters,
            'TYPE_ANNOTATION': component_annotation,
        })

    executor_class = type(
        '%s_Executor' % func.__name__,
        (_FunctionExecutor, ),
        {
            '_ARG_FORMATS': arg_formats,
            '_ARG_DEFAULTS': arg_defaults,
            # The function needs to be marked with `staticmethod` so that later
            # references of `self._FUNCTION` do not result in a bound method (i.e.
            # one with `self` as its first parameter).
            '_FUNCTION': staticmethod(func),
            '_RETURNED_VALUES': returned_values,
            '__module__': func.__module__,
        })

    # Expose the generated executor class in the same module as the decorated
    # function. This is needed so that the executor class can be accessed at the
    # proper module path. One place this is needed is in the Dill pickler used by
    # Apache Beam serialization.
    module = sys.modules[func.__module__]
    setattr(module, '%s_Executor' % func.__name__, executor_class)

    executor_spec_instance = executor_spec.ExecutorClassSpec(
        executor_class=executor_class)

    return type(
        func.__name__, (_SimpleComponent, ), {
            'SPEC_CLASS': component_spec_class,
            'EXECUTOR_SPEC': executor_spec_instance,
            '__module__': func.__module__,
        })
コード例 #12
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir_lite: Text,
                     metadata_path: Text, labels_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the CIFAR10 image classification pipeline using TFX."""
    # This is needed for datasets with pre-defined splits
    # Change the pattern argument to train_whole/* and test_whole/* to train
    # on the whole CIFAR-10 dataset
    input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='train/*'),
        example_gen_pb2.Input.Split(name='eval', pattern='test/*')
    ])

    # Brings data into the pipeline.
    example_gen = ImportExampleGen(input_base=data_root,
                                   input_config=input_config)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that trains a model.
    # When traning on the whole dataset, use 18744 for train steps, 156 for eval
    # steps. 18744 train steps correspond to 24 epochs on the whole train set, and
    # 156 eval steps correspond to 1 epoch on the whole test set. The
    # configuration below is for training on the dataset we provided in the data
    # folder, which has 128 train and 128 test samples. The 160 train steps
    # correspond to 40 epochs on this tiny train set, and 4 eval steps correspond
    # to 1 epoch on this tiny test set.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        train_args=trainer_pb2.TrainArgs(num_steps=160),
        eval_args=trainer_pb2.EvalArgs(num_steps=4),
        custom_config={'labels_path': labels_path})

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compare to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[
            tfma.ModelSpec(label_key='label_xf', model_type='tf_lite')
        ],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.55}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-3})))
            ])
        ])

    # Uses TFMA to compute the evaluation statistics over features of a model.
    # We evaluate using the materialized examples that are output by Transform
    # because
    # 1. the decoding_png function currently performed within Transform are not
    # compatible with TFLite.
    # 2. MLKit requires deserialized (float32) tensor image inputs
    # Note that for deployment, the same logic that is performed within Transform
    # must be reproduced client-side.
    evaluator = Evaluator(examples=transform.outputs['transformed_examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir_lite)))

    components = [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        trainer, model_resolver, evaluator, pusher
    ]

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
コード例 #13
0
ファイル: component.py プロジェクト: jay90099/tfx
class SchemaGen(base_component.BaseComponent):
  """A TFX SchemaGen component to generate a schema from the training data.

  The SchemaGen component uses [TensorFlow Data
  Validation](https://www.tensorflow.org/tfx/data_validation/api_docs/python/tfdv)
  to generate a schema from input statistics. The following TFX libraries use
  the schema:
    - TensorFlow Data Validation
    - TensorFlow Transform
    - TensorFlow Model Analysis

  In a typical TFX pipeline, the SchemaGen component generates a schema which
  is consumed by the other pipeline components.

  ## Example
  ```
    # Generates schema based on statistics files.
    infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
  ```

  Component `outputs` contains:
   - `schema`: Channel of type `standard_artifacts.Schema` for schema
   result.

  See [the SchemaGen guide](https://www.tensorflow.org/tfx/guide/schemagen)
  for more details.
  """
  SPEC_CLASS = standard_component_specs.SchemaGenSpec
  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(
      self,
      statistics: types.BaseChannel,
      infer_feature_shape: Optional[Union[bool,
                                          data_types.RuntimeParameter]] = True,
      exclude_splits: Optional[List[str]] = None):
    """Constructs a SchemaGen component.

    Args:
      statistics: A BaseChannel of `ExampleStatistics` type (required if spec is
        not passed). This should contain at least a `train` split. Other splits
        are currently ignored. _required_
      infer_feature_shape: Boolean (or RuntimeParameter) value indicating
        whether or not to infer the shape of features. If the feature shape is
        not inferred, downstream Tensorflow Transform component using the schema
        will parse input as tf.SparseTensor. Default to True if not set.
      exclude_splits: Names of splits that will not be taken into consideration
        when auto-generating a schema. Default behavior (when exclude_splits is
        set to None) is excluding no splits.
    """
    if exclude_splits is None:
      exclude_splits = []
      logging.info('Excluding no splits because exclude_splits is not set.')
    schema = types.Channel(type=standard_artifacts.Schema)
    if isinstance(infer_feature_shape, bool):
      infer_feature_shape = int(infer_feature_shape)
    spec = standard_component_specs.SchemaGenSpec(
        statistics=statistics,
        infer_feature_shape=infer_feature_shape,
        exclude_splits=json_utils.dumps(exclude_splits),
        schema=schema)
    super().__init__(spec=spec)
コード例 #14
0
class Trainer(base_component.BaseComponent):
    """A TFX component to train a TensorFlow model.

  The Trainer component is used to train and eval a model using given inputs and
  a user-supplied run_fn function.

  An example of `run_fn()` can be found in the [user-supplied
  code](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/penguin_utils_keras.py)
  of the TFX penguin pipeline example.

  *Note:* The default executor for this component trains locally.  This can be
  overriden to enable the model to be trained on other platforms.  The [Cloud AI
  Platform custom
  executor](https://github.com/tensorflow/tfx/tree/master/tfx/extensions/google_cloud_ai_platform/trainer)
  provides an example how to implement this.

  ## Example 1: Training locally
  ```
  # Uses user-provided Python function that trains a model using TF.
  trainer = Trainer(
      module_file=module_file,
      transformed_examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=proto.TrainArgs(splits=['train'], num_steps=10000),
      eval_args=proto.EvalArgs(splits=['eval'], num_steps=5000))
  ```

  ## Example 2: Training through a cloud provider
  ```
  from tfx.extensions.google_cloud_ai_platform.trainer import executor as
  ai_platform_trainer_executor
  # Train using Google Cloud AI Platform.
  trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_trainer_executor.GenericExecutor),
      module_file=module_file,
      transformed_examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=proto.TrainArgs(splits=['train'], num_steps=10000),
      eval_args=proto.EvalArgs(splits=['eval'], num_steps=5000))
  ```

  Component `outputs` contains:
   - `model`: Channel of type `standard_artifacts.Model` for trained model.
   - `model_run`: Channel of type `standard_artifacts.ModelRun`, as the working
                  dir of models, can be used to output non-model related output
                  (e.g., TensorBoard logs).

  Please see [the Trainer guide](https://www.tensorflow.org/tfx/guide/trainer)
  for more details.
  """

    SPEC_CLASS = standard_component_specs.TrainerSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.GenericExecutor)

    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: Optional[types.Channel] = None,
            base_model: Optional[types.Channel] = None,
            hyperparameters: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
            # TODO(b/147702778): deprecate trainer_fn.
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
            eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. Schema is optional when
        1) transform_graph is provided which contains schema.
        2) user module bypasses the usage of schema, e.g., hardcoded.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `run_fn` at its top
        level with function signature:
          `def run_fn(trainer.fn_args_utils.FnArgs)`,
        and the trained model must be saved to FnArgs.serving_model_dir when
        this function is executed.

        For Estimator based Executor, The module_file must implement a function
        named `trainer_fn` at its top level. The function must have the
        following signature.
          def trainer_fn(trainer.fn_args_utils.FnArgs,
                         tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
            ...
          where the returned Dict has the following key-values.
            'estimator': an instance of tf.estimator.Estimator
            'train_spec': an instance of tf.estimator.TrainSpec
            'eval_spec': an instance of tf.estimator.EvalSpec
            'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
        Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer
        uses GenericExecutor (default). Use of a RuntimeParameter for this
        argument is experimental.
      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor (default).
         Use of a RuntimeParameter for this argument is experimental.
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
        uses Estimator based Executor. Use of a RuntimeParameter for this
        argument is experimental.
      train_args: A proto.TrainArgs instance or a dict, containing args
        used for training. Currently only splits and num_steps are available. If
        it's provided as a dict and any field is a RuntimeParameter, it should
        have the same field names as a TrainArgs proto message. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A proto.EvalArgs instance or a dict, containing args
        used for evaluation. Currently only splits and num_steps are available.
        If it's provided as a dict and any field is a RuntimeParameter, it
        should have the same field names as a EvalArgs proto message. Default
        behavior (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec. This is experimental
        and is subject to change in the future.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
        if [bool(module_file),
                bool(run_fn),
                bool(trainer_fn)].count(True) != 1:
            raise ValueError(
                "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
                "supplied.")

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")

        if custom_executor_spec:
            logging.warning(
                "`custom_executor_spec` is going to be deprecated.")
        examples = examples or transformed_examples
        model = types.Channel(type=standard_artifacts.Model)
        model_run = types.Channel(type=standard_artifacts.ModelRun)
        spec = standard_component_specs.TrainerSpec(
            examples=examples,
            transform_graph=transform_graph,
            schema=schema,
            base_model=base_model,
            hyperparameters=hyperparameters,
            train_args=train_args,
            eval_args=eval_args,
            module_file=module_file,
            run_fn=run_fn,
            trainer_fn=trainer_fn,
            custom_config=json_utils.dumps(custom_config),
            model=model,
            model_run=model_run)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec)

        if udf_utils.should_package_user_modules():
            # In this case, the `MODULE_PATH_KEY` execution property will be injected
            # as a reference to the given user module file after packaging, at which
            # point the `MODULE_FILE_KEY` execution property will be removed.
            udf_utils.add_user_module_dependency(
                self, standard_component_specs.MODULE_FILE_KEY,
                standard_component_specs.MODULE_PATH_KEY)
コード例 #15
0
ファイル: iris_pipeline_sklearn_gcp.py プロジェクト: lre/tfx
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, metadata_path: Text,
                     ai_platform_training_args: Optional[Dict[Text, Text]],
                     ai_platform_serving_args: Optional[Dict[Text, Text]],
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
  """Implements the Iris flowers pipeline with TFX."""
  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # TODO(humichael): Handle applying transformation component in Milestone 3.

  # Uses user-provided Python function that trains a model using TF-Learn.
  # Num_steps is not provided during evaluation because the scikit-learn model
  # loads and evaluates the entire test set at once.
  # TODO(b/159470716): Make schema optional in Trainer.
  trainer = Trainer(
      module_file=module_file,
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_trainer_executor.GenericExecutor),
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      train_args=trainer_pb2.TrainArgs(num_steps=2000),
      eval_args=trainer_pb2.EvalArgs(),
      custom_config={
          ai_platform_trainer_executor.TRAINING_ARGS_KEY:
          ai_platform_training_args,
      })

  # TODO(humichael): Add Evaluator once it's decided how to proceed with
  # Milestone 2.

  pusher = Pusher(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_pusher_executor.Executor),
      model=trainer.outputs['model'],
      custom_config={
          ai_platform_pusher_executor.SERVING_ARGS_KEY:
          ai_platform_serving_args,
      })

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=[
          example_gen,
          statistics_gen,
          schema_gen,
          example_validator,
          trainer,
          pusher,
      ],
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      beam_pipeline_args=beam_pipeline_args,
  )
コード例 #16
0
 def _get_executor_spec(self):
     return executor_spec.ExecutorClassSpec(self.get_executor())
コード例 #17
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the penguin pipeline with TFX."""
    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that trains a model using TF-Learn.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        train_args=trainer_pb2.TrainArgs(num_steps=2000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5))

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute an evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='species')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        # Change threshold will be ignored if there is no baseline (first run).
        eval_config=eval_config)

    # Performs infra validation of a candidate model to prevent unservable model
    # from being pushed. This config will launch a model server of the latest
    # TensorFlow Serving image in a local docker engine.
    infra_validator = InfraValidator(
        model=trainer.outputs['model'],
        examples=example_gen.outputs['examples'],
        serving_spec=infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(
                tags=['latest']),
            local_docker=infra_validator_pb2.LocalDockerConfig()),
        request_spec=infra_validator_pb2.RequestSpec(
            tensorflow_serving=infra_validator_pb2.
            TensorFlowServingRequestSpec()))

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    infra_blessing=infra_validator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            schema_gen,
            example_validator,
            transform,
            trainer,
            model_resolver,
            evaluator,
            infra_validator,
            pusher,
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
コード例 #18
0
def create_pipeline(
        pipeline_name: Text,
        pipeline_root: Text,
        data_root: Text,
        module_file: Text,
        ai_platform_training_args: Dict[Text, Text],
        pusher_custom_config: Dict[Text, Text],
        enable_tuning: bool,
        beam_pipeline_args: Optional[List[Text]] = None) -> pipeline.Pipeline:
    """Implements the penguin pipeline with TFX and Kubeflow Pipeline.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root: uri of the penguin data.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    pusher_custom_config: Custom configs passed to pusher.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options

  Returns:
    A TFX pipeline object.
  """
    examples = external_input(data_root)

    # Beam args to run data processing on DataflowRunner.
    #
    # TODO(b/151114974): Remove `disk_size_gb` flag after default is increased.
    # TODO(b/151116587): Remove `shuffle_mode` flag after default is changed.
    # TODO(b/156874687): Remove `machine_type` after IP addresses are no longer a
    #                    scaling bottleneck.
    if beam_pipeline_args is None:
        beam_pipeline_args = [
            '--runner=DataflowRunner',
            '--project=' + _project_id,
            '--temp_location=' + os.path.join(_output_bucket, 'tmp'),
            '--region=' + _gcp_region,

            # Temporary overrides of defaults.
            '--disk_size_gb=50',
            '--experiments=shuffle_mode=auto',
            '--machine_type=e2-standard-8',
        ]

    # Number of epochs in training.
    train_steps = data_types.RuntimeParameter(
        name='train_steps',
        default=100,
        ptype=int,
    )

    # Number of epochs in evaluation.
    eval_steps = data_types.RuntimeParameter(
        name='eval_steps',
        default=50,
        ptype=int,
    )

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input=examples)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Update ai_platform_training_args if distributed training was enabled.
    # Number of worker machines used in distributed training.
    worker_count = data_types.RuntimeParameter(
        name='worker_count',
        default=2,
        ptype=int,
    )

    # Type of worker machines used in distributed training.
    worker_type = data_types.RuntimeParameter(
        name='worker_type',
        default='standard',
        ptype=str,
    )

    local_training_args = copy.deepcopy(ai_platform_training_args)
    if FLAGS.distributed_training:
        local_training_args.update({
            # You can specify the machine types, the number of replicas for workers
            # and parameter servers.
            # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#ScaleTier
            'scaleTier': 'CUSTOM',
            'masterType': 'large_model',
            'workerType': worker_type,
            'parameterServerType': 'standard',
            'workerCount': worker_count,
            'parameterServerCount': 1,
        })

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        # The Tuner component launches 1 AIP Training job for flock management.
        # For example, 3 workers (defined by num_parallel_trials) in the flock
        # management AIP Training job, each runs Tuner.Executor.
        # Then, 3 AIP Training Jobs (defined by local_training_args) are invoked
        # from each worker in the flock management Job for Trial execution.
        tuner = Tuner(
            module_file=module_file,
            examples=transform.outputs['transformed_examples'],
            transform_graph=transform.outputs['transform_graph'],
            train_args={'num_steps': train_steps},
            eval_args={'num_steps': eval_steps},
            tune_args=tuner_pb2.TuneArgs(
                # num_parallel_trials=3 means that 3 search loops are
                # running in parallel.
                # Each tuner may include a distributed training job which can be
                # specified in local_training_args above (e.g. 1 PS + 2 workers).
                num_parallel_trials=3),
            custom_config={
                # Configures Cloud AI Platform-specific configs . For details, see
                # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
                ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                local_training_args
            })

    # Uses user-provided Python function that trains a model.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.GenericExecutor),
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        # If Tuner is in the pipeline, Trainer can take Tuner's output
        # best_hyperparameters artifact as input and utilize it in the user module
        # code.
        #
        # If there isn't Tuner in the pipeline, either use ImporterNode to import
        # a previous Tuner's output to feed to Trainer, or directly use the tuned
        # hyperparameters in user module code and set hyperparameters to None
        # here.
        #
        # Example of ImporterNode,
        #   hparams_importer = ImporterNode(
        #     instance_name='import_hparams',
        #     source_uri='path/to/best_hyperparameters.txt',
        #     artifact_type=HyperParameters)
        #   ...
        #   hyperparameters = hparams_importer.outputs['result'],
        hyperparameters=(tuner.outputs['best_hyperparameters']
                         if enable_tuning else None),
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={
            ai_platform_trainer_executor.TRAINING_ARGS_KEY: local_training_args
        })

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute an evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='species')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        # Change threshold will be ignored if there is no baseline (first run).
        eval_config=eval_config)

    pusher = Pusher(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_pusher_executor.Executor),
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        custom_config=pusher_custom_config,
    )

    components = [
        example_gen,
        statistics_gen,
        schema_gen,
        example_validator,
        transform,
        trainer,
        model_resolver,
        evaluator,
        pusher,
    ]
    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=components,
                             enable_cache=True,
                             beam_pipeline_args=beam_pipeline_args)
コード例 #19
0
    def benchmark(self,
                  data_dir: str = None,
                  use_keras: bool = True,
                  enable_tuning: bool = True):
        # Use TFDSTask to define the task for the titanic dataset.
        task = nitroml.tasks.TFDSTask(
            tfds.builder('titanic', data_dir=data_dir))

        autodata = nitroml.autodata.AutoData(
            task.problem_statement,
            examples=task.train_and_eval_examples,
            preprocessor=nitroml.autodata.BasicPreprocessor())

        pipeline = task.components + autodata.components

        if enable_tuning:
            # Search over search space of model hyperparameters.
            tuner = tfx.Tuner(
                tuner_fn='examples.auto_trainer.tuner_fn',
                examples=autodata.transformed_examples,
                transform_graph=autodata.transform_graph,
                train_args=trainer_pb2.TrainArgs(num_steps=100),
                eval_args=trainer_pb2.EvalArgs(num_steps=50),
                custom_config={
                    # Pass the problem statement proto as a text proto. Required
                    # since custom_config must be JSON-serializable.
                    'problem_statement':
                    text_format.MessageToString(message=task.problem_statement,
                                                as_utf8=True),
                })
            pipeline.append(tuner)

        # Define a Trainer to train our model on the given task.
        trainer = tfx.Trainer(
            run_fn='examples.auto_trainer.run_fn'
            if use_keras else 'examples.auto_estimator_trainer.run_fn',
            custom_executor_spec=(executor_spec.ExecutorClassSpec(
                trainer_executor.GenericExecutor)),
            transformed_examples=autodata.transformed_examples,
            transform_graph=autodata.transform_graph,
            schema=autodata.schema,
            train_args=trainer_pb2.TrainArgs(num_steps=1000),
            eval_args=trainer_pb2.EvalArgs(num_steps=500),
            hyperparameters=(tuner.outputs.best_hyperparameters
                             if enable_tuning else None),
            custom_config={
                # Pass the problem statement proto as a text proto. Required
                # since custom_config must be JSON-serializable.
                'problem_statement':
                text_format.MessageToString(message=task.problem_statement,
                                            as_utf8=True),
            })

        pipeline.append(trainer)

        # Finally, call evaluate() on the workflow DAG outputs. This will
        # automatically append Evaluators to compute metrics from the given
        # SavedModel and 'eval' TF Examples.
        self.evaluate(pipeline,
                      examples=task.train_and_eval_examples,
                      model=trainer.outputs.model)
コード例 #20
0
def create_pipeline(pipeline_name: Text,
                    pipeline_root: Text,
                    data_root_uri: data_types.RuntimeParameter,
                    train_steps: data_types.RuntimeParameter,
                    eval_steps: data_types.RuntimeParameter,
                    enable_tuning: bool,
                    ai_platform_training_args: Dict[Text, Text],
                    ai_platform_serving_args: Dict[Text, Text],
                    beam_pipeline_args: List[Text],
                    enable_cache: Optional[bool] = False) -> pipeline.Pipeline:
    """Trains and deploys the Keras Covertype Classifier with TFX and Kubeflow Pipeline on Google Cloud.
  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root_uri: uri of the dataset.
    train_steps: runtime parameter for number of model training steps for the Trainer component.
    eval_steps: runtime parameter for number of model evaluation steps for the Trainer component.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.    
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options
    enable_cache: Optional boolean
  Returns:
    A TFX pipeline object.
  """

    # Brings data into the pipeline and splits the data into training and eval splits
    output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))

    examplegen = CsvExampleGen(input_base=data_root_uri, output_config=output)

    # Computes statistics over data for visualization and example validation.
    statisticsgen = StatisticsGen(examples=examplegen.outputs.examples)

    # Generates schema based on statistics files. Even though, we use user-provided schema
    # we still want to generate the schema of the newest data for tracking and comparison
    schemagen = SchemaGen(statistics=statisticsgen.outputs.statistics)

    # Import a user-provided schema
    import_schema = ImporterNode(instance_name='import_user_schema',
                                 source_uri=SCHEMA_FOLDER,
                                 artifact_type=Schema)

    # Performs anomaly detection based on statistics and data schema.
    examplevalidator = ExampleValidator(
        statistics=statisticsgen.outputs.statistics,
        schema=import_schema.outputs.result)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=examplegen.outputs.examples,
                          schema=import_schema.outputs.result,
                          module_file=TRANSFORM_MODULE_FILE)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        # The Tuner component launches 1 AI Platform Training job for flock management.
        # For example, n_workers (defined by num_parallel_trials) in the flock
        # management AI Platform Training job, each run Tuner.Executor in parallel.
        tuner = Tuner(
            module_file=TRAIN_MODULE_FILE,
            examples=transform.outputs.transformed_examples,
            transform_graph=transform.outputs.transform_graph,
            train_args={'num_steps': train_steps},
            eval_args={'num_steps': eval_steps},
            tune_args=tuner_pb2.TuneArgs(
                # num_parallel_trials can be configured for distributed training.
                num_parallel_trials=1),
            custom_config={
                # Configures Cloud AI Platform-specific configs. For details, see
                # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
                ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                ai_platform_training_args
            })

    # Trains the model using a user provided trainer function.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.GenericExecutor),
        module_file=TRAIN_MODULE_FILE,
        transformed_examples=transform.outputs.transformed_examples,
        schema=import_schema.outputs.result,
        transform_graph=transform.outputs.transform_graph,
        hyperparameters=(tuner.outputs.best_hyperparameters
                         if enable_tuning else None),
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={'ai_platform_training_args': ai_platform_training_args})

    # Get the latest blessed model for model validation.
    resolver = ResolverNode(instance_name='latest_blessed_model_resolver',
                            resolver_class=latest_blessed_model_resolver.
                            LatestBlessedModelResolver,
                            model=Channel(type=Model),
                            model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute a evaluation statistics over features of a model.
    accuracy_threshold = tfma.MetricThreshold(
        value_threshold=tfma.GenericValueThreshold(lower_bound={'value': 0.5},
                                                   upper_bound={'value':
                                                                0.99}), )

    metrics_specs = tfma.MetricsSpec(metrics=[
        tfma.MetricConfig(class_name='SparseCategoricalAccuracy',
                          threshold=accuracy_threshold),
        tfma.MetricConfig(class_name='ExampleCount')
    ])

    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='Cover_Type')],
        metrics_specs=[metrics_specs],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['Wilderness_Area'])
        ])

    evaluator = Evaluator(examples=examplegen.outputs.examples,
                          model=trainer.outputs.model,
                          baseline_model=resolver.outputs.model,
                          eval_config=eval_config)

    # Validate model can be loaded and queried in sand-boxed environment
    # mirroring production.
    serving_config = infra_validator_pb2.ServingSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServing(
            tags=['latest']),
        kubernetes=infra_validator_pb2.KubernetesConfig(),
    )

    validation_config = infra_validator_pb2.ValidationSpec(
        max_loading_time_seconds=60,
        num_tries=3,
    )

    request_config = infra_validator_pb2.RequestSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServingRequestSpec(),
        num_examples=3,
    )

    infravalidator = InfraValidator(
        model=trainer.outputs.model,
        examples=examplegen.outputs.examples,
        serving_spec=serving_config,
        validation_spec=validation_config,
        request_spec=request_config,
    )

    # Checks whether the model passed the validation steps and pushes the model
    # to CAIP Prediction if checks are passed.
    pusher = Pusher(custom_executor_spec=executor_spec.ExecutorClassSpec(
        ai_platform_pusher_executor.Executor),
                    model=trainer.outputs.model,
                    model_blessing=evaluator.outputs.blessing,
                    infra_blessing=infravalidator.outputs.blessing,
                    custom_config={
                        ai_platform_pusher_executor.SERVING_ARGS_KEY:
                        ai_platform_serving_args
                    })

    components = [
        examplegen, statisticsgen, schemagen, import_schema, examplevalidator,
        transform, trainer, resolver, evaluator, infravalidator, pusher
    ]

    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=components,
                             enable_cache=enable_cache,
                             beam_pipeline_args=beam_pipeline_args)
コード例 #21
0
 class _FakeComponent(base_component.BaseComponent):
     SPEC_CLASS = types.ComponentSpec
     EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(
         _get_fake_executor(label))
コード例 #22
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text, enable_tuning: bool,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the Iris flowers pipeline with TFX."""
    examples = external_input(data_root)

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input=examples)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        tuner = Tuner(module_file=module_file,
                      examples=transform.outputs['transformed_examples'],
                      transform_graph=transform.outputs['transform_graph'],
                      train_args=trainer_pb2.TrainArgs(num_steps=20),
                      eval_args=trainer_pb2.EvalArgs(num_steps=5))

    # Uses user-provided Python function that trains a model.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        # If Tuner is in the pipeline, Trainer can take Tuner's output
        # best_hyperparameters artifact as input and utilize it in the user module
        # code.
        #
        # If there isn't Tuner in the pipeline, either use ImporterNode to import
        # a previous Tuner's output to feed to Trainer, or directly use the tuned
        # hyperparameters in user module code and set hyperparameters to None
        # here.
        #
        # Example of ImporterNode,
        #   hparams_importer = ImporterNode(
        #     instance_name='import_hparams',
        #     source_uri='path/to/best_hyperparameters.txt',
        #     artifact_type=HyperParameters)
        #   ...
        #   hyperparameters = hparams_importer.outputs['result'],
        hyperparameters=(tuner.outputs['best_hyperparameters']
                         if enable_tuning else None),
        train_args=trainer_pb2.TrainArgs(num_steps=100),
        eval_args=trainer_pb2.EvalArgs(num_steps=5))

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute an evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='variety')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        # Change threshold will be ignored if there is no baseline (first run).
        eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    components = [
        example_gen,
        statistics_gen,
        schema_gen,
        example_validator,
        transform,
        trainer,
        model_resolver,
        evaluator,
        pusher,
    ]
    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
コード例 #23
0
class CsvExampleGen(component.FileBasedExampleGen):  # pylint: disable=protected-access
  """Official TFX CsvExampleGen component.

  The csv examplegen component takes csv data, and generates train
  and eval examples for downstream components.

  The csv examplegen encodes column values to tf.Example int/float/byte feature.
  For the case when there's missing cells, the csv examplegen uses:
  -- tf.train.Feature(`type`_list=tf.train.`type`List(value=[])), when the
     `type` can be inferred.
  -- tf.train.Feature() when it cannot infer the `type` from the column.

  Note that the type inferring will be per input split. If input isn't a single
  split, users need to ensure the column types align in each pre-splits.

  For example, given the following csv rows of a split:

    header:A,B,C,D
    row1:  1,,x,0.1
    row2:  2,,y,0.2
    row3:  3,,,0.3
    row4:

  The output example will be
    example1: 1(int), empty feature(no type), x(string), 0.1(float)
    example2: 2(int), empty feature(no type), x(string), 0.2(float)
    example3: 3(int), empty feature(no type), empty list(string), 0.3(float)

    Note that the empty feature is `tf.train.Feature()` while empty list string
    feature is `tf.train.Feature(bytes_list=tf.train.BytesList(value=[]))`.
  """

  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(
      self,
      # TODO(b/159467778): deprecate this, use input_base instead.
      input: Optional[types.Channel] = None,  # pylint: disable=redefined-builtin
      input_base: Optional[Text] = None,
      input_config: Optional[Union[example_gen_pb2.Input, Dict[Text,
                                                               Any]]] = None,
      output_config: Optional[Union[example_gen_pb2.Output, Dict[Text,
                                                                 Any]]] = None,
      range_config: Optional[Union[range_config_pb2.RangeConfig,
                                   Dict[Text, Any]]] = None,
      example_artifacts: Optional[types.Channel] = None,
      instance_name: Optional[Text] = None):
    """Construct a CsvExampleGen component.

    Args:
      input: A Channel of type `standard_artifacts.ExternalArtifact`, which
        includes one artifact whose uri is an external directory containing the
        CSV files. (Deprecated by input_base)
      input_base: an external directory containing the CSV files.
      input_config: An example_gen_pb2.Input instance, providing input
        configuration. If unset, the files under input_base will be treated as a
        single split. If any field is provided as a RuntimeParameter,
        input_config should be constructed as a dict with the same field names
        as Input proto message.
      output_config: An example_gen_pb2.Output instance, providing output
        configuration. If unset, default splits will be 'train' and 'eval' with
        size 2:1. If any field is provided as a RuntimeParameter,
        output_config should be constructed as a dict with the same field names
        as Output proto message.
      range_config: An optional range_config_pb2.RangeConfig instance,
        specifying the range of span values to consider. If unset, driver will
        default to searching for latest span with no restrictions.
      example_artifacts: Optional channel of 'ExamplesPath' for output train and
        eval examples.
      instance_name: Optional unique instance name. Necessary if multiple
        CsvExampleGen components are declared in the same pipeline.
    """
    if input:
      logging.warning(
          'The "input" argument to the CsvExampleGen component has been '
          'deprecated by "input_base". Please update your usage as support for '
          'this argument will be removed soon.')
      input_base = artifact_utils.get_single_uri(list(input.get()))
    super(CsvExampleGen, self).__init__(
        input_base=input_base,
        input_config=input_config,
        output_config=output_config,
        range_config=range_config,
        example_artifacts=example_artifacts,
        instance_name=instance_name)
コード例 #24
0
class ModelValidator(base_component.BaseComponent):
    """DEPRECATED: Please use `Evaluator` instead.

  The model validator component can be used to check model metrics threshold
  and validate current model against a previously validated model. If there
  isn't a prior validated model, model validator will just make sure the
  threshold passed.  Otherwise, ModelValidator compares a newly trained models
  against a known good model, specifically the last model "blessed" by this
  component.  A model is "blessed" if the exported model's metrics are within
  predefined thresholds around the prior model's metrics.

  *Note:* This component includes a driver to resolve last blessed model.

  ## Possible causes why model validation fails
  Model validation can fail for many reasons, but these are the most common:

  - problems with training data.  For example, negative examples are dropped or
    features are missing.
  - problems with the test or evaluation data.  For example, skew exists between
    the training and evaluation data.
  - changes in data distribution.  This indicates the user behavior may have
    changed over time.
  - problems with the trainer.  For example, the trainer was stopped before
    model is converged or the model is unstable.

  ## Example
  ```
    # Performs quality validation of a candidate model (compared to a baseline).
    model_validator = ModelValidator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'])
  ```
  """

    SPEC_CLASS = ModelValidatorSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)
    DRIVER_CLASS = driver.Driver

    @deprecation_utils.deprecated(
        None, 'ModelValidator is deprecated, use Evaluator instead.')
    def __init__(self,
                 examples: types.Channel,
                 model: types.Channel,
                 blessing: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Construct a ModelValidator component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, usually
        produced by an
        [ExampleGen](https://www.tensorflow.org/tfx/guide/examplegen) component.
        _required_
      model: A Channel of type `standard_artifacts.Model`, usually produced by
        a [Trainer](https://www.tensorflow.org/tfx/guide/trainer) component.
        _required_
      blessing: Output channel of type `standard_artifacts.ModelBlessing`
        that contains the validation result.
      instance_name: Optional name assigned to this specific instance of
        ModelValidator.  Required only if multiple ModelValidator components are
        declared in the same pipeline.
    """
        blessing = blessing or types.Channel(
            type=standard_artifacts.ModelBlessing)
        spec = ModelValidatorSpec(examples=examples,
                                  model=model,
                                  blessing=blessing)
        super(ModelValidator, self).__init__(spec=spec,
                                             instance_name=instance_name)
コード例 #25
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     user_schema_path: Text, module_file: Text,
                     serving_model_dir: Text, metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""
    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Import user-provided schema.
    user_schema_importer = ImporterNode(
        source_uri=user_schema_path,
        artifact_type=Schema).with_id('import_user_schema')

    # Generates schema based on statistics files. Even we use user-provided schema
    # in downstream components, we still want to generate the schema of the newest
    # data so that user can compare and optionally update the schema to use.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=user_schema_importer.outputs['result'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=user_schema_importer.outputs['result'],
                          module_file=module_file)

    # Uses user-provided Python function that implements a model using TF-Learn.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=user_schema_importer.outputs['result'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000))

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, user_schema_importer, schema_gen,
            example_validator, transform, trainer, model_resolver, evaluator,
            pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
コード例 #26
0
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_root: Text,
    module_file: Text,
    ai_platform_training_args: Dict[Text, Text],
    ai_platform_serving_args: Dict[Text, Text],
    enable_tuning: bool,
    beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
  """Implements the penguin pipeline with TFX and Kubeflow Pipeline.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root: uri of the penguin data.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.
    beam_pipeline_args: List of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.

  Returns:
    A TFX pipeline object.
  """
  # Number of epochs in training.
  train_steps = data_types.RuntimeParameter(
      name='train_steps',
      default=100,
      ptype=int,
  )

  # Number of epochs in evaluation.
  eval_steps = data_types.RuntimeParameter(
      name='eval_steps',
      default=50,
      ptype=int,
  )

  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  # Tunes the hyperparameters for model training based on user-provided Python
  # function. Note that once the hyperparameters are tuned, you can drop the
  # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
  if enable_tuning:
    # The Tuner component launches 1 AIP Training job for flock management of
    # parallel tuning. For example, 2 workers (defined by num_parallel_trials)
    # in the flock management AIP Training job, each runs a search loop for
    # trials as shown below.
    #   Tuner component -> CAIP job X -> CloudTunerA -> tuning trials
    #                                 -> CloudTunerB -> tuning trials
    #
    # Distributed training for each trial depends on the Tuner
    # (kerastuner.BaseTuner) setup in tuner_fn. Currently CloudTuner is single
    # worker training per trial. DistributingCloudTuner (a subclass of
    # CloudTuner) launches remote distributed training job per trial.
    #
    # E.g., single worker training per trial
    #   ... -> CloudTunerA -> single worker training
    #       -> CloudTunerB -> single worker training
    # vs distributed training per trial
    #   ... -> DistributingCloudTunerA -> CAIP job Y -> master,worker1,2,3
    #       -> DistributingCloudTunerB -> CAIP job Z -> master,worker1,2,3
    tuner = Tuner(
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        tune_args=tuner_pb2.TuneArgs(
            # num_parallel_trials=3 means that 3 search loops are
            # running in parallel.
            num_parallel_trials=3),
        custom_config={
            # Note that this TUNING_ARGS_KEY will be used to start the CAIP job
            # for parallel tuning (CAIP job X above).
            #
            # num_parallel_trials will be used to fill/overwrite the
            # workerCount specified by TUNING_ARGS_KEY:
            #   num_parallel_trials = workerCount + 1 (for master)
            ai_platform_tuner_executor.TUNING_ARGS_KEY:
                ai_platform_training_args,
            # This working directory has to be a valid GCS path and will be used
            # to launch remote training job per trial.
            ai_platform_tuner_executor.REMOTE_TRIALS_WORKING_DIR_KEY:
                os.path.join(_pipeline_root, 'trials'),
        })

  # Uses user-provided Python function that trains a model.
  trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_trainer_executor.GenericExecutor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      # If Tuner is in the pipeline, Trainer can take Tuner's output
      # best_hyperparameters artifact as input and utilize it in the user module
      # code.
      #
      # If there isn't Tuner in the pipeline, either use ImporterNode to import
      # a previous Tuner's output to feed to Trainer, or directly use the tuned
      # hyperparameters in user module code and set hyperparameters to None
      # here.
      #
      # Example of ImporterNode,
      #   hparams_importer = ImporterNode(
      #     source_uri='path/to/best_hyperparameters.txt',
      #     artifact_type=HyperParameters).with_id('import_hparams')
      #   ...
      #   hyperparameters = hparams_importer.outputs['result'],
      hyperparameters=(tuner.outputs['best_hyperparameters']
                       if enable_tuning else None),
      train_args={'num_steps': train_steps},
      eval_args={'num_steps': eval_steps},
      custom_config={
          ai_platform_trainer_executor.TRAINING_ARGS_KEY:
              ai_platform_training_args
      })

  # Get the latest blessed model for model validation.
  model_resolver = ResolverNode(
      resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(
          type=ModelBlessing)).with_id('latest_blessed_model_resolver')

  # Uses TFMA to compute evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compared to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='species')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.6}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
          ])
      ])

  evaluator = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  pusher = Pusher(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_pusher_executor.Executor),
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      custom_config={
          ai_platform_pusher_executor.SERVING_ARGS_KEY: ai_platform_serving_args
      },
  )

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      example_validator,
      transform,
      trainer,
      model_resolver,
      evaluator,
      pusher,
  ]
  if enable_tuning:
    components.append(tuner)

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      beam_pipeline_args=beam_pipeline_args)
コード例 #27
0
ファイル: component.py プロジェクト: jay90099/tfx
class InfraValidator(base_component.BaseComponent):
    """A TFX component to validate the model against the serving infrastructure.

  An infra validation is done by loading the model to the exactly same serving
  binary that is used in production, and additionaly sending some requests to
  the model server. Such requests can be specified from Examples artifact.

  ## Examples

  Full example using TensorFlowServing binary running on local docker.

  ```
  infra_validator = InfraValidator(
      model=trainer.outputs['model'],
      examples=test_example_gen.outputs['examples'],
      serving_spec=ServingSpec(
          tensorflow_serving=TensorFlowServing(  # Using TF Serving.
              tags=['latest']
          ),
          local_docker=LocalDockerConfig(),  # Running on local docker.
      ),
      validation_spec=ValidationSpec(
          max_loading_time_seconds=60,
          num_tries=5,
      ),
      request_spec=RequestSpec(
          tensorflow_serving=TensorFlowServingRequestSpec(),
          num_examples=1,
      )
  )
  ```

  Minimal example when running on Kubernetes.

  ```
  infra_validator = InfraValidator(
      model=trainer.outputs['model'],
      examples=test_example_gen.outputs['examples'],
      serving_spec=ServingSpec(
          tensorflow_serving=TensorFlowServing(
              tags=['latest']
          ),
          kubernetes=KubernetesConfig(),  # Running on Kubernetes.
      ),
  )
  ```

  Component `outputs` contains:
   - `blessing`: Channel of type `standard_artifacts.InfraBlessing` that
                 contains the validation result.

  See [the InfraValidator
  guide](https://www.tensorflow.org/tfx/guide/infra_validator) for more
  details.
  """

    SPEC_CLASS = standard_component_specs.InfraValidatorSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)
    DRIVER_CLASS = base_driver.BaseDriver

    def __init__(
            self,
            model: types.BaseChannel,
            serving_spec: infra_validator_pb2.ServingSpec,
            examples: Optional[types.BaseChannel] = None,
            request_spec: Optional[infra_validator_pb2.RequestSpec] = None,
            validation_spec: Optional[
                infra_validator_pb2.ValidationSpec] = None):
        """Construct a InfraValidator component.

    Args:
      model: A `BaseChannel` of `ModelExportPath` type, usually produced by
        [Trainer](https://www.tensorflow.org/tfx/guide/trainer) component.
          _required_
      serving_spec: A `ServingSpec` configuration about serving binary and test
        platform config to launch model server for validation. _required_
      examples: A `BaseChannel` of `ExamplesPath` type, usually produced by
        [ExampleGen](https://www.tensorflow.org/tfx/guide/examplegen) component.
          If not specified, InfraValidator does not issue requests for
          validation.
      request_spec: Optional `RequestSpec` configuration about making requests
        from `examples` input. If not specified, InfraValidator does not issue
        requests for validation.
      validation_spec: Optional `ValidationSpec` configuration.
    """
        blessing = types.Channel(type=standard_artifacts.InfraBlessing)
        spec = standard_component_specs.InfraValidatorSpec(
            model=model,
            examples=examples,
            blessing=blessing,
            serving_spec=serving_spec,
            validation_spec=validation_spec,
            request_spec=request_spec)
        super().__init__(spec=spec)
コード例 #28
0
class Evaluator(base_component.BaseComponent):
    """A TFX component to evaluate models trained by a TFX Trainer component.

  See [Evaluator](https://www.tensorflow.org/tfx/guide/evaluator) for more
  information on what this component's required inputs are, how to configure it,
  and what outputs it produces.
  """

    SPEC_CLASS = EvaluatorSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

    def __init__(
            self,
            examples: types.Channel = None,
            model: types.Channel = None,
            baseline_model: Optional[types.Channel] = None,
            # TODO(b/148618405): deprecate feature_slicing_spec.
            feature_slicing_spec: Optional[Union[
                evaluator_pb2.FeatureSlicingSpec, Dict[Text, Any]]] = None,
            fairness_indicator_thresholds: Optional[List[Union[
                float, data_types.RuntimeParameter]]] = None,
            example_splits: Optional[List[Text]] = None,
            output: Optional[types.Channel] = None,
            model_exports: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None,
            eval_config: Optional[tfma.EvalConfig] = None,
            blessing: Optional[types.Channel] = None,
            schema: Optional[types.Channel] = None,
            module_file: Optional[Text] = None):
        """Construct an Evaluator component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, usually
        produced by an ExampleGen component. _required_
      model: A Channel of type `standard_artifacts.Model`, usually produced by
        a Trainer component.
      baseline_model: An optional channel of type 'standard_artifacts.Model' as
        the baseline model for model diff and model validation purpose.
      feature_slicing_spec:
        Deprecated, please use eval_config instead. Only support estimator.
        [evaluator_pb2.FeatureSlicingSpec](https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto)
          instance that describes how Evaluator should slice the data. If any
          field is provided as a RuntimeParameter, feature_slicing_spec should
          be constructed as a dict with the same field names as
          FeatureSlicingSpec proto message.
      fairness_indicator_thresholds: Optional list of float (or
        RuntimeParameter) threshold values for use with TFMA fairness
          indicators. Experimental functionality: this interface and
          functionality may change at any time. TODO(b/142653905): add a link
          to additional documentation for TFMA fairness indicators here.
      example_splits: Names of splits on which the metrics are computed.
        Default behavior (when example_splits is set to None or Empty) is using
        the 'eval' split.
      output: Channel of `ModelEvaluation` to store the evaluation results.
      model_exports: Backwards compatibility alias for the `model` argument.
      instance_name: Optional name assigned to this specific instance of
        Evaluator. Required only if multiple Evaluator components are declared
        in the same pipeline.  Either `model_exports` or `model` must be present
        in the input arguments.
      eval_config: Instance of tfma.EvalConfig containg configuration settings
        for running the evaluation. This config has options for both estimator
        and Keras.
      blessing: Output channel of 'ModelBlessing' that contains the
        blessing result.
      schema: A `Schema` channel to use for TFXIO.
      module_file: A path to python module file containing UDFs for Evaluator
        customization. The module_file can implement following functions at its
        top level.
          def custom_eval_shared_model(
             eval_saved_model_path, model_name, eval_config, **kwargs,
          ) -> tfma.EvalSharedModel:
          def custom_extractors(
            eval_shared_model, eval_config, tensor_adapter_config,
          ) -> List[tfma.extractors.Extractor]:
    """
        if eval_config is not None and feature_slicing_spec is not None:
            raise ValueError(
                "Exactly one of 'eval_config' or 'feature_slicing_spec' "
                "must be supplied.")
        if eval_config is None and feature_slicing_spec is None:
            feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec()
            logging.info(
                'Neither eval_config nor feature_slicing_spec is passed, '
                'the model is treated as estimator.')

        if model_exports:
            logging.warning(
                'The "model_exports" argument to the Evaluator component has '
                'been renamed to "model" and is deprecated. Please update your '
                'usage as support for this argument will be removed soon.')
            model = model_exports

        if feature_slicing_spec:
            logging.warning('feature_slicing_spec is deprecated, please use '
                            'eval_config instead.')

        blessing = blessing or types.Channel(
            type=standard_artifacts.ModelBlessing)
        evaluation = output or types.Channel(
            type=standard_artifacts.ModelEvaluation)
        spec = EvaluatorSpec(
            examples=examples,
            model=model,
            baseline_model=baseline_model,
            feature_slicing_spec=feature_slicing_spec,
            fairness_indicator_thresholds=fairness_indicator_thresholds,
            example_splits=json_utils.dumps(example_splits),
            evaluation=evaluation,
            eval_config=eval_config,
            blessing=blessing,
            schema=schema,
            module_file=module_file)
        super(Evaluator, self).__init__(spec=spec, instance_name=instance_name)
コード例 #29
0
class Trainer(base_component.BaseComponent):
  """A TFX component to train a TensorFlow model.

  The Trainer component is used to train and eval a model using given inputs and
  a user-supplied estimator.

  ## Providing an estimator
  The TFX executor will use the estimator provided in the `module_file` file
  to train the model.  The Trainer executor will look specifically for the
  `trainer_fn()` function within that file.  Before training, the executor will
  call that function expecting the following returned as a dictionary:

    - estimator: The
    [estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
    to be used by TensorFlow to train the model.
    - train_spec: The
    [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec)
    to be used by the "train" part of the TensorFlow `train_and_evaluate()`
    call.
    - eval_spec: The
    [configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec)
    to be used by the "eval" part of the TensorFlow `train_and_evaluate()` call.
    - eval_input_receiver_fn: The
    [configuration](https://www.tensorflow.org/tfx/model_analysis/get_started#modify_an_existing_model)
    to be used
    by the [ModelValidator](https://www.tensorflow.org/tfx/guide/modelval)
    component when validating the model.

  An example of `trainer_fn()` can be found in the [user-supplied
  code]((https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py))
  of the TFX Chicago Taxi pipeline example.

  *Note:* The default executor for this component trains locally.  This can be
  overriden to enable the model to be trained on other platforms.  The [Cloud AI
  Platform custom
  executor](https://github.com/tensorflow/tfx/tree/master/tfx/extensions/google_cloud_ai_platform/trainer)
  provides an example how to implement this.

  Please see https://www.tensorflow.org/guide/estimators for more details.

  ## Example 1: Training locally
  ```
  # Uses user-provided Python function that implements a model using TF-Learn.
  trainer = Trainer(
      module_file=module_file,
      transformed_examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=5000))
  ```

  ## Example 2: Training through a cloud provider
  ```
  from tfx.extensions.google_cloud_ai_platform.trainer import executor as
  ai_platform_trainer_executor
  # Train using Google Cloud AI Platform.
  trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_trainer_executor.Executor),
      module_file=module_file,
      transformed_examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=5000))
  ```
  """

  SPEC_CLASS = TrainerSpec
  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(
      self,
      examples: types.Channel = None,
      transformed_examples: Optional[types.Channel] = None,
      transform_graph: Optional[types.Channel] = None,
      schema: types.Channel = None,
      base_model: Optional[types.Channel] = None,
      hyperparameters: Optional[types.Channel] = None,
      module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      # TODO(b/147702778): deprecate trainer_fn.
      trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
      eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
      custom_config: Optional[Dict[Text, Any]] = None,
      custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
      output: Optional[types.Channel] = None,
      model_run: Optional[types.Channel] = None,
      transform_output: Optional[types.Channel] = None,
      instance_name: Optional[Text] = None):
    """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  A Channel of type `standard_artifacts.Schema`, serving as the
        schema of training and eval data.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.

        For default executor, The module_file must implement a function named
        `trainer_fn` at its top level. The function must have the following
        signature.

        def trainer_fn(trainer.executor.TrainerFnArgs,
                       tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
          ...

        where the returned Dict has the following key-values.
          'estimator': an instance of tf.estimator.Estimator
          'train_spec': an instance of tf.estimator.TrainSpec
          'eval_spec': an instance of tf.estimator.EvalSpec
          'eval_input_receiver_fn': an instance of
            tfma.export.EvalInputReceiver. Exactly one of 'module_file' or
            'trainer_fn' must be supplied.

        For generic executor, The module_file must implement a function named
        `run_fn` at its top level with function signature:
        `def run_fn(trainer.executor.TrainerFnArgs)`, and the trained model must
        be saved to TrainerFnArgs.serving_model_dir when execute this function.
      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor.
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance or a dict, containing args
        used for training. Currently only splits and num_steps are available. If
        it's provided as a dict and any field is a RuntimeParameter, it should
        have the same field names as a TrainArgs proto message. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args
        used for evaluation. Currently only splits and num_steps are available.
        If it's provided as a dict and any field is a RuntimeParameter, it
        should have the same field names as a EvalArgs proto message. Default
        behavior (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec.
      output: Optional `Model` channel for result of exported models.
      model_run: Optional `ModelRun` channel, as the working dir of models,
        can be used to output non-model related output (e.g., TensorBoard logs).
      transform_output: Backwards compatibility alias for the 'transform_graph'
        argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        Trainer components are declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
    if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1:
      raise ValueError(
          "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
          "supplied.")

    if bool(examples) == bool(transformed_examples):
      raise ValueError(
          "Exactly one of 'example' or 'transformed_example' must be supplied.")

    if transform_output:
      absl.logging.warning(
          'The "transform_output" argument to the Trainer component has '
          'been renamed to "transform_graph" and is deprecated. Please update '
          "your usage as support for this argument will be removed soon.")
      transform_graph = transform_output
    if transformed_examples and not transform_graph:
      raise ValueError("If 'transformed_examples' is supplied, "
                       "'transform_graph' must be supplied too.")
    examples = examples or transformed_examples
    output = output or types.Channel(type=standard_artifacts.Model)
    model_run = model_run or types.Channel(type=standard_artifacts.ModelRun)
    spec = TrainerSpec(
        examples=examples,
        transform_graph=transform_graph,
        schema=schema,
        base_model=base_model,
        hyperparameters=hyperparameters,
        train_args=train_args,
        eval_args=eval_args,
        module_file=module_file,
        run_fn=run_fn,
        trainer_fn=trainer_fn,
        custom_config=json_utils.dumps(custom_config),
        model=output,
        # TODO(b/158106209): change the model_run as optional output artifact
        model_run=model_run)
    super(Trainer, self).__init__(
        spec=spec,
        custom_executor_spec=custom_executor_spec,
        instance_name=instance_name)
コード例 #30
0
ファイル: component.py プロジェクト: konny0311/tfx
class Tuner(base_component.BaseComponent):
    """A TFX component for model hyperparameter tuning."""

    SPEC_CLASS = TunerSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

    def __init__(self,
                 examples: types.Channel = None,
                 schema: Optional[types.Channel] = None,
                 transform_graph: Optional[types.Channel] = None,
                 module_file: Optional[Text] = None,
                 tuner_fn: Optional[Text] = None,
                 train_args: trainer_pb2.TrainArgs = None,
                 eval_args: trainer_pb2.EvalArgs = None,
                 tune_args: Optional[tuner_pb2.TuneArgs] = None,
                 custom_config: Optional[Dict[Text, Any]] = None,
                 best_hyperparameters: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Construct a Tuner component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as the
        source of examples that are used in tuning (required).
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. This is used when raw examples
        are provided.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present. This is used when transformed examples are provided.
      module_file: A path to python module file containing UDF tuner definition.
        The module_file must implement a function named `tuner_fn` at its top
        level. The function must have the following signature.
            def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
        Exactly one of 'module_file' or 'tuner_fn' must be supplied.
      tuner_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF. Exactly one of
        'module_file' or 'tuner_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Currently only splits and num_steps are available. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is evaluate on `eval` split.
      tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning.
        Currently only num_parallel_trials is available.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      best_hyperparameters: Optional Channel of type
        `standard_artifacts.HyperParameters` for result of the best hparams.
      instance_name: Optional unique instance name. Necessary if multiple Tuner
        components are declared in the same pipeline.
    """
        if bool(module_file) == bool(tuner_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'tuner_fn' must be supplied")

        best_hyperparameters = best_hyperparameters or types.Channel(
            type=standard_artifacts.HyperParameters)
        spec = TunerSpec(
            examples=examples,
            schema=schema,
            transform_graph=transform_graph,
            module_file=module_file,
            tuner_fn=tuner_fn,
            train_args=train_args,
            eval_args=eval_args,
            tune_args=tune_args,
            best_hyperparameters=best_hyperparameters,
            custom_config=json_utils.dumps(custom_config),
        )
        super(Tuner, self).__init__(spec=spec, instance_name=instance_name)