Exemplo n.º 1
0
    def testResolverDefinition_ExactlyOneOfStrategyClassOrFunction(self):
        with self.assertRaisesRegex(
                ValueError,
                'Exactly one of strategy_class= or function= argument '
                'should be given.'):
            resolver.Resolver()

        with self.assertRaisesRegex(
                ValueError,
                'Exactly one of strategy_class= or function= argument '
                'should be given.'):
            resolver.Resolver(
                strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
                function=resolver_function.ResolverFunction(lambda x: x))
Exemplo n.º 2
0
    def testBuildLatestArtifactResolverSucceed(self):
        latest_model_resolver = resolver.Resolver(
            instance_name='my_resolver',
            strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            examples=channel.Channel(type=standard_artifacts.Examples))
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')
        my_builder = step_builder.StepBuilder(
            node=latest_model_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Exemplo n.º 3
0
  def testIsResolver(self):
    resv = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver)
    self.assertTrue(compiler_utils.is_resolver(resv))

    example_gen = CsvExampleGen(input_base="data_path")
    self.assertFalse(compiler_utils.is_resolver(example_gen))
Exemplo n.º 4
0
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root_1: str,
                     data_root_2: str) -> pipeline.Pipeline:
    """Implements a pipeline with channel.union()."""
    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen_1 = CsvExampleGen(
        input_base=data_root_1).with_id('example_gen_1')
    example_gen_2 = CsvExampleGen(
        input_base=data_root_2).with_id('example_gen_2')

    # pylint: disable=no-value-for-parameter
    channel_union = ChannelUnionComponent(input_data=channel.union(
        [example_gen_1.outputs['examples'],
         example_gen_2.outputs['examples']]),
                                          name='channel_union_input')

    # Get the latest channel.
    latest_artifacts_resolver = resolver.Resolver(
        strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
        resolved_channels=channel.union([
            example_gen_1.outputs['examples'],
            channel_union.outputs['output_data']
        ])).with_id('latest_artifacts_resolver')

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(
        examples=latest_artifacts_resolver.outputs['resolved_channels'])
    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=[
                                 example_gen_1, example_gen_2, channel_union,
                                 latest_artifacts_resolver, statistics_gen
                             ])
Exemplo n.º 5
0
    def test_functional_resolver_node(self):
        # Linter doesn't understand the @resolver decorator has changed the
        # signature of MyResolver, thus disables linters.
        # pylint: disable=invalid-name
        # pylint: disable=unexpected-keyword-arg
        # pylint: disable=no-value-for-parameter

        @resolver_function.resolver_function
        def my_resolver_fn(input_dict):
            return Bar(Foo(input_dict, foo=1), bar='z')

        a = DummyComponents.A()
        r = resolver.Resolver(function=my_resolver_fn,
                              x=a.outputs['x']).with_id('R')
        pipeline_ir = self.compile_sync_pipeline([a, r])

        r_ir = self.extract_pipeline_node(pipeline_ir, 'R')
        resolver_steps = r_ir.inputs.resolver_config.resolver_steps
        self.assertLen(resolver_steps, 2)
        # Check steps[0]
        self.assertEndsWith(resolver_steps[0].class_path, '.Foo')
        self.assertEqual(json.loads(resolver_steps[0].config_json), {'foo': 1})
        self.assertEqual(resolver_steps[0].input_keys, ['x'])
        # Check steps[1]
        self.assertEndsWith(resolver_steps[1].class_path, '.Bar')
        self.assertEqual(json.loads(resolver_steps[1].config_json),
                         {'bar': 'z'})
        self.assertEqual(resolver_steps[1].input_keys, ['x'])
Exemplo n.º 6
0
    def testBuildLatestBlessedModelResolverSucceed(self):

        latest_blessed_resolver = resolver.Resolver(
            instance_name='my_resolver2',
            strategy_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            model_blessing=channel.Channel(
                type=standard_artifacts.ModelBlessing))
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        my_builder = step_builder.StepBuilder(
            node=latest_blessed_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info)
        actual_step_specs = my_builder.build()

        self.assertProtoEquals(
            text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1,
                              pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[0])

        self.assertProtoEquals(
            text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2,
                              pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[1])

        self.assertProtoEquals(
            text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR,
                              pipeline_pb2.PipelineDeploymentConfig()),
            deployment_config)
Exemplo n.º 7
0
    def testResolverDefinition_BadStrategyClass(self):
        class NotAStrategy:
            pass

        with self.assertRaisesRegex(
                TypeError, 'strategy_class should be ResolverStrategy'):
            resolver.Resolver(strategy_class=NotAStrategy)
Exemplo n.º 8
0
 def testResolverDefinitionBadArgs(self):
     with self.assertRaisesRegexp(
             ValueError,
             'Expected extra kwarg .* to be of type .*tfx.types.Channel'):
         _ = resolver.Resolver(strategy_class=latest_artifacts_resolver.
                               LatestArtifactsResolver,
                               config={'desired_num_of_artifacts': 5},
                               blah=object())
Exemplo n.º 9
0
 def testResolverDefinition_BadChannel(self):
     with self.assertRaisesRegex(
             ValueError,
             'Expected extra kwarg .* to be of type .*tfx.types.Channel'):
         resolver.Resolver(
             strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
             config={'desired_num_of_artifacts': 5},
             not_a_channel=object())
Exemplo n.º 10
0
    def testBuildLatestBlessedModelResolverSucceed(self):
        latest_blessed_resolver = resolver.Resolver(
            instance_name='my_resolver2',
            strategy_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            model_blessing=channel.Channel(
                type=standard_artifacts.ModelBlessing))
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=latest_blessed_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_specs = my_builder.build()

        model_blessing_resolver_id = 'Resolver.my_resolver2-model-blessing-resolver'
        model_resolver_id = 'Resolver.my_resolver2-model-resolver'
        self.assertSameElements(
            actual_step_specs.keys(),
            [model_blessing_resolver_id, model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_1.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_1.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_2.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_2.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Exemplo n.º 11
0
    def testIsResolver(self):
        resv = resolver.Resolver(instance_name="test_resolver_name",
                                 strategy_class=latest_blessed_model_resolver.
                                 LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))
        resv = legacy_resolver_node.ResolverNode(
            instance_name="test_resolver_name",
            resolver_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))

        example_gen = CsvExampleGen(input_base="data_path")
        self.assertFalse(compiler_utils.is_resolver(example_gen))
Exemplo n.º 12
0
def create_test_pipeline():
    """Creates a sample pipeline with ForEach context."""

    example_gen = CsvExampleGen(input_base='/data/mydummy_dataset')

    with for_each.ForEach(example_gen.outputs['examples']) as each_example:
        statistics_gen = StatisticsGen(examples=each_example)

    latest_stats_resolver = resolver.Resolver(
        statistics=statistics_gen.outputs['statistics'],
        strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
    ).with_id('latest_stats_resolver')

    schema_gen = SchemaGen(
        statistics=latest_stats_resolver.outputs['statistics'])

    with for_each.ForEach(example_gen.outputs['examples']) as each_example:
        trainer = Trainer(
            module_file='/src/train.py',
            examples=each_example,
            schema=schema_gen.outputs['schema'],
            train_args=trainer_pb2.TrainArgs(num_steps=2000),
        )

    with for_each.ForEach(trainer.outputs['model']) as each_model:
        pusher = Pusher(
            model=each_model,
            push_destination=pusher_pb2.PushDestination(
                filesystem=pusher_pb2.PushDestination.Filesystem(
                    base_directory='/models')),
        )

    return pipeline.Pipeline(
        pipeline_name='foreach',
        pipeline_root='/tfx/pipelines/foreach',
        components=[
            example_gen,
            statistics_gen,
            latest_stats_resolver,
            schema_gen,
            trainer,
            pusher,
        ],
        enable_cache=True,
        execution_mode=pipeline.ExecutionMode.SYNC,
    )
Exemplo n.º 13
0
 def testResolverDefinition(self):
     channel_to_resolve = types.Channel(type=standard_artifacts.Examples)
     rnode = resolver.Resolver(
         strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
         config={'desired_num_of_artifacts': 5},
         channel_to_resolve=channel_to_resolve)
     self.assertDictEqual(
         rnode.exec_properties, {
             resolver.RESOLVER_STRATEGY_CLASS:
             latest_artifact_strategy.LatestArtifactStrategy,
             resolver.RESOLVER_CONFIG: {
                 'desired_num_of_artifacts': 5
             }
         })
     self.assertEqual(rnode.inputs['channel_to_resolve'],
                      channel_to_resolve)
     self.assertEqual(rnode.outputs['channel_to_resolve'].type_name,
                      channel_to_resolve.type_name)
Exemplo n.º 14
0
 def testResolverUnionChannel(self):
     one_channel = types.Channel(type=standard_artifacts.Examples)
     another_channel = types.Channel(type=standard_artifacts.Examples)
     unioned_channel = channel.union([one_channel, another_channel])
     rnode = resolver.Resolver(
         strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
         config={'desired_num_of_artifacts': 5},
         unioned_channel=unioned_channel)
     self.assertDictEqual(
         rnode.exec_properties, {
             resolver.RESOLVER_STRATEGY_CLASS:
             latest_artifact_strategy.LatestArtifactStrategy,
             resolver.RESOLVER_CONFIG: {
                 'desired_num_of_artifacts': 5
             }
         })
     self.assertEqual(rnode.inputs['unioned_channel'], unioned_channel)
     self.assertEqual(rnode.outputs['unioned_channel'].type_name,
                      unioned_channel.type_name)
Exemplo n.º 15
0
def create_pipeline() -> pipeline_pb2.Pipeline:
  """Creates a pipeline with a resolver node for testing."""
  trainer = _trainer().with_id('my_trainer')  # pylint: disable=no-value-for-parameter
  rnode = resolver.Resolver(
      strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
      config={
          'desired_num_of_artifacts': 1
      },
      resolved_model=types.Channel(
          type=standard_artifacts.Model)).with_id('my_resolver')
  rnode.add_upstream_node(trainer)
  consumer = _consumer(
      resolved_model=rnode.outputs['resolved_model']).with_id('my_consumer')
  pipeline = pipeline_lib.Pipeline(
      pipeline_name='my_pipeline',
      pipeline_root='/path/to/root',
      components=[
          trainer,
          rnode,
          consumer,
      ],
      execution_mode=pipeline_lib.ExecutionMode.SYNC)
  dsl_compiler = compiler.Compiler()
  return dsl_compiler.compile(pipeline)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     trainer_module_file: Text, evaluator_module_file: Text,
                     serving_model_dir: Text, metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the Penguin pipeline with TFX."""
    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # TODO(humichael): Handle applying transformation component in Milestone 3.

    # Uses user-provided Python function that trains a model using TF-Learn.
    # Num_steps is not provided during evaluation because the scikit-learn model
    # loads and evaluates the entire test set at once.
    trainer = Trainer(module_file=trainer_module_file,
                      examples=example_gen.outputs['examples'],
                      schema=schema_gen.outputs['schema'],
                      train_args=trainer_pb2.TrainArgs(num_steps=2000),
                      eval_args=trainer_pb2.EvalArgs())

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='species')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='Accuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])
    evaluator = Evaluator(module_file=evaluator_module_file,
                          examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            schema_gen,
            example_validator,
            trainer,
            model_resolver,
            evaluator,
            pusher,
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args,
    )
Exemplo n.º 17
0
def create_pipeline_components(
    pipeline_root: Text,
    transform_module: Text,
    trainer_module: Text,
    bigquery_query: Text = '',
    csv_input_location: Text = '',
) -> List[base_node.BaseNode]:
    """Creates components for a simple Chicago Taxi TFX pipeline for testing.

  Args:
    pipeline_root: The root of the pipeline output.
    transform_module: The location of the transform module file.
    trainer_module: The location of the trainer module file.
    bigquery_query: The query to get input data from BigQuery. If not empty,
      BigQueryExampleGen will be used.
    csv_input_location: The location of the input data directory.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """

    if bool(bigquery_query) == bool(csv_input_location):
        raise ValueError(
            'Exactly one example gen is expected. ',
            'Please provide either bigquery_query or csv_input_location.')

    if bigquery_query:
        example_gen = big_query_example_gen_component.BigQueryExampleGen(
            query=bigquery_query)
    else:
        example_gen = components.CsvExampleGen(input_base=csv_input_location)

    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs['examples'])
    schema_gen = components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=False)
    example_validator = components.ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    transform = components.Transform(examples=example_gen.outputs['examples'],
                                     schema=schema_gen.outputs['schema'],
                                     module_file=transform_module)
    latest_model_resolver = resolver.Resolver(
        strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
        model=channel.Channel(type=standard_artifacts.Model)).with_id(
            'Resolver.latest_model_resolver')
    trainer = components.Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        base_model=latest_model_resolver.outputs['model'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10),
        eval_args=trainer_pb2.EvalArgs(num_steps=5),
        module_file=trainer_module,
    )
    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=channel.Channel(type=standard_artifacts.Model),
        model_blessing=channel.Channel(
            type=standard_artifacts.ModelBlessing)).with_id(
                'Resolver.latest_blessed_model_resolver')
    # Set the TFMA config for Model Evaluation and Validation.
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        metrics_specs=[
            tfma.MetricsSpec(
                metrics=[tfma.MetricConfig(class_name='ExampleCount')],
                thresholds={
                    'binary_accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.5}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ])
    evaluator = components.Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        eval_config=eval_config)

    pusher = components.Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=os.path.join(pipeline_root, 'model_serving'))))

    return [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        latest_model_resolver, trainer, model_resolver, evaluator, pusher
    ]
Exemplo n.º 18
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX and Kubeflow Pipelines."""
    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that implements a model using TF-Learn
    # to train a model on Google Cloud AI Platform.
    trainer = Trainer(
        module_file=module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000),
    )

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Performs infra validation of a candidate model to prevent unservable model
    # from being pushed. In order to use InfraValidator component, persistent
    # volume and its claim that the pipeline is using should be a ReadWriteMany
    # access mode.
    infra_validator = InfraValidator(
        model=trainer.outputs['model'],
        examples=example_gen.outputs['examples'],
        serving_spec=infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(
                tags=['latest']),
            kubernetes=infra_validator_pb2.KubernetesConfig()),
        request_spec=infra_validator_pb2.RequestSpec(
            tensorflow_serving=infra_validator_pb2.
            TensorFlowServingRequestSpec()))

    # Checks whether the model passed the validation steps and pushes the model
    # to  Google Cloud AI Platform if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    infra_blessing=infra_validator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=[
                                 example_gen,
                                 statistics_gen,
                                 schema_gen,
                                 example_validator,
                                 transform,
                                 trainer,
                                 model_resolver,
                                 evaluator,
                                 infra_validator,
                                 pusher,
                             ],
                             beam_pipeline_args=beam_pipeline_args)
Exemplo n.º 19
0
 def testResolverDefinition_BadFunction(self):
     with self.assertRaisesRegex(TypeError,
                                 'function should be ResolverFunction'):
         resolver.Resolver(function=42)
Exemplo n.º 20
0
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     module_file: str, serving_model_dir_lite: str,
                     metadata_path: str, labels_path: str,
                     beam_pipeline_args: List[str]) -> pipeline.Pipeline:
  """Implements the CIFAR10 image classification pipeline using TFX."""
  # This is needed for datasets with pre-defined splits
  # Change the pattern argument to train_whole/* and test_whole/* to train
  # on the whole CIFAR-10 dataset
  input_config = example_gen_pb2.Input(splits=[
      example_gen_pb2.Input.Split(name='train', pattern='train/*'),
      example_gen_pb2.Input.Split(name='eval', pattern='test/*')
  ])

  # Brings data into the pipeline.
  example_gen = ImportExampleGen(
      input_base=data_root, input_config=input_config)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  # Uses user-provided Python function that trains a model.
  # When traning on the whole dataset, use 18744 for train steps, 156 for eval
  # steps. 18744 train steps correspond to 24 epochs on the whole train set, and
  # 156 eval steps correspond to 1 epoch on the whole test set. The
  # configuration below is for training on the dataset we provided in the data
  # folder, which has 128 train and 128 test samples. The 160 train steps
  # correspond to 40 epochs on this tiny train set, and 4 eval steps correspond
  # to 1 epoch on this tiny test set.
  trainer = Trainer(
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      train_args=trainer_pb2.TrainArgs(num_steps=160),
      eval_args=trainer_pb2.EvalArgs(num_steps=4),
      custom_config={'labels_path': labels_path})

  # Get the latest blessed model for model validation.
  model_resolver = resolver.Resolver(
      strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(
          type=ModelBlessing)).with_id('latest_blessed_model_resolver')

  # Uses TFMA to compute evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compare to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='label_xf', model_type='tf_lite')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.55}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-3})))
          ])
      ])

  # Uses TFMA to compute the evaluation statistics over features of a model.
  # We evaluate using the materialized examples that are output by Transform
  # because
  # 1. the decoding_png function currently performed within Transform are not
  # compatible with TFLite.
  # 2. MLKit requires deserialized (float32) tensor image inputs
  # Note that for deployment, the same logic that is performed within Transform
  # must be reproduced client-side.
  evaluator = Evaluator(
      examples=transform.outputs['transformed_examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  # Checks whether the model passed the validation steps and pushes the model
  # to a file destination if check passed.
  pusher = Pusher(
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=serving_model_dir_lite)))

  components = [
      example_gen, statistics_gen, schema_gen, example_validator, transform,
      trainer, model_resolver, evaluator, pusher
  ]

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      beam_pipeline_args=beam_pipeline_args)
Exemplo n.º 21
0
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_path: Text,
    # TODO(step 7): (Optional) Uncomment here to use BigQuery as a data source.
    # query: Text,
    preprocessing_fn: Text,
    run_fn: Text,
    train_args: trainer_pb2.TrainArgs,
    eval_args: trainer_pb2.EvalArgs,
    eval_accuracy_threshold: float,
    serving_model_dir: Text,
    metadata_connection_config: Optional[
        metadata_store_pb2.ConnectionConfig] = None,
    beam_pipeline_args: Optional[List[Text]] = None,
    ai_platform_training_args: Optional[Dict[Text, Text]] = None,
    ai_platform_serving_args: Optional[Dict[Text, Any]] = None,
) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""

    components = []

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_path)
    # TODO(step 7): (Optional) Uncomment here to use BigQuery as a data source.
    # example_gen = big_query_example_gen_component.BigQueryExampleGen(
    #     query=query)
    components.append(example_gen)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
    # TODO(step 5): Uncomment here to add StatisticsGen to the pipeline.
    # components.append(statistics_gen)

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)
    # TODO(step 5): Uncomment here to add SchemaGen to the pipeline.
    # components.append(schema_gen)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(  # pylint: disable=unused-variable
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    # TODO(step 5): Uncomment here to add ExampleValidator to the pipeline.
    # components.append(example_validator)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          preprocessing_fn=preprocessing_fn)
    # TODO(step 6): Uncomment here to add Transform to the pipeline.
    # components.append(transform)

    # Uses user-provided Python function that implements a model using TF-Learn.
    trainer_args = {
        'run_fn': run_fn,
        'transformed_examples': transform.outputs['transformed_examples'],
        'schema': schema_gen.outputs['schema'],
        'transform_graph': transform.outputs['transform_graph'],
        'train_args': train_args,
        'eval_args': eval_args,
    }
    if ai_platform_training_args is not None:
        trainer_args.update({
            'custom_executor_spec':
            executor_spec.ExecutorClassSpec(
                ai_platform_trainer_executor.GenericExecutor),
            'custom_config': {
                ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                ai_platform_training_args,
            }
        })
    trainer = Trainer(**trainer_args)
    # TODO(step 6): Uncomment here to add Trainer to the pipeline.
    # components.append(trainer)

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')
    # TODO(step 6): Uncomment here to add ResolverNode to the pipeline.
    # components.append(model_resolver)

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='big_tipper')],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='BinaryAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': eval_accuracy_threshold}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        # Change threshold will be ignored if there is no baseline (first run).
        eval_config=eval_config)
    # TODO(step 6): Uncomment here to add Evaluator to the pipeline.
    # components.append(evaluator)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher_args = {
        'model':
        trainer.outputs['model'],
        'model_blessing':
        evaluator.outputs['blessing'],
        'push_destination':
        pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=serving_model_dir)),
    }
    if ai_platform_serving_args is not None:
        pusher_args.update({
            'custom_executor_spec':
            executor_spec.ExecutorClassSpec(
                ai_platform_pusher_executor.Executor),
            'custom_config': {
                ai_platform_pusher_executor.SERVING_ARGS_KEY:
                ai_platform_serving_args
            },
        })
    pusher = Pusher(**pusher_args)  # pylint: disable=unused-variable
    # TODO(step 6): Uncomment here to add Pusher to the pipeline.
    # components.append(pusher)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        # Change this value to control caching of execution results. Default value
        # is `False`.
        # enable_cache=True,
        metadata_connection_config=metadata_connection_config,
        beam_pipeline_args=beam_pipeline_args,
    )
Exemplo n.º 22
0
def create_test_pipeline():
    """Builds an Iris example pipeline with slight changes."""
    pipeline_name = "iris"
    iris_root = "iris_root"
    serving_model_dir = os.path.join(iris_root, "serving_model", pipeline_name)
    tfx_root = "tfx_root"
    data_path = os.path.join(tfx_root, "data_path")
    pipeline_root = os.path.join(tfx_root, "pipelines", pipeline_name)

    example_gen = CsvExampleGen(input_base=data_path)

    statistics_gen = StatisticsGen(examples=example_gen.outputs["examples"])

    importer = ImporterNode(
        source_uri="m/y/u/r/i",
        properties={
            "split_names": "['train', 'eval']",
        },
        custom_properties={
            "int_custom_property": 42,
            "str_custom_property": "42",
        },
        artifact_type=standard_artifacts.Examples).with_id("my_importer")

    schema_gen = SchemaGen(statistics=statistics_gen.outputs["statistics"],
                           infer_feature_shape=True)

    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs["statistics"],
        schema=schema_gen.outputs["schema"])

    trainer = Trainer(
        # Use RuntimeParameter as module_file to test out RuntimeParameter in
        # compiler.
        module_file=data_types.RuntimeParameter(name="module_file",
                                                default=os.path.join(
                                                    iris_root,
                                                    "iris_utils.py"),
                                                ptype=str),
        custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
        examples=example_gen.outputs["examples"],
        schema=schema_gen.outputs["schema"],
        train_args=trainer_pb2.TrainArgs(num_steps=2000),
        # Attaching `TrainerArgs` as platform config is not sensible practice,
        # but is only for testing purpose.
        eval_args=trainer_pb2.EvalArgs(num_steps=5)).with_platform_config(
            config=trainer_pb2.TrainArgs(num_steps=2000))

    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        baseline_model=Channel(type=standard_artifacts.Model,
                               producer_component_id="Trainer"),
        # Cannot add producer_component_id="Evaluator" for model_blessing as it
        # raises "producer component should have already been compiled" error.
        model_blessing=Channel(type=standard_artifacts.ModelBlessing)).with_id(
            "latest_blessed_model_resolver")

    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name="eval")],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    "sparse_categorical_accuracy":
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={"value": 0.6}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={"value": -1e-10}))
                })
        ])
    evaluator = Evaluator(
        examples=example_gen.outputs["examples"],
        model=trainer.outputs["model"],
        baseline_model=model_resolver.outputs["baseline_model"],
        eval_config=eval_config)

    pusher = Pusher(model=trainer.outputs["model"],
                    model_blessing=evaluator.outputs["blessing"],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            importer,
            schema_gen,
            example_validator,
            trainer,
            model_resolver,
            evaluator,
            pusher,
        ],
        enable_cache=False,
        beam_pipeline_args=["--my_testing_beam_pipeline_args=bar"],
        # Attaching `TrainerArgs` as platform config is not sensible practice,
        # but is only for testing purpose.
        platform_config=trainer_pb2.TrainArgs(num_steps=2000),
        execution_mode=pipeline.ExecutionMode.ASYNC)
Exemplo n.º 23
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
  """Implements the chicago taxi pipeline with TFX."""
  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  infer_schema = SchemaGen(
      statistics=statistics_gen.outputs['statistics'],
      infer_feature_shape=False)

  # Performs anomaly detection based on statistics and data schema.
  validate_stats = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=infer_schema.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=infer_schema.outputs['schema'],
      module_file=module_file)

  # Uses user-provided Python function that implements a model using TF-Learn.
  trainer = Trainer(
      module_file=module_file,
      custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=infer_schema.outputs['schema'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

  # Get the latest blessed model for model validation.
  model_resolver = resolver.Resolver(
      strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(
          type=ModelBlessing)).with_id('latest_blessed_model_resolver')

  # Uses TFMA to compute a evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compared to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='tips')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='BinaryAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.6}),
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
          ])
      ])

  model_analyzer = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      # Change threshold will be ignored if there is no baseline (first run).
      eval_config=eval_config)

  # Checks whether the model passed the validation steps and pushes the model
  # to a file destination if check passed.
  pusher = Pusher(
      model=trainer.outputs['model'],
      model_blessing=model_analyzer.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=[
          example_gen,
          statistics_gen,
          infer_schema,
          validate_stats,
          transform,
          trainer,
          model_resolver,
          model_analyzer,
          pusher,
      ],
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      beam_pipeline_args=beam_pipeline_args)
Exemplo n.º 24
0
def _create_parameterized_pipeline(
        pipeline_name: Text, pipeline_root: Text, enable_cache: bool,
        beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
    """Creates a simple TFX pipeline with RuntimeParameter.

  Args:
    pipeline_name: The name of the pipeline.
    pipeline_root: The root of the pipeline output.
    enable_cache: Whether to enable cache in this pipeline.
    beam_pipeline_args: Pipeline arguments for Beam powered Components.

  Returns:
    A logical TFX pipeline.Pipeline object.
  """
    # First, define the pipeline parameters.
    # Path to the CSV data file, under which there should be a data.csv file.
    data_root = data_types.RuntimeParameter(
        name='data-root',
        default='gs://my-bucket/data',
        ptype=Text,
    )

    # Path to the transform module file.
    transform_module_file = data_types.RuntimeParameter(
        name='transform-module',
        default='gs://my-bucket/modules/transform_module.py',
        ptype=Text,
    )

    # Path to the trainer module file.
    trainer_module_file = data_types.RuntimeParameter(
        name='trainer-module',
        default='gs://my-bucket/modules/trainer_module.py',
        ptype=Text,
    )

    # Number of epochs in training.
    train_steps = data_types.RuntimeParameter(
        name='train-steps',
        default=10,
        ptype=int,
    )

    # Number of epochs in evaluation.
    eval_steps = data_types.RuntimeParameter(
        name='eval-steps',
        default=5,
        ptype=int,
    )

    # The input data location is parameterized by data_root
    example_gen = CsvExampleGen(input_base=data_root)

    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # The module file used in Transform and Trainer component is paramterized by
    # transform_module_file.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=transform_module_file)

    # The numbers of steps in train_args are specified as RuntimeParameter with
    # name 'train-steps' and 'eval-steps', respectively.
    trainer = Trainer(
        module_file=trainer_module_file,
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args={'num_steps': train_steps},
        eval_args={'num_steps': eval_steps})

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.config.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    pusher = Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=os.path.join(str(pipeline.ROOT_PARAMETER),
                                            'model_serving'))))

    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=[
                                 example_gen, statistics_gen, schema_gen,
                                 example_validator, transform, trainer,
                                 model_resolver, evaluator, pusher
                             ],
                             enable_cache=enable_cache,
                             beam_pipeline_args=beam_pipeline_args)
Exemplo n.º 25
0
def create_e2e_components(
    pipeline_root: Text,
    csv_input_location: Text,
    transform_module: Text,
    trainer_module: Text,
) -> List[BaseComponent]:
    """Creates components for a simple Chicago Taxi TFX pipeline for testing.

  Args:
    pipeline_root: The root of the pipeline output.
    csv_input_location: The location of the input data directory.
    transform_module: The location of the transform module file.
    trainer_module: The location of the trainer module file.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """
    example_gen = CsvExampleGen(input_base=csv_input_location)
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=transform_module)
    latest_model_resolver = resolver.Resolver(
        strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
        latest_model=Channel(type=Model)).with_id('latest_model_resolver')
    trainer = Trainer(
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        base_model=latest_model_resolver.outputs['latest_model'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10),
        eval_args=trainer_pb2.EvalArgs(num_steps=5),
        module_file=trainer_module,
    )
    # Set the TFMA config for Model Evaluation and Validation.
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        metrics_specs=[
            tfma.MetricsSpec(
                metrics=[tfma.MetricConfig(class_name='ExampleCount')],
                thresholds={
                    'accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.5}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          eval_config=eval_config)

    infra_validator = InfraValidator(
        model=trainer.outputs['model'],
        examples=example_gen.outputs['examples'],
        serving_spec=infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(
                tags=['latest']),
            kubernetes=infra_validator_pb2.KubernetesConfig()),
        request_spec=infra_validator_pb2.RequestSpec(
            tensorflow_serving=infra_validator_pb2.
            TensorFlowServingRequestSpec()))

    pusher = Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=os.path.join(pipeline_root, 'model_serving'))))

    return [
        example_gen,
        statistics_gen,
        schema_gen,
        example_validator,
        transform,
        latest_model_resolver,
        trainer,
        evaluator,
        infra_validator,
        pusher,
    ]
Exemplo n.º 26
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     beam_pipeline_args: List[Text]) -> pipeline.Pipeline:
  """Implements the Bert classication on mrpc dataset pipline with TFX."""
  input_config = example_gen_pb2.Input(splits=[
      example_gen_pb2.Input.Split(name='train', pattern='train/*'),
      example_gen_pb2.Input.Split(name='eval', pattern='validation/*')
  ])

  # Brings data into the pipline
  example_gen = CsvExampleGen(input_base=data_root, input_config=input_config)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  # Uses user-provided Python function that trains a model using TF-Learn.
  trainer = Trainer(
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      # Adjust these steps when training on the full dataset.
      train_args=trainer_pb2.TrainArgs(num_steps=1),
      eval_args=trainer_pb2.EvalArgs(num_steps=1))

  # Get the latest blessed model for model validation.
  model_resolver = resolver.Resolver(
      strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(
          type=ModelBlessing)).with_id('latest_blessed_model_resolver')

  # Uses TFMA to compute evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compared to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='label')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          # Adjust the threshold when training on the
                          # full dataset.
                          lower_bound={'value': 0.5}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-2})))
          ])
      ])
  evaluator = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  # Checks whether the model passed the validation steps and pushes the model
  # to a file destination if check passed.
  pusher = Pusher(
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      example_validator,
      transform,
      trainer,
      model_resolver,
      evaluator,
      pusher,
  ]

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      enable_cache=True,
      beam_pipeline_args=beam_pipeline_args,
  )
Exemplo n.º 27
0
def create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                    module_file: str, serving_model_dir: str,
                    beam_pipeline_args: List[str]) -> pipeline.Pipeline:
    """Implements the chicago taxi pipeline with TFX."""

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = CsvExampleGen(input_base=data_root)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=False)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(examples=example_gen.outputs['examples'],
                          schema=schema_gen.outputs['schema'],
                          module_file=module_file)

    # Uses user-provided Python function that implements a model.
    trainer = Trainer(
        module_file=module_file,
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000))

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ],
        metrics_specs=[
            tfma.MetricsSpec(
                thresholds={
                    'accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.6}),
                        # Change threshold will be ignored if there is no
                        # baseline model resolved from MLMD (first run).
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ])
    evaluator = Evaluator(examples=example_gen.outputs['examples'],
                          model=trainer.outputs['model'],
                          baseline_model=model_resolver.outputs['model'],
                          eval_config=eval_config)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=evaluator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    config = kubernetes_dag_runner.get_default_kubernetes_metadata_config()
    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=[
                                 example_gen,
                                 statistics_gen,
                                 schema_gen,
                                 example_validator,
                                 transform,
                                 trainer,
                                 model_resolver,
                                 evaluator,
                                 pusher,
                             ],
                             enable_cache=False,
                             metadata_connection_config=config,
                             beam_pipeline_args=beam_pipeline_args)
Exemplo n.º 28
0
def _create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_root: Text,
    module_file: Text,
    accuracy_threshold: float,
    serving_model_dir: Text,
    metadata_path: Text,
    enable_tuning: bool,
    examplegen_input_config: Optional[example_gen_pb2.Input],
    examplegen_range_config: Optional[range_config_pb2.RangeConfig],
    resolver_range_config: Optional[range_config_pb2.RangeConfig],
    beam_pipeline_args: List[Text],
) -> pipeline.Pipeline:
  """Implements the penguin pipeline with TFX.

  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline.
    data_root: directory containing the penguin data.
    module_file: path to files used in Trainer and Transform components.
    accuracy_threshold: minimum accuracy to push the model.
    serving_model_dir: filepath to write pipeline SavedModel to.
    metadata_path: path to local pipeline ML Metadata store.
    enable_tuning: If True, the hyperparameter tuning through KerasTuner is
      enabled.
    examplegen_input_config: ExampleGen's input_config.
    examplegen_range_config: ExampleGen's range_config.
    resolver_range_config: SpansResolver's range_config. Specify this will
      enable SpansResolver to get a window of ExampleGen's output Spans for
      transform and training.
    beam_pipeline_args: list of beam pipeline options for LocalDAGRunner. Please
      refer to https://beam.apache.org/documentation/runners/direct/.

  Returns:
    A TFX pipeline object.
  """

  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = CsvExampleGen(
      input_base=data_root,
      input_config=examplegen_input_config,
      range_config=examplegen_range_config)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on statistics files.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  # Performs anomaly detection based on statistics and data schema.
  example_validator = ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  # Gets multiple Spans for transform and training.
  if resolver_range_config:
    examples_resolver = resolver.Resolver(
        instance_name='span_resolver',
        strategy_class=spans_resolver.SpansResolver,
        config={'range_config': resolver_range_config},
        examples=Channel(type=Examples, producer_component_id=example_gen.id))

  # Performs transformations and feature engineering in training and serving.
  transform = Transform(
      examples=(examples_resolver.outputs['examples']
                if resolver_range_config else example_gen.outputs['examples']),
      schema=schema_gen.outputs['schema'],
      module_file=module_file)

  # Tunes the hyperparameters for model training based on user-provided Python
  # function. Note that once the hyperparameters are tuned, you can drop the
  # Tuner component from pipeline and feed Trainer with tuned hyperparameters.
  if enable_tuning:
    tuner = Tuner(
        module_file=module_file,
        examples=transform.outputs['transformed_examples'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=20),
        eval_args=trainer_pb2.EvalArgs(num_steps=5))

  # Uses user-provided Python function that trains a model.
  trainer = Trainer(
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      transform_graph=transform.outputs['transform_graph'],
      schema=schema_gen.outputs['schema'],
      # If Tuner is in the pipeline, Trainer can take Tuner's output
      # best_hyperparameters artifact as input and utilize it in the user module
      # code.
      #
      # If there isn't Tuner in the pipeline, either use ImporterNode to import
      # a previous Tuner's output to feed to Trainer, or directly use the tuned
      # hyperparameters in user module code and set hyperparameters to None
      # here.
      #
      # Example of ImporterNode,
      #   hparams_importer = ImporterNode(
      #     instance_name='import_hparams',
      #     source_uri='path/to/best_hyperparameters.txt',
      #     artifact_type=HyperParameters)
      #   ...
      #   hyperparameters = hparams_importer.outputs['result'],
      hyperparameters=(tuner.outputs['best_hyperparameters']
                       if enable_tuning else None),
      train_args=trainer_pb2.TrainArgs(num_steps=100),
      eval_args=trainer_pb2.EvalArgs(num_steps=5))

  # Get the latest blessed model for model validation.
  model_resolver = resolver.Resolver(
      instance_name='latest_blessed_model_resolver',
      strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=Channel(type=Model),
      model_blessing=Channel(type=ModelBlessing))

  # Uses TFMA to compute evaluation statistics over features of a model and
  # perform quality validation of a candidate model (compared to a baseline).
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(label_key='species')],
      slicing_specs=[tfma.SlicingSpec()],
      metrics_specs=[
          tfma.MetricsSpec(metrics=[
              tfma.MetricConfig(
                  class_name='SparseCategoricalAccuracy',
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': accuracy_threshold}),
                      # Change threshold will be ignored if there is no
                      # baseline model resolved from MLMD (first run).
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
          ])
      ])
  evaluator = Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  # Checks whether the model passed the validation steps and pushes the model
  # to a file destination if check passed.
  pusher = Pusher(
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      example_validator,
      transform,
      trainer,
      model_resolver,
      evaluator,
      pusher,
  ]
  if resolver_range_config:
    components.append(examples_resolver)
  if enable_tuning:
    components.append(tuner)

  return pipeline.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components,
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      beam_pipeline_args=beam_pipeline_args)
Exemplo n.º 29
0
def create_pipeline(
    pipeline_name: Text,
    pipeline_root: Text,
    data_path: Text,
    preprocessing_fn: Text,
    run_fn: Text,
    train_args: trainer_pb2.TrainArgs,
    eval_args: trainer_pb2.EvalArgs,
    eval_accuracy_threshold: float,
    serving_model_dir: Text,
    metadata_connection_config: Optional[
        metadata_store_pb2.ConnectionConfig] = None,
    beam_pipeline_args: Optional[List[Text]] = None,
) -> pipeline.Pipeline:
    """Implements the penguin pipeline with TFX."""

    components = []

    # Brings data into the pipeline or otherwise joins/converts training data.
    # TODO(step 2): Might use another ExampleGen class for your data.
    example_gen = CsvExampleGen(input_base=data_path)
    components.append(example_gen)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
    components.append(statistics_gen)

    # Generates schema based on statistics files.
    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'],
                           infer_feature_shape=True)
    components.append(schema_gen)

    # Performs anomaly detection based on statistics and data schema.
    example_validator = ExampleValidator(  # pylint: disable=unused-variable
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    components.append(example_validator)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(  # pylint: disable=unused-variable
        examples=example_gen.outputs['examples'],
        schema=schema_gen.outputs['schema'],
        preprocessing_fn=preprocessing_fn)
    # TODO(step 3): Uncomment here to add Transform to the pipeline.
    # components.append(transform)

    # Uses user-provided Python function that implements a model using Tensorflow.
    trainer = Trainer(
        run_fn=run_fn,
        examples=example_gen.outputs['examples'],
        # Use outputs of Transform as training inputs if Transform is used.
        # examples=transform.outputs['transformed_examples'],
        # transform_graph=transform.outputs['transform_graph'],
        schema=schema_gen.outputs['schema'],
        train_args=train_args,
        eval_args=eval_args)
    # TODO(step 4): Uncomment here to add Trainer to the pipeline.
    # components.append(trainer)

    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(
            type=ModelBlessing)).with_id('latest_blessed_model_resolver')
    # TODO(step 5): Uncomment here to add Resolver to the pipeline.
    # components.append(model_resolver)

    # Uses TFMA to compute a evaluation statistics over features of a model and
    # perform quality validation of a candidate model (compared to a baseline).
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key=features.LABEL_KEY)],
        slicing_specs=[tfma.SlicingSpec()],
        metrics_specs=[
            tfma.MetricsSpec(metrics=[
                tfma.MetricConfig(
                    class_name='SparseCategoricalAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': eval_accuracy_threshold}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10})))
            ])
        ])
    evaluator = Evaluator(  # pylint: disable=unused-variable
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        # Change threshold will be ignored if there is no baseline (first run).
        eval_config=eval_config)
    # TODO(step 5): Uncomment here to add Evaluator to the pipeline.
    # components.append(evaluator)

    # Pushes the model to a file destination if check passed.
    pusher = Pusher(  # pylint: disable=unused-variable
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=serving_model_dir)))
    # TODO(step 5): Uncomment here to add Pusher to the pipeline.
    # components.append(pusher)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        # Change this value to control caching of execution results. Default value
        # is `False`.
        # enable_cache=True,
        metadata_connection_config=metadata_connection_config,
        beam_pipeline_args=beam_pipeline_args,
    )
def create_pipeline(pipeline_name: Text,
                    pipeline_root: Text,
                    data_root_uri,
                    trainer_config: TrainerConfig,
                    tuner_config: TunerConfig,
                    pusher_config: PusherConfig,
                    runtime_parameters_config: RuntimeParametersConfig = None,
                    str_runtime_parameters_supported = False,
                    int_runtime_parameters_supported = False,
                    local_run: bool = True,
                    beam_pipeline_args: Optional[List[Text]] = None,
                    enable_cache: Optional[bool] = True,
                    code_folder = '',
                    metadata_connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None
                    ) -> pipeline.Pipeline:
    """Trains and deploys the Keras Titanic Classifier with TFX and Kubeflow Pipeline on Google Cloud.
  Args:
    pipeline_name: name of the TFX pipeline being created.
    pipeline_root: root directory of the pipeline. Should be a valid GCS path.
    data_root_uri: uri of the dataset.
    train_steps: runtime parameter for number of model training steps for the Trainer component.
    eval_steps: runtime parameter for number of model evaluation steps for the Trainer component.
    enable_tuning: If True, the hyperparameter tuning through CloudTuner is
      enabled.    
    ai_platform_training_args: Args of CAIP training job. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      for detailed description.
    ai_platform_serving_args: Args of CAIP model deployment. Please refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      for detailed description.
    beam_pipeline_args: Optional list of beam pipeline options. Please refer to
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options.
      When this argument is not provided, the default is to use GCP
      DataflowRunner with 50GB disk size as specified in this function. If an
      empty list is passed in, default specified by Beam will be used, which can
      be found at
      https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options
    enable_cache: Optional boolean
  Returns:
    A TFX pipeline object.
  """
    #pydevd_pycharm.settrace('localhost', port=9091, stdoutToServer=True, stderrToServer=True)

    absl.logging.info('pipeline_name: %s' % pipeline_name)
    absl.logging.info('pipeline root: %s' % pipeline_root)
    absl.logging.info('data_root_uri for training: %s' % data_root_uri)

    absl.logging.info('train_steps for training: %s' % trainer_config.train_steps)
    absl.logging.info('tuner_steps for tuning: %s' % tuner_config.tuner_steps)
    absl.logging.info('eval_steps for evaluating: %s' % trainer_config.eval_steps)

    absl.logging.info('os default list dir: %s' % os.listdir('.'))

    schema_proper_folder = os.path.join(os.sep, code_folder, SCHEMA_FOLDER)
    absl.logging.info('schema_proper_folder: %s' % schema_proper_folder)

    preprocessing_proper_file = os.path.join(os.sep, code_folder, TRANSFORM_MODULE_FILE)
    absl.logging.info('preprocessing_proper_file: %s' % preprocessing_proper_file)

    model_proper_file = os.path.join(os.sep, code_folder, TRAIN_MODULE_FILE)
    absl.logging.info('model_proper_file: %s' % model_proper_file)

    hyperparameters_proper_folder = os.path.join(os.sep, code_folder, HYPERPARAMETERS_FOLDER)
    absl.logging.info('hyperparameters_proper_folder: %s' % hyperparameters_proper_folder)

    # Brings data into the pipeline and splits the data into training and eval splits
    output_config = example_gen_pb2.Output(
        split_config=example_gen_pb2.SplitConfig(splits=[
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
        ]))

    # examples = external_input(data_root_uri)
    if str_runtime_parameters_supported and runtime_parameters_config is not None:
        data_root_uri = runtime_parameters_config.data_root_runtime

    examplegen = CsvExampleGen(input_base=data_root_uri, output_config=output_config)

    # examplegen = CsvExampleGen(input_base=data_root_uri)

    # Computes statistics over data for visualization and example validation.
    statisticsgen = StatisticsGen(examples=examplegen.outputs.examples)

    # Generates schema based on statistics files. Even though, we use user-provided schema
    # we still want to generate the schema of the newest data for tracking and comparison
    schemagen = SchemaGen(statistics=statisticsgen.outputs.statistics)

    # Import a user-provided schema
    import_schema = Importer(
        source_uri=schema_proper_folder,
        artifact_type=Schema).with_id('import_user_schema')

    # Performs anomaly detection based on statistics and data schema.
    examplevalidator = ExampleValidator(
        statistics=statisticsgen.outputs.statistics,
        schema=import_schema.outputs.result)

    # Performs transformations and feature engineering in training and serving.
    transform = Transform(
        examples=examplegen.outputs.examples,
        schema=import_schema.outputs.result,
        module_file=preprocessing_proper_file)

    # Tunes the hyperparameters for model training based on user-provided Python
    # function. Note that once the hyperparameters are tuned, you can drop the
    # Tuner component from pipeline and feed Trainer with tuned hyperparameters.

    hparams_importer = Importer(
        source_uri=hyperparameters_proper_folder,
        artifact_type=HyperParameters).with_id('import_hparams')

    # apparently only str RuntimeParameters are supported in airflow :/
    if int_runtime_parameters_supported and runtime_parameters_config is not None:
        train_steps = runtime_parameters_config.train_steps_runtime
        eval_steps = runtime_parameters_config.eval_steps_runtime
    else:
        train_steps = trainer_config.train_steps
        eval_steps = trainer_config.eval_steps

    absl.logging.info('train_steps: %s' % train_steps)
    absl.logging.info('eval_steps: %s' % eval_steps)

    if tuner_config.enable_tuning:
        tuner_args = {
            'module_file': model_proper_file,
            'examples': transform.outputs.transformed_examples,
            'transform_graph': transform.outputs.transform_graph,
            'train_args': {'num_steps': tuner_config.tuner_steps},
            'eval_args': {'num_steps': tuner_config.eval_tuner_steps},
            'custom_config': {'max_trials': tuner_config.max_trials, 'is_local_run': local_run}
            # 'tune_args': tuner_pb2.TuneArgs(num_parallel_trials=3),
        }

        if tuner_config.ai_platform_tuner_args is not None:
            tuner_args.update({
                'custom_config': {
                    ai_platform_trainer_executor.TRAINING_ARGS_KEY: tuner_config.ai_platform_tuner_args
                },
                'tune_args': tuner_pb2.TuneArgs(num_parallel_trials=3)
            })

        absl.logging.info("tuner_args: " + str(tuner_args))
        tuner = Tuner(**tuner_args)

    hyperparameters = tuner.outputs.best_hyperparameters if tuner_config.enable_tuning else hparams_importer.outputs['result']

    # Trains the model using a user provided trainer function.

    trainer_args = {
        'module_file': model_proper_file,
        'transformed_examples': transform.outputs.transformed_examples,
        'schema': import_schema.outputs.result,
        'transform_graph': transform.outputs.transform_graph,
        # train_args={'num_steps': train_steps},
        'train_args': {'num_steps': train_steps},
        'eval_args': {'num_steps': eval_steps},
        #'hyperparameters': tuner.outputs.best_hyperparameters if tunerConfig.enable_tuning else None,
        'hyperparameters': hyperparameters,
        'custom_config': {'epochs': trainer_config.epochs, 'train_batch_size': trainer_config.train_batch_size,
                          'eval_batch_size': trainer_config.eval_batch_size,
                         }
    }

    if trainer_config.ai_platform_training_args is not None:
        trainer_args['custom_config'].update({
            ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                trainer_config.ai_platform_training_args,
        })
        trainer_args.update({
            'custom_executor_spec':
                executor_spec.ExecutorClassSpec(ai_platform_trainer_executor.GenericExecutor),
            # 'custom_config': {
            #    ai_platform_trainer_executor.TRAINING_ARGS_KEY:
            #        ai_platform_training_args,
            # }
        })
    else:
        trainer_args.update({
            'custom_executor_spec':
                executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
                #executor_spec.ExecutorClassSpec(custom_trainer_executor.CustomGenericExecutor), # for debugging purposes
        })

    trainer = Trainer(**trainer_args)

    # Get the latest blessed model for model validation.

    model_resolver = resolver.Resolver(
        #instance_name='latest_blessed_model_resolver', # instance_name is deprecated, use with_id()
        strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
        model=Channel(type=Model),
        model_blessing=Channel(type=ModelBlessing)\
        ).with_id('latest_blessed_model_resolver')

    # Uses TFMA to compute a evaluation statistics over features of a model.
    accuracy_threshold = tfma.MetricThreshold(
        value_threshold=tfma.GenericValueThreshold(
            lower_bound={'value': 0.5},
            upper_bound={'value': 0.995}),
    )

    metrics_specs = tfma.MetricsSpec(
        metrics=[
            tfma.MetricConfig(class_name='BinaryAccuracy',
                              threshold=accuracy_threshold),
            tfma.MetricConfig(class_name='ExampleCount')])

    eval_config = tfma.EvalConfig(
        model_specs=[
            tfma.ModelSpec(label_key='Survived')
        ],
        metrics_specs=[metrics_specs],
        slicing_specs=[
            tfma.SlicingSpec()
            ,tfma.SlicingSpec(feature_keys=['Sex'])
            ,tfma.SlicingSpec(feature_keys=['Age'])
            ,tfma.SlicingSpec(feature_keys=['Parch'])
        ]
    )

    evaluator = Evaluator(
        examples=examplegen.outputs.examples,
        model=trainer.outputs.model,
        baseline_model=model_resolver.outputs.model,
        eval_config=eval_config
    )

    # Validate model can be loaded and queried in sand-boxed environment
    # mirroring production.

    serving_config = None

    if local_run:
        serving_config = infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(tags=['latest']),
            local_docker=infra_validator_pb2.LocalDockerConfig()  # Running on local docker.
        )
    else:
        serving_config = infra_validator_pb2.ServingSpec(
            tensorflow_serving=infra_validator_pb2.TensorFlowServing(tags=['latest']),
            kubernetes=infra_validator_pb2.KubernetesConfig()  # Running on K8s.
        )

    validation_config = infra_validator_pb2.ValidationSpec(
        max_loading_time_seconds=60,
        num_tries=3,
    )

    request_config = infra_validator_pb2.RequestSpec(
        tensorflow_serving=infra_validator_pb2.TensorFlowServingRequestSpec(),
        num_examples=3,
    )

    infravalidator = InfraValidator(
        model=trainer.outputs.model,
        examples=examplegen.outputs.examples,
        serving_spec=serving_config,
        validation_spec=validation_config,
        request_spec=request_config,
    )

    # Checks whether the model passed the validation steps and pushes the model
    # to CAIP Prediction if checks are passed.

    pusher_args = {
        'model': trainer.outputs.model,
        'model_blessing': evaluator.outputs.blessing,
        'infra_blessing': infravalidator.outputs.blessing
    }

    if local_run:
        pusher_args.update({'push_destination':
            pusher_pb2.PushDestination(
                filesystem=pusher_pb2.PushDestination.Filesystem(
                    base_directory=pusher_config.serving_model_dir))})

    if pusher_config.ai_platform_serving_args is not None:
        pusher_args.update({
            'custom_executor_spec':
                executor_spec.ExecutorClassSpec(ai_platform_pusher_executor.Executor
                                                ),
            'custom_config': {
                ai_platform_pusher_executor.SERVING_ARGS_KEY:
                    pusher_config.ai_platform_serving_args
            },
        })

    pusher = Pusher(**pusher_args)  # pylint: disable=unused-variable

    components = [
        examplegen,
        statisticsgen,
        schemagen,
        import_schema,
        examplevalidator,
        transform,
        trainer,
        model_resolver,
        evaluator,
        infravalidator,
        pusher
    ]

    if tuner_config.enable_tuning:
        components.append(tuner)
    else:
        components.append(hparams_importer)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=enable_cache,
        metadata_connection_config=metadata_connection_config,
        beam_pipeline_args=beam_pipeline_args
    )