Example #1
0
  def testBuildLatestBlessedModelResolverSucceed(self):

    latest_blessed_resolver = components.ResolverNode(
        instance_name='my_resolver2',
        resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
        model=channel.Channel(type=standard_artifacts.Model),
        model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing))
    test_pipeline_info = data_types.PipelineInfo(
        pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root')

    deployment_config = pipeline_pb2.PipelineDeploymentConfig()
    my_builder = step_builder.StepBuilder(
        node=latest_blessed_resolver,
        deployment_config=deployment_config,
        pipeline_info=test_pipeline_info)
    actual_step_specs = my_builder.build()

    self.assertProtoEquals(
        text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1,
                          pipeline_pb2.PipelineTaskSpec()),
        actual_step_specs[0])

    self.assertProtoEquals(
        text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2,
                          pipeline_pb2.PipelineTaskSpec()),
        actual_step_specs[1])

    self.assertProtoEquals(
        text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR,
                          pipeline_pb2.PipelineDeploymentConfig()),
        deployment_config)
Example #2
0
  def testBuildLatestArtifactResolverSucceed(self):
    latest_model_resolver = components.ResolverNode(
        instance_name='my_resolver',
        resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
        model=channel.Channel(type=standard_artifacts.Model),
        examples=channel.Channel(type=standard_artifacts.Examples))
    deployment_config = pipeline_pb2.PipelineDeploymentConfig()
    test_pipeline_info = data_types.PipelineInfo(
        pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root')
    my_builder = step_builder.StepBuilder(
        node=latest_model_resolver,
        deployment_config=deployment_config,
        pipeline_info=test_pipeline_info)
    actual_step_spec = self._sole(my_builder.build())

    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_latest_artifact_resolver.pbtxt',
            pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
    self.assertProtoEquals(
        test_utils.get_proto_from_test_data(
            'expected_latest_artifact_resolver_executor.pbtxt',
            pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Example #3
0
def create_pipeline_components(
    pipeline_root: Text,
    transform_module: Text,
    trainer_module: Text,
    bigquery_query: Text = '',
    csv_input_location: Text = '',
) -> List[base_node.BaseNode]:
  """Creates components for a simple Chicago Taxi TFX pipeline for testing.

  Args:
    pipeline_root: The root of the pipeline output.
    transform_module: The location of the transform module file.
    trainer_module: The location of the trainer module file.
    bigquery_query: The query to get input data from BigQuery. If not empty,
      BigQueryExampleGen will be used.
    csv_input_location: The location of the input data directory.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """

  if bool(bigquery_query) == bool(csv_input_location):
    raise ValueError(
        'Exactly one example gen is expected. ',
        'Please provide either bigquery_query or csv_input_location.')

  if bigquery_query:
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        query=bigquery_query)
  else:
    examples = dsl_utils.external_input(csv_input_location)
    example_gen = components.CsvExampleGen(input=examples)

  statistics_gen = components.StatisticsGen(
      examples=example_gen.outputs['examples'])
  schema_gen = components.SchemaGen(
      statistics=statistics_gen.outputs['statistics'],
      infer_feature_shape=False)
  example_validator = components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])
  transform = components.Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=transform_module)
  latest_model_resolver = components.ResolverNode(
      instance_name='latest_model_resolver',
      resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
      model=channel.Channel(type=standard_artifacts.Model))
  trainer = components.Trainer(
      transformed_examples=transform.outputs['transformed_examples'],
      schema=schema_gen.outputs['schema'],
      base_model=latest_model_resolver.outputs['model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10),
      eval_args=trainer_pb2.EvalArgs(num_steps=5),
      module_file=trainer_module,
  )
  # Get the latest blessed model for model validation.
  model_resolver = components.ResolverNode(
      instance_name='latest_blessed_model_resolver',
      resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
      model=channel.Channel(type=standard_artifacts.Model),
      model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing))
  # Set the TFMA config for Model Evaluation and Validation.
  eval_config = tfma.EvalConfig(
      model_specs=[tfma.ModelSpec(signature_name='eval')],
      metrics_specs=[
          tfma.MetricsSpec(
              metrics=[tfma.MetricConfig(class_name='ExampleCount')],
              thresholds={
                  'binary_accuracy':
                      tfma.MetricThreshold(
                          value_threshold=tfma.GenericValueThreshold(
                              lower_bound={'value': 0.5}),
                          change_threshold=tfma.GenericChangeThreshold(
                              direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                              absolute={'value': -1e-10}))
              })
      ],
      slicing_specs=[
          tfma.SlicingSpec(),
          tfma.SlicingSpec(feature_keys=['trip_start_hour'])
      ])
  evaluator = components.Evaluator(
      examples=example_gen.outputs['examples'],
      model=trainer.outputs['model'],
      baseline_model=model_resolver.outputs['model'],
      eval_config=eval_config)

  pusher = components.Pusher(
      model=trainer.outputs['model'],
      model_blessing=evaluator.outputs['blessing'],
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=os.path.join(pipeline_root, 'model_serving'))))

  return [
      example_gen, statistics_gen, schema_gen, example_validator, transform,
      latest_model_resolver, trainer, model_resolver, evaluator, pusher
  ]