Esempio n. 1
0
    def generate_models(self, args):
        # Modified version of Chicago Taxi Example pipeline
        # tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py

        root = tempfile.mkdtemp()
        pipeline_root = os.path.join(root, "pipeline")
        metadata_path = os.path.join(root, "metadata/metadata.db")
        module_file = os.path.join(
            os.path.dirname(__file__),
            "../../../examples/chicago_taxi_pipeline/taxi_utils.py")

        examples = external_input(os.path.dirname(self.dataset_path()))
        example_gen = components.ImportExampleGen(input=examples)
        statistics_gen = components.StatisticsGen(
            examples=example_gen.outputs["examples"])
        schema_gen = components.SchemaGen(
            statistics=statistics_gen.outputs["statistics"],
            infer_feature_shape=False)
        transform = components.Transform(
            examples=example_gen.outputs["examples"],
            schema=schema_gen.outputs["schema"],
            module_file=module_file)
        trainer = components.Trainer(
            module_file=module_file,
            transformed_examples=transform.outputs["transformed_examples"],
            schema=schema_gen.outputs["schema"],
            transform_graph=transform.outputs["transform_graph"],
            train_args=trainer_pb2.TrainArgs(num_steps=100),
            eval_args=trainer_pb2.EvalArgs(num_steps=50))
        p = pipeline.Pipeline(pipeline_name="chicago_taxi_beam",
                              pipeline_root=pipeline_root,
                              components=[
                                  example_gen, statistics_gen, schema_gen,
                                  transform, trainer
                              ],
                              enable_cache=True,
                              metadata_connection_config=metadata.
                              sqlite_metadata_connection_config(metadata_path))
        BeamDagRunner().run(p)

        def join_unique_subdir(path):
            dirs = os.listdir(path)
            if len(dirs) != 1:
                raise ValueError(
                    "expecting there to be only one subdirectory in %s, but "
                    "subdirectories were: %s" % (path, dirs))
            return os.path.join(path, dirs[0])

        trainer_output_dir = join_unique_subdir(
            os.path.join(pipeline_root, "Trainer/output"))
        eval_model_dir = join_unique_subdir(
            os.path.join(trainer_output_dir, "eval_model_dir"))
        serving_model_dir = join_unique_subdir(
            os.path.join(trainer_output_dir,
                         "serving_model_dir/export/chicago-taxi"))

        shutil.rmtree(self.trained_saved_model_path(), ignore_errors=True)
        shutil.rmtree(self.tfma_saved_model_path(), ignore_errors=True)
        shutil.copytree(serving_model_dir, self.trained_saved_model_path())
        shutil.copytree(eval_model_dir, self.tfma_saved_model_path())
Esempio n. 2
0
    def testBuildFileBasedExampleGenWithInputConfig(self):
        input_config = example_gen_pb2.Input(splits=[
            example_gen_pb2.Input.Split(name='train', pattern='*train.tfr'),
            example_gen_pb2.Input.Split(name='eval', pattern='*test.tfr')
        ])
        example_gen = components.ImportExampleGen(
            input_base='path/to/data/root', input_config=input_config)
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=example_gen,
            image='gcr.io/tensorflow/tfx:latest',
            deployment_config=deployment_config,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_import_example_gen_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Esempio n. 3
0
  def _make_example_gen(self) -> base_component.BaseComponent:
    """Returns a TFX ExampleGen which produces the desired split."""

    splits = []
    for name, value in self._dataset_builder.info.splits.items():
      # Assume there is only one file per split.
      # Filename will be like `'fashion_mnist-test.tfrecord-00000-of-00001'`.
      assert len(value.filenames) == 1
      pattern = value.filenames[0]
      splits.append(example_gen_pb2.Input.Split(name=name, pattern=pattern))

    logging.info('Splits: %s', splits)
    input_config = example_gen_pb2.Input(splits=splits)
    return tfx.ImportExampleGen(
        input=external_input(self._dataset_builder.data_dir),
        input_config=input_config)