Beispiel #1
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text) -> pipeline.Pipeline:
    """Implements the cifar10 pipeline with TFX."""
    examples = external_input(data_root)
    input_split = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
        example_gen_pb2.Input.Split(name='eval', pattern='test.tfrecord')
    ])
    example_gen = ImportExampleGen(input=examples, input_config=input_split)
    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                             infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    validate_stats = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=infer_schema.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=infer_schema.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that implements a model using TF-Learn.
    trainer = Trainer(module_file=module_file,
                      examples=transform.outputs['transformed_examples'],
                      schema=infer_schema.outputs['schema'],
                      transform_graph=transform.outputs['transform_graph'],
                      train_args=trainer_pb2.TrainArgs(num_steps=1000),
                      eval_args=trainer_pb2.EvalArgs(num_steps=500))

    # Uses TFMA to compute a evaluation statistics over features of a model.
    evaluator = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(
            specs=[evaluator_pb2.SingleSlicingSpec()]))

    # Performs quality validation of a candidate model (compared to a baseline).
    model_validator = ModelValidator(examples=example_gen.outputs['examples'],
                                     model=trainer.outputs['model'])

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=model_validator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, infer_schema, validate_stats,
            transform, trainer, evaluator, model_validator, pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
    )
Beispiel #2
0
def create_pipeline(pipeline_name: Text,
                    pipeline_root: Text,
                    data_root_uri: data_types.RuntimeParameter,
                    train_steps: data_types.RuntimeParameter,
                    eval_steps: data_types.RuntimeParameter,
                    enable_tuning: bool,
                    ai_platform_training_args: Dict[Text, Text],
                    ai_platform_serving_args: Dict[Text, Text],
                    beam_pipeline_args: List[Text],
                    enable_cache: Optional[bool] = False) -> pipeline.Pipeline:
    """Trains and deploys the Keras Covertype Classifier with TFX and Kubeflow Pipeline on Google Cloud.
  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root_uri: uri of the dataset.
    train_steps: runtime parameter for number of model training steps for the Trainer component.
    eval_steps: runtime parameter for number of model evaluation steps for the Trainer component.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.    
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options
    enable_cache: Optional boolean
  Returns:
    A TFX pipeline object.
  """

    # Brings data into the pipeline and splits the data into training and eval splits
    output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))

    examplegen = CsvExampleGen(input_base=data_root_uri)

    # Computes statistics over data for visualization and example validation.
    statisticsgen = StatisticsGen(examples=examplegen.outputs.examples)

    # Import a user-provided schema
    import_schema = ImporterNode(instance_name='import_user_schema',
                                 source_uri=SCHEMA_FOLDER,
                                 artifact_type=Schema)

    # Generates schema based on statistics files. Even though, we use user-provided schema
    # we still want to generate the schema of the newest data for tracking and comparison
    schemagen = SchemaGen(statistics=statisticsgen.outputs.statistics)

    # Performs anomaly detection based on statistics and data schema.
    examplevalidator = ExampleValidator(
        statistics=statisticsgen.outputs.statistics,
        schema=import_schema.outputs.result)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=examplegen.outputs.examples,
                          schema=import_schema.outputs.result,
                          module_file=TRANSFORM_MODULE_FILE)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        # The Tuner component launches 1 AI Platform Training job for flock management.
        # For example, 3 workers (defined by num_parallel_trials) in the flock
        # management AI Platform Training job, each runs Tuner.Executor.
        tuner = Tuner(
            module_file=TRAIN_MODULE_FILE,
            examples=transform.outputs.transformed_examples,
            transform_graph=transform.outputs.transform_graph,
            train_args={'num_steps': train_steps},
            eval_args={'num_steps': eval_steps},
            tune_args=tuner_pb2.TuneArgs(
                # num_parallel_trials=3 means that 3 search loops are running in parallel.
                num_parallel_trials=3),
            custom_config={
                # Configures Cloud AI Platform-specific configs. For details, see
                # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
                ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                ai_platform_training_args
            })

    # Trains the model using a user provided trainer function.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.GenericExecutor),
        module_file=TRAIN_MODULE_FILE,
        transformed_examples=transform.outputs.transformed_examples,
        schema=import_schema.outputs.result,
        transform_graph=transform.outputs.transform_graph,
        hyperparameters=(tuner.outputs.best_hyperparameters
                         if enable_tuning else None),
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={'ai_platform_training_args': ai_platform_training_args})

    # Get the latest blessed model for model validation.
    resolver = ResolverNode(instance_name='latest_blessed_model_resolver',
                            resolver_class=latest_blessed_model_resolver.
                            LatestBlessedModelResolver,
                            model=Channel(type=Model),
                            model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute a evaluation statistics over features of a model.
    accuracy_threshold = tfma.MetricThreshold(
        value_threshold=tfma.GenericValueThreshold(lower_bound={'value': 0.5},
                                                   upper_bound={'value':
                                                                0.99}), )

    metrics_specs = tfma.MetricsSpec(metrics=[
        tfma.MetricConfig(class_name='SparseCategoricalAccuracy',
                          threshold=accuracy_threshold),
        tfma.MetricConfig(class_name='ExampleCount')
    ])

    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='Cover_Type')],
        metrics_specs=[metrics_specs],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['Wilderness_Area'])
        ])

    evaluator = Evaluator(examples=examplegen.outputs.examples,
                          model=trainer.outputs.model,
                          baseline_model=resolver.outputs.model,
                          eval_config=eval_config)

    # Validate model can be loaded and queried in sand-boxed environment
    # mirroring production.
    serving_config = infra_validator_pb2.ServingSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServing(
            tags=['latest']),
        kubernetes=infra_validator_pb2.KubernetesConfig(),
    )

    validation_config = infra_validator_pb2.ValidationSpec(
        max_loading_time_seconds=60,
        num_tries=3,
    )

    request_config = infra_validator_pb2.RequestSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServingRequestSpec(),
        num_examples=3,
    )

    infravalidator = InfraValidator(
        model=trainer.outputs.model,
        examples=examplegen.outputs.examples,
        serving_spec=serving_config,
        validation_spec=validation_config,
        request_spec=request_config,
    )

    # Checks whether the model passed the validation steps and pushes the model
    # to CAIP Prediction if checks are passed.
    pusher = Pusher(custom_executor_spec=executor_spec.ExecutorClassSpec(
        ai_platform_pusher_executor.Executor),
                    model=trainer.outputs.model,
                    model_blessing=evaluator.outputs.blessing,
                    infra_blessing=infravalidator.outputs.blessing,
                    custom_config={
                        ai_platform_pusher_executor.SERVING_ARGS_KEY:
                        ai_platform_serving_args
                    })

    components = [
        examplegen, statisticsgen, import_schema, schemagen, examplegen,
        transform, trainer, resolver, evaluator, infravalidator, pusher
    ]

    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=components,
                             enable_cache=enable_cache,
                             beam_pipeline_args=beam_pipeline_args)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir_lite: Text,
                     metadata_path: Text, labels_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the CIFAR10 image classification pipeline using TFX."""
    # This is needed for datasets with pre-defined splits
    # Change the pattern argument to train_whole/* and test_whole/* to train
    # on the whole CIFAR-10 dataset
    input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='train/*'),
        example_gen_pb2.Input.Split(name='eval', pattern='test/*')
    ])

    examples = external_input(data_root)

    # Brings data into the pipeline.
    example_gen = ImportExampleGen(input=examples, input_config=input_config)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that trains a model.
    # When traning on the whole dataset, use 18744 for train steps, 156 for eval
    # steps. 18744 train steps correspond to 24 epochs on the whole train set, and
    # 156 eval steps correspond to 1 epoch on the whole test set. The
    # configuration below is for training on the dataset we provided in the data
    # folder, which has 128 train and 128 test samples. The 160 train steps
    # correspond to 40 epochs on this tiny train set, and 4 eval steps correspond
    # to 1 epoch on this tiny test set.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        train_args=trainer_pb2.TrainArgs(num_steps=160),
        eval_args=trainer_pb2.EvalArgs(num_steps=4),
        custom_config={'labels_path': labels_path})

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute an evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compare to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[
            tfma.ModelSpec(label_key='label_xf', model_type='tf_lite')
        ],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.55}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-3})))
            ])
        ])

    # Uses TFMA to compute the evaluation statistics over features of a model.
    # We evaluate using the materialized examples that are output by Transform
    # because
    # 1. the decoding_png function currently performed within Transform are not
    # compatible with TFLite.
    # 2. MLKit requires deserialized (float32) tensor image inputs
    # Note that for deployment, the same logic that is performed within Transform
    # must be reproduced client-side.
    evaluator = Evaluator(examples=transform.outputs['transformed_examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir_lite)))

    components = [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        trainer, model_resolver, evaluator, pusher
    ]

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
Beispiel #4
0
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     module_file: str, serving_model_dir: str,
                     metadata_path: str,
                     beam_pipeline_args: List[str]) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Get the latest model so that we can warm start from the model.
    latest_model_resolver = resolver.Resolver(
        strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
        latest_model=Channel(type=Model)).with_id('latest_model_resolver')

    # Uses user-provided Python function that implements a model.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        base_model=latest_model_resolver.outputs['latest_model'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000))

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        # Change threshold will be ignored if there is no baseline (first run).
        eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            schema_gen,
            example_validator,
            transform,
            latest_model_resolver,
            trainer,
            model_resolver,
            evaluator,
            pusher,
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)
Beispiel #5
0
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    module_file: Text,
    ai_platform_training_args: Dict[Text, Text],
    ai_platform_serving_args: Dict[Text, Text],
    beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX and Kubeflow Pipelines.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: List of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.

  Returns:
    A TFX pipeline object.
  """

    # The rate at which to sample rows from the Taxi dataset using BigQuery.
    # The full taxi dataset is > 200M record.  In the interest of resource
    # savings and time, we've set the default for this example to be much smaller.
    # Feel free to crank it up and process the full dataset!
    # By default it generates a 0.1% random sample.
    query_sample_rate = data_types.RuntimeParameter(name='query_sample_rate',
                                                    ptype=float,
                                                    default=0.001)

    # This is the upper bound of FARM_FINGERPRINT in Bigquery (ie the max value of
    # signed int64).
    max_int64 = '0x7FFFFFFFFFFFFFFF'

    # The query that extracts the examples from BigQuery. The Chicago Taxi dataset
    # used for this example is a public dataset available on Google AI Platform.
    # https://console.cloud.google.com/marketplace/details/city-of-chicago-public-data/chicago-taxi-trips
    query = """
          SELECT
            pickup_community_area,
            fare,
            EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month,
            EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour,
            EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day,
            UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp,
            pickup_latitude,
            pickup_longitude,
            dropoff_latitude,
            dropoff_longitude,
            trip_miles,
            pickup_census_tract,
            dropoff_census_tract,
            payment_type,
            company,
            trip_seconds,
            dropoff_community_area,
            tips
          FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips`
          WHERE (ABS(FARM_FINGERPRINT(unique_key)) / {max_int64})
            < {query_sample_rate}""".format(
        max_int64=max_int64, query_sample_rate=str(query_sample_rate))

    # Number of epochs in training.
    train_steps = data_types.RuntimeParameter(
        name='train_steps',
        default=10000,
        ptype=int,
    )

    # Number of epochs in evaluation.
    eval_steps = data_types.RuntimeParameter(
        name='eval_steps',
        default=5000,
        ptype=int,
    )

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        query=query)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Update ai_platform_training_args if distributed training was enabled.
    # Number of worker machines used in distributed training.
    worker_count = data_types.RuntimeParameter(
        name='worker_count',
        default=2,
        ptype=int,
    )

    # Type of worker machines used in distributed training.
    worker_type = data_types.RuntimeParameter(
        name='worker_type',
        default='standard',
        ptype=str,
    )

    ai_platform_training_args = copy.copy(ai_platform_training_args)
    if FLAGS.distributed_training:
        ai_platform_training_args.update({
            # You can specify the machine types, the number of replicas for workers
            # and parameter servers.
            # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#ScaleTier
            'scaleTier': 'CUSTOM',
            'masterType': 'large_model',
            'workerType': worker_type,
            'parameterServerType': 'standard',
            'workerCount': worker_count,
            'parameterServerCount': 1
        })

    # Uses user-provided Python function that implements a model using TF-Learn
    # to train a model on Google Cloud AI Platform.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.Executor),
        module_file=module_file,
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={
            ai_platform_trainer_executor.TRAINING_ARGS_KEY:
            ai_platform_training_args
        })

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to  Google Cloud AI Platform if check passed.
    # TODO(b/162451308): Add pusher back to components list once AIP Prediction
    # Service supports TF>=2.3.
    _ = Pusher(custom_executor_spec=executor_spec.ExecutorClassSpec(
        ai_platform_pusher_executor.Executor),
               model=trainer.outputs['model'],
               model_blessing=evaluator.outputs['blessing'],
               custom_config={
                   ai_platform_pusher_executor.SERVING_ARGS_KEY:
                   ai_platform_serving_args
               })

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, schema_gen, example_validator,
            transform, trainer, model_resolver, evaluator
        ],
        beam_pipeline_args=beam_pipeline_args,
    )
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_root: Text,
    module_file: Text,
    ai_platform_training_args: Dict[Text, Text],
    ai_platform_serving_args: Dict[Text, Text],
    enable_tuning: bool,
    beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
    """Implements the penguin pipeline with TFX and Kubeflow Pipeline.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root: uri of the penguin data.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.
    beam_pipeline_args: List of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.

  Returns:
    A TFX pipeline object.
  """
    # Number of epochs in training.
    train_steps = data_types.RuntimeParameter(
        name='train_steps',
        default=100,
        ptype=int,
    )

    # Number of epochs in evaluation.
    eval_steps = data_types.RuntimeParameter(
        name='eval_steps',
        default=50,
        ptype=int,
    )

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Update ai_platform_training_args if distributed training was enabled.
    # Number of worker machines used in distributed training.
    worker_count = data_types.RuntimeParameter(
        name='worker_count',
        default=2,
        ptype=int,
    )

    # Type of worker machines used in distributed training.
    worker_type = data_types.RuntimeParameter(
        name='worker_type',
        default='standard',
        ptype=str,
    )

    ai_platform_training_args = copy.copy(ai_platform_training_args)
    if FLAGS.distributed_training:
        ai_platform_training_args.update({
            # You can specify the machine types, the number of replicas for workers
            # and parameter servers.
            # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#ScaleTier
            'scaleTier': 'CUSTOM',
            'masterType': 'large_model',
            'workerType': worker_type,
            'parameterServerType': 'standard',
            'workerCount': worker_count,
            'parameterServerCount': 1,
        })

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        # The Tuner component launches 1 CAIP Training job for flock management.
        # For example, 3 workers (defined by num_parallel_trials) in the flock
        # management CAIP Training job, each runs Tuner.Executor.
        # Then, 3 CAIP Training Jobs (defined by training_args) are invoked
        # from each worker in the flock management Job for Trial execution.
        tuner = Tuner(
            module_file=module_file,
            examples=transform.outputs['transformed_examples'],
            transform_graph=transform.outputs['transform_graph'],
            train_args={'num_steps': train_steps},
            eval_args={'num_steps': eval_steps},
            tune_args=tuner_pb2.TuneArgs(
                # num_parallel_trials=3 means that 3 search loops are
                # running in parallel.
                # Each tuner may include a distributed training job which can be
                # specified in training_args above (e.g. 1 PS + 2 workers).
                num_parallel_trials=3),
            custom_config={
                # Configures Cloud AI Platform-specific configs . For details, see
                # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
                ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                ai_platform_training_args
            })

    # Uses user-provided Python function that trains a model.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.GenericExecutor),
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        # If Tuner is in the pipeline, Trainer can take Tuner's output
        # best_hyperparameters artifact as input and utilize it in the user module
        # code.
        #
        # If there isn't Tuner in the pipeline, either use ImporterNode to import
        # a previous Tuner's output to feed to Trainer, or directly use the tuned
        # hyperparameters in user module code and set hyperparameters to None
        # here.
        #
        # Example of ImporterNode,
        #   hparams_importer = ImporterNode(
        #     instance_name='import_hparams',
        #     source_uri='path/to/best_hyperparameters.txt',
        #     artifact_type=HyperParameters)
        #   ...
        #   hyperparameters = hparams_importer.outputs['result'],
        hyperparameters=(tuner.outputs['best_hyperparameters']
                         if enable_tuning else None),
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={
            ai_platform_trainer_executor.TRAINING_ARGS_KEY:
            ai_platform_training_args
        })

    # Get the latest blessed model for model validation.
    model_resolver = ResolverNode(
        instance_name='latest_blessed_model_resolver',
        resolver_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='species')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])

    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    pusher = Pusher(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_pusher_executor.Executor),
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        custom_config={
            ai_platform_pusher_executor.SERVING_ARGS_KEY:
            ai_platform_serving_args
        },
    )

    components = [
        example_gen,
        statistics_gen,
        schema_gen,
        example_validator,
        transform,
        trainer,
        model_resolver,
        evaluator,
        pusher,
    ]
    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=components,
                             enable_cache=True,
                             beam_pipeline_args=beam_pipeline_args)