示例#1
0
  def testDoWithTuneArgsAndTrainingInputOverride(self):
    executor = ai_platform_tuner_executor.Executor()
    self._exec_properties['tune_args'] = proto_utils.proto_to_json(
        tuner_pb2.TuneArgs(num_parallel_trials=6))

    self._exec_properties['custom_config'][
        ai_platform_trainer_executor.TRAINING_ARGS_KEY].update({
            'scaleTier': 'CUSTOM',
            'masterType': 'n1-highmem-16',
            'workerType': 'n1-highmem-16',
            'workerCount': 2,
        })

    executor.Do(self._inputs, self._outputs,
                self._serialize_custom_config_under_test())

    self.mock_runner.start_aip_training.assert_called_with(
        self._inputs,
        self._outputs,
        self._serialize_custom_config_under_test(),
        self._executor_class_path,
        {
            'project': self._project_id,
            'jobDir': self._job_dir,
            # Confirm scale tier and machine types are not overritten.
            'scaleTier': 'CUSTOM',
            'masterType': 'n1-highmem-16',
            'workerType': 'n1-highmem-16',
            # Confirm workerCount has been adjusted to num_parallel_trials.
            'workerCount': 5,
        },
        mock.ANY)
示例#2
0
 def testVertexDistributedTunerPipeline(self):
     """Tuner-only pipeline for distributed Tuner flock on Vertex AI Training."""
     pipeline_name = self._make_unique_pipeline_name(
         'kubeflow-vertex-dist-tuner')
     pipeline = self._create_pipeline(
         pipeline_name,
         [
             self.penguin_examples_importer,
             self.penguin_schema_importer,
             ai_platform_tuner_component.Tuner(
                 examples=self.penguin_examples_importer.outputs['result'],
                 module_file=self._penguin_tuner_module,
                 schema=self.penguin_schema_importer.outputs['result'],
                 train_args=trainer_pb2.TrainArgs(num_steps=10),
                 eval_args=trainer_pb2.EvalArgs(num_steps=5),
                 # 3 worker parallel tuning.
                 tune_args=tuner_pb2.TuneArgs(num_parallel_trials=3),
                 custom_config={
                     ai_platform_tuner_executor.TUNING_ARGS_KEY:
                     self._getVertexTrainingArgs(pipeline_name),
                     constants.ENABLE_VERTEX_KEY:
                     True,
                     constants.VERTEX_REGION_KEY:
                     self._GCP_REGION
                 })
         ])
     self._compile_and_run_pipeline(pipeline)
     self._assertHyperparametersAreWritten(pipeline_name)
示例#3
0
    def testTuneArgs(self):
        with self.assertRaises(ValueError):
            self._exec_properties['tune_args'] = proto_utils.proto_to_json(
                tuner_pb2.TuneArgs(num_parallel_trials=3))

            tuner = executor.Executor(self._context)
            tuner.Do(input_dict=self._input_dict,
                     output_dict=self._output_dict,
                     exec_properties=self._exec_properties)
示例#4
0
 def setUp(self):
   super(TunerTest, self).setUp()
   self.examples = channel_utils.as_channel([standard_artifacts.Examples()])
   self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
   self.transform_graph = channel_utils.as_channel(
       [standard_artifacts.TransformGraph()])
   self.train_args = trainer_pb2.TrainArgs(splits=['train'], num_steps=100)
   self.eval_args = trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)
   self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3)
示例#5
0
    def testTuneArgs(self):
        with self.assertRaises(ValueError):
            self._exec_properties['tune_args'] = json_format.MessageToJson(
                tuner_pb2.TuneArgs(num_parallel_trials=3),
                preserving_proto_field_name=True)

            tuner = executor.Executor(self._context)
            tuner.Do(input_dict=self._input_dict,
                     output_dict=self._output_dict,
                     exec_properties=self._exec_properties)
示例#6
0
文件: executor.py 项目: yifanmai/tfx
def get_tune_args(
        exec_properties: Dict[Text, Any]) -> Optional[tuner_pb2.TuneArgs]:
    """Returns TuneArgs protos from execution properties, if present."""
    tune_args = exec_properties.get(_TUNE_ARGS_KEY)
    if not tune_args:
        return None

    result = tuner_pb2.TuneArgs()
    json_format.Parse(tune_args, result)

    return result
示例#7
0
文件: executor.py 项目: jay90099/tfx
def get_tune_args(
        exec_properties: Dict[str, Any]) -> Optional[tuner_pb2.TuneArgs]:
    """Returns TuneArgs protos from execution properties, if present."""
    tune_args = exec_properties.get(standard_component_specs.TUNE_ARGS_KEY)
    if not tune_args:
        return None

    result = tuner_pb2.TuneArgs()
    proto_utils.json_to_proto(tune_args, result)

    return result
示例#8
0
    def testTuneArgs(self):
        with self.assertRaises(ValueError):
            self._exec_properties[standard_component_specs.
                                  TUNE_ARGS_KEY] = proto_utils.proto_to_json(
                                      tuner_pb2.TuneArgs(
                                          num_parallel_trials=3))

            tuner = executor.Executor(self._context)
            tuner.Do(input_dict=self._input_dict,
                     output_dict=self._output_dict,
                     exec_properties=self._exec_properties)
示例#9
0
 def setUp(self):
     super().setUp()
     self.examples = channel_utils.as_channel(
         [standard_artifacts.Examples()])
     self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
     self.transform_graph = channel_utils.as_channel(
         [standard_artifacts.TransformGraph()])
     self.train_args = trainer_pb2.TrainArgs(num_steps=100)
     self.eval_args = trainer_pb2.EvalArgs(num_steps=50)
     self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3)
     self.custom_config = {'key': 'value'}
示例#10
0
  def testDoWithTuneArgs(self):
    executor = ai_platform_tuner_executor.Executor()
    self._exec_properties['tune_args'] = proto_utils.proto_to_json(
        tuner_pb2.TuneArgs(num_parallel_trials=3))

    executor.Do(self._inputs, self._outputs,
                self._serialize_custom_config_under_test())

    self.mock_runner.start_aip_training.assert_called_with(
        self._inputs, self._outputs, self._serialize_custom_config_under_test(),
        self._executor_class_path, {
            'project': self._project_id,
            'jobDir': self._job_dir,
            'scaleTier': 'CUSTOM',
            'masterType': 'standard',
            'workerType': 'standard',
            'workerCount': 2,
        }, mock.ANY)
示例#11
0
    def setUp(self):
        super(ComponentTest, self).setUp()

        examples_artifact = standard_artifacts.Examples()
        examples_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        transform_output = standard_artifacts.TransformGraph()

        self.examples = channel_utils.as_channel([examples_artifact])
        self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
        self.transform_graph = channel_utils.as_channel([transform_output])
        self.custom_config = {'some': 'thing', 'some other': 1, 'thing': 2}
        self.train_args = trainer_pb2.TrainArgs(num_steps=100)
        self.eval_args = trainer_pb2.EvalArgs(num_steps=50)
        self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3)
        self.warmup_hyperparams = channel_utils.as_channel(
            [artifacts.KCandidateHyperParameters()])
        self.meta_model = channel_utils.as_channel(
            [standard_artifacts.Model()])
 def testAIPlatformDistributedTunerPipeline(self):
     """Tuner-only pipeline for distributed Tuner flock on AIP Training."""
     pipeline_name = 'kubeflow-aip-dist-tuner-test-{}'.format(
         test_utils.random_id())
     pipeline = self._create_pipeline(
         pipeline_name,
         [
             self.iris_examples_importer,
             self.iris_schema_importer,
             ai_platform_tuner_component.Tuner(
                 examples=self.iris_examples_importer.outputs['result'],
                 module_file=self._iris_tuner_module,
                 schema=self.iris_schema_importer.outputs['result'],
                 train_args=trainer_pb2.TrainArgs(num_steps=10),
                 eval_args=trainer_pb2.EvalArgs(num_steps=5),
                 # 3 worker parallel tuning.
                 tune_args=tuner_pb2.TuneArgs(num_parallel_trials=3),
                 custom_config={
                     ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                     self._getCaipTrainingArgs(pipeline_name)
                 })
         ])
     self._compile_and_run_pipeline(pipeline)
     self._assertHyperparametersAreWritten(pipeline_name)
def create_pipeline(pipeline_name: Text,
                    pipeline_root: Text,
                    data_root_uri,
                    trainer_config: TrainerConfig,
                    tuner_config: TunerConfig,
                    pusher_config: PusherConfig,
                    runtime_parameters_config: RuntimeParametersConfig = None,
                    str_runtime_parameters_supported = False,
                    int_runtime_parameters_supported = False,
                    local_run: bool = True,
                    beam_pipeline_args: Optional[List[Text]] = None,
                    enable_cache: Optional[bool] = True,
                    code_folder = '',
                    metadata_connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None
                    ) -> pipeline.Pipeline:
    """Trains and deploys the Keras Titanic Classifier with TFX and Kubeflow Pipeline on Google Cloud.
  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root_uri: uri of the dataset.
    train_steps: runtime parameter for number of model training steps for the Trainer component.
    eval_steps: runtime parameter for number of model evaluation steps for the Trainer component.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.    
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options
    enable_cache: Optional boolean
  Returns:
    A TFX pipeline object.
  """
    #pydevd_pycharm.settrace('localhost', port=9091, stdoutToServer=True, stderrToServer=True)

    absl.logging.info('pipeline_name: %s' % pipeline_name)
    absl.logging.info('pipeline root: %s' % pipeline_root)
    absl.logging.info('data_root_uri for training: %s' % data_root_uri)

    absl.logging.info('train_steps for training: %s' % trainer_config.train_steps)
    absl.logging.info('tuner_steps for tuning: %s' % tuner_config.tuner_steps)
    absl.logging.info('eval_steps for evaluating: %s' % trainer_config.eval_steps)

    absl.logging.info('os default list dir: %s' % os.listdir('.'))

    schema_proper_folder = os.path.join(os.sep, code_folder, SCHEMA_FOLDER)
    absl.logging.info('schema_proper_folder: %s' % schema_proper_folder)

    preprocessing_proper_file = os.path.join(os.sep, code_folder, TRANSFORM_MODULE_FILE)
    absl.logging.info('preprocessing_proper_file: %s' % preprocessing_proper_file)

    model_proper_file = os.path.join(os.sep, code_folder, TRAIN_MODULE_FILE)
    absl.logging.info('model_proper_file: %s' % model_proper_file)

    hyperparameters_proper_folder = os.path.join(os.sep, code_folder, HYPERPARAMETERS_FOLDER)
    absl.logging.info('hyperparameters_proper_folder: %s' % hyperparameters_proper_folder)

    # Brings data into the pipeline and splits the data into training and eval splits
    output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))

    # examples = external_input(data_root_uri)
    if str_runtime_parameters_supported and runtime_parameters_config is not None:
        data_root_uri = runtime_parameters_config.data_root_runtime

    examplegen = CsvExampleGen(input_base=data_root_uri, output_config=output_config)

    # examplegen = CsvExampleGen(input_base=data_root_uri)

    # Computes statistics over data for visualization and example validation.
    statisticsgen = StatisticsGen(examples=examplegen.outputs.examples)

    # Generates schema based on statistics files. Even though, we use user-provided schema
    # we still want to generate the schema of the newest data for tracking and comparison
    schemagen = SchemaGen(statistics=statisticsgen.outputs.statistics)

    # Import a user-provided schema
    import_schema = Importer(
        source_uri=schema_proper_folder,
        artifact_type=Schema).with_id('import_user_schema')

    # Performs anomaly detection based on statistics and data schema.
    examplevalidator = ExampleValidator(
        statistics=statisticsgen.outputs.statistics,
        schema=import_schema.outputs.result)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(
        examples=examplegen.outputs.examples,
        schema=import_schema.outputs.result,
        module_file=preprocessing_proper_file)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.

    hparams_importer = Importer(
        source_uri=hyperparameters_proper_folder,
        artifact_type=HyperParameters).with_id('import_hparams')

    # apparently only str RuntimeParameters are supported in airflow :/
    if int_runtime_parameters_supported and runtime_parameters_config is not None:
        train_steps = runtime_parameters_config.train_steps_runtime
        eval_steps = runtime_parameters_config.eval_steps_runtime
    else:
        train_steps = trainer_config.train_steps
        eval_steps = trainer_config.eval_steps

    absl.logging.info('train_steps: %s' % train_steps)
    absl.logging.info('eval_steps: %s' % eval_steps)

    if tuner_config.enable_tuning:
        tuner_args = {
            'module_file': model_proper_file,
            'examples': transform.outputs.transformed_examples,
            'transform_graph': transform.outputs.transform_graph,
            'train_args': {'num_steps': tuner_config.tuner_steps},
            'eval_args': {'num_steps': tuner_config.eval_tuner_steps},
            'custom_config': {'max_trials': tuner_config.max_trials, 'is_local_run': local_run}
            # 'tune_args': tuner_pb2.TuneArgs(num_parallel_trials=3),
        }

        if tuner_config.ai_platform_tuner_args is not None:
            tuner_args.update({
                'custom_config': {
                    ai_platform_trainer_executor.TRAINING_ARGS_KEY: tuner_config.ai_platform_tuner_args
                },
                'tune_args': tuner_pb2.TuneArgs(num_parallel_trials=3)
            })

        absl.logging.info("tuner_args: " + str(tuner_args))
        tuner = Tuner(**tuner_args)

    hyperparameters = tuner.outputs.best_hyperparameters if tuner_config.enable_tuning else hparams_importer.outputs['result']

    # Trains the model using a user provided trainer function.

    trainer_args = {
        'module_file': model_proper_file,
        'transformed_examples': transform.outputs.transformed_examples,
        'schema': import_schema.outputs.result,
        'transform_graph': transform.outputs.transform_graph,
        # train_args={'num_steps': train_steps},
        'train_args': {'num_steps': train_steps},
        'eval_args': {'num_steps': eval_steps},
        #'hyperparameters': tuner.outputs.best_hyperparameters if tunerConfig.enable_tuning else None,
        'hyperparameters': hyperparameters,
        'custom_config': {'epochs': trainer_config.epochs, 'train_batch_size': trainer_config.train_batch_size,
                          'eval_batch_size': trainer_config.eval_batch_size,
                         }
    }

    if trainer_config.ai_platform_training_args is not None:
        trainer_args['custom_config'].update({
            ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                trainer_config.ai_platform_training_args,
        })
        trainer_args.update({
            'custom_executor_spec':
                executor_spec.ExecutorClassSpec(ai_platform_trainer_executor.GenericExecutor),
            # 'custom_config': {
            #    ai_platform_trainer_executor.TRAINING_ARGS_KEY:
            #        ai_platform_training_args,
            # }
        })
    else:
        trainer_args.update({
            'custom_executor_spec':
                executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
                #executor_spec.ExecutorClassSpec(custom_trainer_executor.CustomGenericExecutor), # for debugging purposes
        })

    trainer = Trainer(**trainer_args)

    # Get the latest blessed model for model validation.

    model_resolver = resolver.Resolver(
        #instance_name='latest_blessed_model_resolver', # instance_name is deprecated, use with_id()
        strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing)\
        ).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model.
    accuracy_threshold = tfma.MetricThreshold(
        value_threshold=tfma.GenericValueThreshold(
            lower_bound={'value': 0.5},
            upper_bound={'value': 0.995}),
    )

    metrics_specs = tfma.MetricsSpec(
        metrics=[
            tfma.MetricConfig(class_name='BinaryAccuracy',
                              threshold=accuracy_threshold),
            tfma.MetricConfig(class_name='ExampleCount')])

    eval_config = tfma.EvalConfig(
        model_specs=[
            tfma.ModelSpec(label_key='Survived')
        ],
        metrics_specs=[metrics_specs],
        slicing_specs=[
            tfma.SlicingSpec()
            ,tfma.SlicingSpec(feature_keys=['Sex'])
            ,tfma.SlicingSpec(feature_keys=['Age'])
            ,tfma.SlicingSpec(feature_keys=['Parch'])
        ]
    )

    evaluator = Evaluator(
        examples=examplegen.outputs.examples,
        model=trainer.outputs.model,
        baseline_model=model_resolver.outputs.model,
        eval_config=eval_config
    )

    # Validate model can be loaded and queried in sand-boxed environment
    # mirroring production.

    serving_config = None

    if local_run:
        serving_config = infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(tags=['latest']),
            local_docker=infra_validator_pb2.LocalDockerConfig()  # Running on local docker.
        )
    else:
        serving_config = infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(tags=['latest']),
            kubernetes=infra_validator_pb2.KubernetesConfig()  # Running on K8s.
        )

    validation_config = infra_validator_pb2.ValidationSpec(
        max_loading_time_seconds=60,
        num_tries=3,
    )

    request_config = infra_validator_pb2.RequestSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServingRequestSpec(),
        num_examples=3,
    )

    infravalidator = InfraValidator(
        model=trainer.outputs.model,
        examples=examplegen.outputs.examples,
        serving_spec=serving_config,
        validation_spec=validation_config,
        request_spec=request_config,
    )

    # Checks whether the model passed the validation steps and pushes the model
    # to CAIP Prediction if checks are passed.

    pusher_args = {
        'model': trainer.outputs.model,
        'model_blessing': evaluator.outputs.blessing,
        'infra_blessing': infravalidator.outputs.blessing
    }

    if local_run:
        pusher_args.update({'push_destination':
            pusher_pb2.PushDestination(
                filesystem=pusher_pb2.PushDestination.Filesystem(
                    base_directory=pusher_config.serving_model_dir))})

    if pusher_config.ai_platform_serving_args is not None:
        pusher_args.update({
            'custom_executor_spec':
                executor_spec.ExecutorClassSpec(ai_platform_pusher_executor.Executor
                                                ),
            'custom_config': {
                ai_platform_pusher_executor.SERVING_ARGS_KEY:
                    pusher_config.ai_platform_serving_args
            },
        })

    pusher = Pusher(**pusher_args)  # pylint: disable=unused-variable

    components = [
        examplegen,
        statisticsgen,
        schemagen,
        import_schema,
        examplevalidator,
        transform,
        trainer,
        model_resolver,
        evaluator,
        infravalidator,
        pusher
    ]

    if tuner_config.enable_tuning:
        components.append(tuner)
    else:
        components.append(hparams_importer)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=enable_cache,
        metadata_connection_config=metadata_connection_config,
        beam_pipeline_args=beam_pipeline_args
    )
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_root: Text,
    module_file: Text,
    ai_platform_training_args: Dict[Text, Text],
    ai_platform_serving_args: Dict[Text, Text],
    enable_tuning: bool,
    beam_pipeline_args: Optional[List[Text]] = None) -> pipeline.Pipeline:
  """Implements the Iris flowers pipeline with TFX and Kubeflow Pipeline.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root: uri of the Iris flowers data.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options

  Returns:
    A TFX pipeline object.
  """
  examples = external_input(data_root)

  # Beam args to run data processing on DataflowRunner.
  #
  # TODO(b/151114974): Remove `disk_size_gb` flag after default is increased.
  # TODO(b/151116587): Remove `shuffle_mode` flag after default is changed.
  # TODO(b/156874687): Remove `machine_type` after IP addresses are no longer a
  #                    scaling bottleneck.
  if beam_pipeline_args is None:
    beam_pipeline_args = [
        '--runner=DataflowRunner',
        '--project=' + _project_id,
        '--temp_location=' + os.path.join(_output_bucket, 'tmp'),
        '--region=' + _gcp_region,

        # Temporary overrides of defaults.
        '--disk_size_gb=50',
        '--experiments=shuffle_mode=auto',
        '--machine_type=n1-standard-8',
    ]

  # Number of epochs in training.
  train_steps = data_types.RuntimeParameter(
      name='train_steps',
      default=100,
      ptype=int,
  )

  # Number of epochs in evaluation.
  eval_steps = data_types.RuntimeParameter(
      name='eval_steps',
      default=50,
      ptype=int,
  )

  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = CsvExampleGen(input=examples)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  # Update ai_platform_training_args if distributed training was enabled.
  # Number of worker machines used in distributed training.
  worker_count = data_types.RuntimeParameter(
      name='worker_count',
      default=2,
      ptype=int,
  )

  # Type of worker machines used in distributed training.
  worker_type = data_types.RuntimeParameter(
      name='worker_type',
      default='standard',
      ptype=str,
  )

  local_training_args = copy.deepcopy(ai_platform_training_args)
  if FLAGS.distributed_training:
    local_training_args.update({
        # You can specify the machine types, the number of replicas for workers
        # and parameter servers.
        # https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#ScaleTier
        'scaleTier': 'CUSTOM',
        'masterType': 'large_model',
        'workerType': worker_type,
        'parameterServerType': 'standard',
        'workerCount': worker_count,
        'parameterServerCount': 1,
    })

  # Tunes the hyperparameters for model training based on user-provided Python
  # function. Note that once the hyperparameters are tuned, you can drop the
  # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
  if enable_tuning:
    # The Tuner component launches 1 AIP Training job for flock management.
    # For example, 3 workers (defined by num_parallel_trials) in the flock
    # management AIP Training job, each runs Tuner.Executor.
    # Then, 3 AIP Training Jobs (defined by local_training_args) are invoked
    # from each worker in the flock management Job for Trial execution.
    tuner = Tuner(
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        tune_args=tuner_pb2.TuneArgs(
            # num_parallel_trials=3 means that 3 search loops are
            # running in parallel.
            # Each tuner may include a distributed training job which can be
            # specified in local_training_args above (e.g. 1 PS + 2 workers).
            num_parallel_trials=3),
        custom_config={
            # Configures Cloud AI Platform-specific configs . For details, see
            # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
            ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                local_training_args
        })

  # Uses user-provided Python function that trains a model.
  trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_trainer_executor.GenericExecutor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      # If Tuner is in the pipeline, Trainer can take Tuner's output
      # best_hyperparameters artifact as input and utilize it in the user module
      # code.
      #
      # If there isn't Tuner in the pipeline, either use ImporterNode to import
      # a previous Tuner's output to feed to Trainer, or directly use the tuned
      # hyperparameters in user module code and set hyperparameters to None
      # here.
      #
      # Example of ImporterNode,
      #   hparams_importer = ImporterNode(
      #     instance_name='import_hparams',
      #     source_uri='path/to/best_hyperparameters.txt',
      #     artifact_type=HyperParameters)
      #   ...
      #   hyperparameters = hparams_importer.outputs['result'],
      hyperparameters=(tuner.outputs['best_hyperparameters']
                       if enable_tuning else None),
      train_args={'num_steps': train_steps},
      eval_args={'num_steps': eval_steps},
      custom_config={
          ai_platform_trainer_executor.TRAINING_ARGS_KEY:
              local_training_args
      })

  # Get the latest blessed model for model validation.
  model_resolver = ResolverNode(
      instance_name='latest_blessed_model_resolver',
      resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(type=ModelBlessing))

  # Uses TFMA to compute an evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compared to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='variety')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.6}),
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
          ])
      ])
  evaluator = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      # Change threshold will be ignored if there is no baseline (first run).
      eval_config=eval_config)

  pusher = Pusher(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_pusher_executor.Executor),
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      custom_config={
          ai_platform_pusher_executor.SERVING_ARGS_KEY: ai_platform_serving_args
      })

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      example_validator,
      transform,
      trainer,
      model_resolver,
      evaluator,
      pusher,
  ]
  if enable_tuning:
    components.append(tuner)

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      beam_pipeline_args=beam_pipeline_args)
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_root: Text,
    module_file: Text,
    ai_platform_training_args: Dict[Text, Text],
    ai_platform_serving_args: Dict[Text, Text],
    enable_tuning: bool,
    beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
  """Implements the penguin pipeline with TFX and Kubeflow Pipeline.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root: uri of the penguin data.
    module_file: uri of the module files used in Trainer and Transform
      components.
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.
    beam_pipeline_args: List of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.

  Returns:
    A TFX pipeline object.
  """
  # Number of epochs in training.
  train_steps = data_types.RuntimeParameter(
      name='train_steps',
      default=100,
      ptype=int,
  )

  # Number of epochs in evaluation.
  eval_steps = data_types.RuntimeParameter(
      name='eval_steps',
      default=50,
      ptype=int,
  )

  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  # Tunes the hyperparameters for model training based on user-provided Python
  # function. Note that once the hyperparameters are tuned, you can drop the
  # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
  if enable_tuning:
    # The Tuner component launches 1 AIP Training job for flock management of
    # parallel tuning. For example, 2 workers (defined by num_parallel_trials)
    # in the flock management AIP Training job, each runs a search loop for
    # trials as shown below.
    #   Tuner component -> CAIP job X -> CloudTunerA -> tuning trials
    #                                 -> CloudTunerB -> tuning trials
    #
    # Distributed training for each trial depends on the Tuner
    # (kerastuner.BaseTuner) setup in tuner_fn. Currently CloudTuner is single
    # worker training per trial. DistributingCloudTuner (a subclass of
    # CloudTuner) launches remote distributed training job per trial.
    #
    # E.g., single worker training per trial
    #   ... -> CloudTunerA -> single worker training
    #       -> CloudTunerB -> single worker training
    # vs distributed training per trial
    #   ... -> DistributingCloudTunerA -> CAIP job Y -> master,worker1,2,3
    #       -> DistributingCloudTunerB -> CAIP job Z -> master,worker1,2,3
    tuner = Tuner(
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        tune_args=tuner_pb2.TuneArgs(
            # num_parallel_trials=3 means that 3 search loops are
            # running in parallel.
            num_parallel_trials=3),
        custom_config={
            # Note that this TUNING_ARGS_KEY will be used to start the CAIP job
            # for parallel tuning (CAIP job X above).
            #
            # num_parallel_trials will be used to fill/overwrite the
            # workerCount specified by TUNING_ARGS_KEY:
            #   num_parallel_trials = workerCount + 1 (for master)
            ai_platform_tuner_executor.TUNING_ARGS_KEY:
                ai_platform_training_args,
            # This working directory has to be a valid GCS path and will be used
            # to launch remote training job per trial.
            ai_platform_tuner_executor.REMOTE_TRIALS_WORKING_DIR_KEY:
                os.path.join(_pipeline_root, 'trials'),
        })

  # Uses user-provided Python function that trains a model.
  trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_trainer_executor.GenericExecutor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      # If Tuner is in the pipeline, Trainer can take Tuner's output
      # best_hyperparameters artifact as input and utilize it in the user module
      # code.
      #
      # If there isn't Tuner in the pipeline, either use ImporterNode to import
      # a previous Tuner's output to feed to Trainer, or directly use the tuned
      # hyperparameters in user module code and set hyperparameters to None
      # here.
      #
      # Example of ImporterNode,
      #   hparams_importer = ImporterNode(
      #     instance_name='import_hparams',
      #     source_uri='path/to/best_hyperparameters.txt',
      #     artifact_type=HyperParameters)
      #   ...
      #   hyperparameters = hparams_importer.outputs['result'],
      hyperparameters=(tuner.outputs['best_hyperparameters']
                       if enable_tuning else None),
      train_args={'num_steps': train_steps},
      eval_args={'num_steps': eval_steps},
      custom_config={
          ai_platform_trainer_executor.TRAINING_ARGS_KEY:
              ai_platform_training_args
      })

  # Get the latest blessed model for model validation.
  model_resolver = ResolverNode(
      instance_name='latest_blessed_model_resolver',
      resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(type=ModelBlessing))

  # Uses TFMA to compute evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compared to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='species')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.6}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
          ])
      ])

  evaluator = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  pusher = Pusher(
      custom_executor_spec=executor_spec.ExecutorClassSpec(
          ai_platform_pusher_executor.Executor),
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      custom_config={
          ai_platform_pusher_executor.SERVING_ARGS_KEY: ai_platform_serving_args
      },
  )

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      example_validator,
      transform,
      trainer,
      model_resolver,
      evaluator,
      pusher,
  ]
  if enable_tuning:
    components.append(tuner)

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      beam_pipeline_args=beam_pipeline_args)
示例#16
0
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    serving_model_uri: Text,
    data_root_uri: Union[Text, data_types.RuntimeParameter],
    schema_folder_uri: Union[Text, data_types.RuntimeParameter],
    train_steps: Union[int, data_types.RuntimeParameter],
    eval_steps: Union[int, data_types.RuntimeParameter],
    beam_pipeline_args: List[Text],
    trainer_custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
    trainer_custom_config: Optional[Dict[Text, Any]] = None,
    enable_tuning: Optional[bool] = False,
    enable_cache: Optional[bool] = False,
    metadata_connection_config: Optional[
        metadata_store_pb2.ConnectionConfig] = None
) -> pipeline.Pipeline:
    """Trains and deploys the Keras Covertype Classifier with TFX and AI Platform Pipelines."""

    # Brings data into the pipeline and splits the data into training and eval splits
    output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))

    examplegen = CsvExampleGen(input_base=data_root_uri)

    # Computes statistics over data for visualization and example validation.
    statisticsgen = StatisticsGen(examples=examplegen.outputs.examples)

    # Generates schema based on statistics files. Even though, we use user-provided schema
    # we still want to generate the schema of the newest data for tracking and comparison
    schemagen = SchemaGen(statistics=statisticsgen.outputs.statistics)

    # Import a user-provided schema
    import_schema = ImporterNode(
        instance_name='import_user_schema',
        #source_uri=SCHEMA_FOLDER,
        source_uri=schema_folder_uri,
        artifact_type=Schema)

    # Performs anomaly detection based on statistics and data schema.
    examplevalidator = ExampleValidator(
        statistics=statisticsgen.outputs.statistics,
        schema=import_schema.outputs.result)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=examplegen.outputs.examples,
                          schema=import_schema.outputs.result,
                          module_file=TRANSFORM_MODULE_FILE)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        # The Tuner component launches 1 AI Platform Training job for flock management.
        # For example, 3 workers (defined by num_parallel_trials) in the flock
        # management AI Platform Training job, each runs Tuner.Executor.
        tuner = Tuner(
            module_file=TRAIN_MODULE_FILE,
            examples=transform.outputs.transformed_examples,
            transform_graph=transform.outputs.transform_graph,
            train_args={'num_steps': train_steps},
            eval_args={'num_steps': eval_steps},
            tune_args=tuner_pb2.TuneArgs(
                # num_parallel_trials=3 means that 3 search loops are running in parallel.
                num_parallel_trials=3),
            custom_config=custom_config)

    # Trains the model using a user provided trainer function.
    trainer = Trainer(
        custom_executor_spec=trainer_custom_executor_spec,
        module_file=TRAIN_MODULE_FILE,
        transformed_examples=transform.outputs.transformed_examples,
        schema=import_schema.outputs.result,
        transform_graph=transform.outputs.transform_graph,
        hyperparameters=(tuner.outputs.best_hyperparameters
                         if enable_tuning else None),
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config=trainer_custom_config)

    # Get the latest blessed model for model validation.
    resolver = ResolverNode(instance_name='latest_blessed_model_resolver',
                            resolver_class=latest_blessed_model_resolver.
                            LatestBlessedModelResolver,
                            model=Channel(type=Model),
                            model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute a evaluation statistics over features of a model.
    accuracy_threshold = tfma.MetricThreshold(
        value_threshold=tfma.GenericValueThreshold(lower_bound={'value': 0.5},
                                                   upper_bound={'value':
                                                                0.99}), )

    metrics_specs = tfma.MetricsSpec(metrics=[
        tfma.MetricConfig(class_name='SparseCategoricalAccuracy',
                          threshold=accuracy_threshold),
        tfma.MetricConfig(class_name='ExampleCount')
    ])

    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='Cover_Type')],
        metrics_specs=[metrics_specs],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['Wilderness_Area'])
        ])

    evaluator = Evaluator(examples=examplegen.outputs.examples,
                          model=trainer.outputs.model,
                          baseline_model=resolver.outputs.model,
                          eval_config=eval_config)

    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_uri)))

    components = [
        examplegen, statisticsgen, schemagen, import_schema, examplevalidator,
        transform, trainer, resolver, evaluator, pusher
    ]

    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=enable_cache,
        beam_pipeline_args=beam_pipeline_args,
        metadata_connection_config=metadata_connection_config)
示例#17
0
def create_pipeline(pipeline_name: Text,
                    pipeline_root: Text,
                    data_root_uri: data_types.RuntimeParameter,
                    train_steps: data_types.RuntimeParameter,
                    eval_steps: data_types.RuntimeParameter,
                    enable_tuning: bool,
                    ai_platform_training_args: Dict[Text, Text],
                    ai_platform_serving_args: Dict[Text, Text],
                    beam_pipeline_args: List[Text],
                    enable_cache: Optional[bool] = False) -> pipeline.Pipeline:
    """Trains and deploys the Keras Covertype Classifier with TFX and Kubeflow Pipeline on Google Cloud.
  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root_uri: uri of the dataset.
    train_steps: runtime parameter for number of model training steps for the Trainer component.
    eval_steps: runtime parameter for number of model evaluation steps for the Trainer component.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.    
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options
    enable_cache: Optional boolean
  Returns:
    A TFX pipeline object.
  """

    # Brings data into the pipeline and splits the data into training and eval splits
    output = example_gen_pb2.Output(split_config=example_gen_pb2.SplitConfig(
        splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))

    examplegen = CsvExampleGen(input_base=data_root_uri, output_config=output)

    # Computes statistics over data for visualization and example validation.
    statisticsgen = StatisticsGen(examples=examplegen.outputs.examples)

    # Generates schema based on statistics files. Even though, we use user-provided schema
    # we still want to generate the schema of the newest data for tracking and comparison
    schemagen = SchemaGen(statistics=statisticsgen.outputs.statistics)

    # Import a user-provided schema
    import_schema = ImporterNode(instance_name='import_user_schema',
                                 source_uri=SCHEMA_FOLDER,
                                 artifact_type=Schema)

    # Performs anomaly detection based on statistics and data schema.
    examplevalidator = ExampleValidator(
        statistics=statisticsgen.outputs.statistics,
        schema=import_schema.outputs.result)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=examplegen.outputs.examples,
                          schema=import_schema.outputs.result,
                          module_file=TRANSFORM_MODULE_FILE)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
    if enable_tuning:
        # The Tuner component launches 1 AI Platform Training job for flock management.
        # For example, 3 workers (defined by num_parallel_trials) in the flock
        # management AI Platform Training job, each runs Tuner.Executor.
        tuner = Tuner(
            module_file=TRAIN_MODULE_FILE,
            examples=transform.outputs.transformed_examples,
            transform_graph=transform.outputs.transform_graph,
            train_args={'num_steps': train_steps},
            eval_args={'num_steps': eval_steps},
            tune_args=tuner_pb2.TuneArgs(
                # num_parallel_trials=3 means that 3 search loops are running in parallel.
                num_parallel_trials=3),
            custom_config={
                # Configures Cloud AI Platform-specific configs. For details, see
                # https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#traininginput.
                ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                ai_platform_training_args
            })

    # Trains the model using a user provided trainer function.
    trainer = Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            ai_platform_trainer_executor.GenericExecutor),
        module_file=TRAIN_MODULE_FILE,
        transformed_examples=transform.outputs.transformed_examples,
        schema=import_schema.outputs.result,
        transform_graph=transform.outputs.transform_graph,
        hyperparameters=(tuner.outputs.best_hyperparameters
                         if enable_tuning else None),
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps},
        custom_config={'ai_platform_training_args': ai_platform_training_args})

    # Get the latest blessed model for model validation.
    resolver = ResolverNode(instance_name='latest_blessed_model_resolver',
                            resolver_class=latest_blessed_model_resolver.
                            LatestBlessedModelResolver,
                            model=Channel(type=Model),
                            model_blessing=Channel(type=ModelBlessing))

    # Uses TFMA to compute a evaluation statistics over features of a model.
    accuracy_threshold = tfma.MetricThreshold(
        value_threshold=tfma.GenericValueThreshold(lower_bound={'value': 0.5},
                                                   upper_bound={'value':
                                                                0.99}), )

    metrics_specs = tfma.MetricsSpec(metrics=[
        tfma.MetricConfig(class_name='SparseCategoricalAccuracy',
                          threshold=accuracy_threshold),
        tfma.MetricConfig(class_name='ExampleCount')
    ])

    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='Cover_Type')],
        metrics_specs=[metrics_specs],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['Wilderness_Area'])
        ])

    evaluator = Evaluator(examples=examplegen.outputs.examples,
                          model=trainer.outputs.model,
                          baseline_model=resolver.outputs.model,
                          eval_config=eval_config)

    # Validate model can be loaded and queried in sand-boxed environment
    # mirroring production.
    serving_config = infra_validator_pb2.ServingSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServing(
            tags=['latest']),
        kubernetes=infra_validator_pb2.KubernetesConfig(),
    )

    validation_config = infra_validator_pb2.ValidationSpec(
        max_loading_time_seconds=60,
        num_tries=3,
    )

    request_config = infra_validator_pb2.RequestSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServingRequestSpec(),
        num_examples=3,
    )

    infravalidator = InfraValidator(
        model=trainer.outputs.model,
        examples=examplegen.outputs.examples,
        serving_spec=serving_config,
        validation_spec=validation_config,
        request_spec=request_config,
    )

    # Checks whether the model passed the validation steps and pushes the model
    # to CAIP Prediction if checks are passed.
    pusher = Pusher(custom_executor_spec=executor_spec.ExecutorClassSpec(
        ai_platform_pusher_executor.Executor),
                    model=trainer.outputs.model,
                    model_blessing=evaluator.outputs.blessing,
                    infra_blessing=infravalidator.outputs.blessing,
                    custom_config={
                        ai_platform_pusher_executor.SERVING_ARGS_KEY:
                        ai_platform_serving_args
                    })

    components = [
        examplegen, statisticsgen, schemagen, import_schema, examplevalidator,
        transform, trainer, resolver, evaluator, infravalidator, pusher
    ]

    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=components,
                             enable_cache=enable_cache,
                             beam_pipeline_args=beam_pipeline_args)