コード例 #1
0
ファイル: dataset.py プロジェクト: ysjeon7/tfx
    def generate_models(self, args):
        # Modified version of Chicago Taxi Example pipeline
        # tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py

        root = tempfile.mkdtemp()
        pipeline_root = os.path.join(root, "pipeline")
        metadata_path = os.path.join(root, "metadata/metadata.db")
        module_file = os.path.join(
            os.path.dirname(__file__),
            "../../../examples/chicago_taxi_pipeline/taxi_utils.py")

        examples = external_input(os.path.dirname(self.dataset_path()))
        example_gen = components.ImportExampleGen(input=examples)
        statistics_gen = components.StatisticsGen(
            examples=example_gen.outputs["examples"])
        schema_gen = components.SchemaGen(
            statistics=statistics_gen.outputs["statistics"],
            infer_feature_shape=False)
        transform = components.Transform(
            examples=example_gen.outputs["examples"],
            schema=schema_gen.outputs["schema"],
            module_file=module_file)
        trainer = components.Trainer(
            module_file=module_file,
            transformed_examples=transform.outputs["transformed_examples"],
            schema=schema_gen.outputs["schema"],
            transform_graph=transform.outputs["transform_graph"],
            train_args=trainer_pb2.TrainArgs(num_steps=100),
            eval_args=trainer_pb2.EvalArgs(num_steps=50))
        p = pipeline.Pipeline(pipeline_name="chicago_taxi_beam",
                              pipeline_root=pipeline_root,
                              components=[
                                  example_gen, statistics_gen, schema_gen,
                                  transform, trainer
                              ],
                              enable_cache=True,
                              metadata_connection_config=metadata.
                              sqlite_metadata_connection_config(metadata_path))
        BeamDagRunner().run(p)

        def join_unique_subdir(path):
            dirs = os.listdir(path)
            if len(dirs) != 1:
                raise ValueError(
                    "expecting there to be only one subdirectory in %s, but "
                    "subdirectories were: %s" % (path, dirs))
            return os.path.join(path, dirs[0])

        trainer_output_dir = join_unique_subdir(
            os.path.join(pipeline_root, "Trainer/output"))
        eval_model_dir = join_unique_subdir(
            os.path.join(trainer_output_dir, "eval_model_dir"))
        serving_model_dir = join_unique_subdir(
            os.path.join(trainer_output_dir,
                         "serving_model_dir/export/chicago-taxi"))

        shutil.rmtree(self.trained_saved_model_path(), ignore_errors=True)
        shutil.rmtree(self.tfma_saved_model_path(), ignore_errors=True)
        shutil.copytree(serving_model_dir, self.trained_saved_model_path())
        shutil.copytree(eval_model_dir, self.tfma_saved_model_path())
コード例 #2
0
    def _build_statistics_gen(
            self, examples: types.Channel,
            instance_name: Optional[str]) -> tfx.StatisticsGen:
        """Returns the StatisticsGen component."""

        # TODO(b/156134844): Allow passing TFDV StatsOptions to automatically infer
        # useful semantic types
        return tfx.StatisticsGen(examples=examples,
                                 instance_name=instance_name)
コード例 #3
0
def two_step_pipeline() -> tfx_pipeline.Pipeline:
    """Returns a simple 2-step pipeline under test."""
    example_gen = big_query_example_gen_component.BigQueryExampleGen(
        query='SELECT * FROM TABLE')
    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs['examples'])
    return tfx_pipeline.Pipeline(
        pipeline_name=_TEST_TWO_STEP_PIPELINE_NAME,
        pipeline_root=_TEST_PIPELINE_ROOT,
        components=[example_gen, statistics_gen],
        # Needs to set GCP project because BQ is used.
        beam_pipeline_args=[
            '--project=my-gcp-project',
        ])
コード例 #4
0
ファイル: test_utils.py プロジェクト: jay90099/tfx
def create_e2e_components(csv_input_location: str, ) -> List[BaseComponent]:
    """Creates components for a simple Chicago Taxi TFX pipeline for testing.

     Because we don't need to run whole pipeline, we will make a very short
     toy pipeline.

  Args:
    csv_input_location: The location of the input data directory.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """

    example_gen = components.CsvExampleGen(input_base=csv_input_location)
    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs['examples'])
    schema_gen = components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=False)

    return [example_gen, statistics_gen, schema_gen]
コード例 #5
0
def create_pipeline_components(
    pipeline_root: Text,
    transform_module: Text,
    trainer_module: Text,
    bigquery_query: Text = '',
    csv_input_location: Text = '',
) -> List[base_node.BaseNode]:
    """Creates components for a simple Chicago Taxi TFX pipeline for testing.

  Args:
    pipeline_root: The root of the pipeline output.
    transform_module: The location of the transform module file.
    trainer_module: The location of the trainer module file.
    bigquery_query: The query to get input data from BigQuery. If not empty,
      BigQueryExampleGen will be used.
    csv_input_location: The location of the input data directory.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """

    if bool(bigquery_query) == bool(csv_input_location):
        raise ValueError(
            'Exactly one example gen is expected. ',
            'Please provide either bigquery_query or csv_input_location.')

    if bigquery_query:
        example_gen = big_query_example_gen_component.BigQueryExampleGen(
            query=bigquery_query)
    else:
        example_gen = components.CsvExampleGen(input_base=csv_input_location)

    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs['examples'])
    schema_gen = components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=False)
    example_validator = components.ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    transform = components.Transform(examples=example_gen.outputs['examples'],
                                     schema=schema_gen.outputs['schema'],
                                     module_file=transform_module)
    latest_model_resolver = resolver.Resolver(
        strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
        model=channel.Channel(type=standard_artifacts.Model)).with_id(
            'Resolver.latest_model_resolver')
    trainer = components.Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        base_model=latest_model_resolver.outputs['model'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10),
        eval_args=trainer_pb2.EvalArgs(num_steps=5),
        module_file=trainer_module,
    )
    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=channel.Channel(type=standard_artifacts.Model),
        model_blessing=channel.Channel(
            type=standard_artifacts.ModelBlessing)).with_id(
                'Resolver.latest_blessed_model_resolver')
    # Set the TFMA config for Model Evaluation and Validation.
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        metrics_specs=[
            tfma.MetricsSpec(
                metrics=[tfma.MetricConfig(class_name='ExampleCount')],
                thresholds={
                    'binary_accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.5}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ])
    evaluator = components.Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        eval_config=eval_config)

    pusher = components.Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=os.path.join(pipeline_root, 'model_serving'))))

    return [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        latest_model_resolver, trainer, model_resolver, evaluator, pusher
    ]