Пример #1
0
def run_fn(fn_args: TrainerFnArgs):
    """Train the model based on given args.

    Args:
            fn_args: Holds args used to train the model as name/value pairs.
    """

    # todo: if your model uses hyperparameters uncomment below
    # hparams = fn_args.hyperparameters
    # if type(hparams) is dict and 'values' in hparams.keys():
    # 		hparams = hparams['values']

    # need schema ? uncomment below
    # schema = schema_pb2.Schema()
    # schema_text = file_io.read_file_to_string(fn_args.schema_file)
    # text_format.Parse(schema_text, schema)
    # feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec

    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = _input_fn(fn_args.train_files, tf_transform_output)
    eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        # uncommend depending on hparams
        # model = _build_keras_model()
        # model = _build_keras_model(hparams=hparams)
    try:
        log_dir = fn_args.model_run_dir
    except KeyError:
        log_dir = os.path.join(os.path.dirname(
            fn_args.serving_model_dir), 'logs')

    # Write logs to path
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir, update_freq='batch')

    model.fit(
        train_dataset,
        steps_per_epoch=fn_args.train_steps,
        validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps,
        callbacks=[tensorboard_callback])

    signatures = {
        'serving_default':
        _get_serve_tf_examples_fn(model,
                                  tf_transform_output).get_concrete_function(
            tf.TensorSpec(
                shape=[None],
                dtype=tf.string,
                name='examples'))
    }
    model.save(fn_args.serving_model_dir,
               save_format='tf', signatures=signatures)


def get_pipeline():
    # todo: specify evalualtion configuration (optional)
    # eval_config = tfma.EvalConfig(
    #     model_specs=[tfma.ModelSpec(label_key=LABEL_KEY)],
    #     slicing_specs=[tfma.SlicingSpec()],
    #     metrics_specs=[
    #         tfma.MetricsSpec(metrics=[
    #             tfma.MetricConfig(
    #                 class_name='BinaryAccuracy',
    #                 threshold=tfma.MetricThreshold(
    #                     value_threshold=tfma.GenericValueThreshold(
    #                         lower_bound={'value': 0.6}),
    #                     change_threshold=tfma.GenericChangeThreshold(
    #                         direction=tfma.MetricDirection.HIGHER_IS_BETTER,
    #                         absolute={'value': -1e-10})))
    #         ])
    #     ])

    # this is relative to the root of the python process dir that runs the pipeline
    # todo:
    # fn_file = 'examples/chicago_taxi_pipeline/pipeline.py'

    # todo: mix and match
    # data sources:
    # from csv
    # from tfrecords
    # from bigquery
    # from custom component

    # if you want to infer a schema, you have to generate statistics

    # model evaluation is not required for pusher, but strongly recommended

    # depending on your running platform of choice, you can add custom parameters
    # to train, tuner and infra validator for example (gcp ai platform supports it)

    return ftfx.PipelineDef(name='<fill in>') \
        #     .from_csv(uri=<specify>) \
    #     .generate_statistics() \
    #     .infer_schema() \
    #     .preprocess(fn_file) \
    #     .train(fn_file,
    #            train_args=trainer_pb2.TrainArgs(num_steps=1000),
    #            eval_args=trainer_pb2.EvalArgs(num_steps=150)) \
    #     .evaluate_model(eval_config=eval_config) \
    #     .push_to(relative_push_uri='serving_model') \
    #     .cache() \
    #     .with_sqlite_ml_metadata() \
    #     .with_beam_pipeline_args([
    #         '--direct_running_mode=multi_processing',
    #         '--direct_num_workers=0',
    #     ]) \
    .build()


if __name__ == '__main__':
    # absl.logging.set_verbosity(absl.logging.INFO)
    pipeline = get_pipeline()
    BeamDagRunner().run(pipeline)
def run_pipeline_on_beam():
    """Runs the pipelineon Beam."""
    pipeline = create_pipeline()
    BeamDagRunner().run(pipeline)
Пример #3
0
            example_validator,
            transform,
            trainer,
            model_resolver,
            evaluator,
            pusher,
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        # TODO(b/142684737): The multi-processing API might change.
        beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers],
    )


# To run this pipeline from the python CLI:
#   $python iris_pipeline_native_keras.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)
    BeamDagRunner().run(
        _create_pipeline(
            pipeline_name=_pipeline_name,
            pipeline_root=_pipeline_root,
            data_root=_data_root,
            module_file=_module_file,
            serving_model_dir=_serving_model_dir,
            metadata_path=_metadata_path,
            # 0 means auto-detect based on the number of CPUs available during
            # execution time.
            direct_num_workers=0))
    def testStubbedImdbPipelineBeam(self):
        pipeline_ir = compiler.Compiler().compile(self.imdb_pipeline)

        pipeline_mock.replace_executor_with_stub(pipeline_ir,
                                                 self._recorded_output_dir, [])

        BeamDagRunner().run_with_ir(pipeline_ir)

        self.assertTrue(fileio.exists(self._metadata_path))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)

        # Verify that recorded files are successfully copied to the output uris.
        with metadata.Metadata(metadata_config) as m:
            for execution in m.store.get_executions():
                component_id = pipeline_recorder_utils.get_component_id_from_execution(
                    m, execution)
                if component_id.startswith('Resolver'):
                    continue
                eid = [execution.id]
                events = m.store.get_events_by_execution_ids(eid)
                output_events = [
                    x for x in events
                    if x.type == metadata_store_pb2.Event.OUTPUT
                ]
                for event in output_events:
                    steps = event.path.steps
                    assert steps[0].HasField('key')
                    name = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    for idx, artifact in enumerate(artifacts):
                        self.assertDirectoryEqual(
                            artifact.uri,
                            os.path.join(self._recorded_output_dir,
                                         component_id, name, str(idx)))

        # Calls verifier for pipeline output artifacts, excluding the resolver node.
        BeamDagRunner().run(self.imdb_pipeline)
        pipeline_outputs = executor_verifier_utils.get_pipeline_outputs(
            self.imdb_pipeline.metadata_connection_config, self._pipeline_name)

        verifier_map = {
            'model': self._verify_model,
            'model_run': self._verify_model,
            'examples': self._verify_examples,
            'schema': self._verify_schema,
            'anomalies': self._verify_anomalies,
            'evaluation': self._verify_evaluation,
            # A subdirectory of updated_analyzer_cache has changing name.
            'updated_analyzer_cache': self._veryify_root_dir,
        }

        # List of components to verify. Resolver is ignored because it
        # doesn't have an executor.
        verify_component_ids = [
            component.id for component in self.imdb_pipeline.components
            if not component.id.startswith('Resolver')
        ]

        for component_id in verify_component_ids:
            for key, artifact_dict in pipeline_outputs[component_id].items():
                for idx, artifact in artifact_dict.items():
                    logging.info('Verifying %s', component_id)
                    recorded_uri = os.path.join(self._recorded_output_dir,
                                                component_id, key, str(idx))
                    verifier_map.get(key, self._verify_file_path)(artifact.uri,
                                                                  recorded_uri)
Пример #5
0
    def testTaxiPipelineWithImporter(self):
        BeamDagRunner().run(
            taxi_pipeline_importer._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                user_schema_path=self._user_schema_path,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(10, execution_count)

        self.assertPipelineExecution()

        # Runs the pipeline again.
        BeamDagRunner().run(
            taxi_pipeline_importer._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                user_schema_path=self._user_schema_path,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # All executions but Evaluator and Pusher are cached.
        # Note that Resolver will always execute.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
            artifact_count = len(m.store.get_artifacts())
            self.assertEqual(20, len(m.store.get_executions()))

        # Runs the pipeline the third time.
        BeamDagRunner().run(
            taxi_pipeline_importer._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                user_schema_path=self._user_schema_path,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(30, len(m.store.get_executions()))
Пример #6
0
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            infer_schema,
            validate_stats,
            tuner,
            trainer,
            model_analyzer,
            model_validator,
            pusher,
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
    )


# To run this pipeline from the python CLI:
#   $python iris_pipeline_tuner.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)
    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         metadata_path=_metadata_path))
    def testIrisPipelineNativeKeras(self):
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                direct_num_workers=1))

        self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        expected_execution_count = 10  # 9 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution()
        self.assertInfraValidatorPassed()

        # Runs pipeline the second time.
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                direct_num_workers=1))

        # All executions but Evaluator and Pusher are cached.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
            artifact_count = len(m.store.get_artifacts())
            self.assertEqual(expected_execution_count * 2,
                             len(m.store.get_executions()))

        # Runs pipeline the third time.
        BeamDagRunner().run(
            iris_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                direct_num_workers=1))

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(expected_execution_count * 3,
                             len(m.store.get_executions()))
Пример #8
0
  save_summary_steps=_SAVE_SUMMARY_STEPS,
  save_checkpoints_secs=_SAVE_CHECKPOINT_SECS
)

preprocessing_fn = preprocess.preprocess_factory(
  categorical_feature_keys=full._CATEGORICAL_FEATURE_KEYS,
  numerical_feature_keys=full._NUMERICAL_FEATURE_KEYS,
  label_key=full._LABEL_KEY,
)

beam_pipeline_args = [
  '--project=' + _GCP_PROJECT
]

if __name__ == "__main__":
  DAG = BeamDagRunner().run(
    create_pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      pipeline_mod=pipeline_mod,
      schema_uri=full.schema_uri,
      transform_graph_uri=full.transform_graph_uri,
      model_uri=full.model_uri,
      query=_QUERY,
      num_train_steps=_NUM_TRAIN_STEPS,
      num_eval_steps=_NUM_EVAL_STEPS,
      beam_pipeline_args=beam_pipeline_args,
      metadata_path=os.path.join(pipeline_root, 'metadata', 'metadata.db')
      )
  )
Пример #9
0
                          module_file=module_file)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen,
            statistics_gen,
            schema_gen,
            transform,
        ],
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        enable_cache=True,
        beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers],
    )


if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)
    BeamDagRunner().run(
        _create_pipeline(
            pipeline_name=_pipeline_name,
            pipeline_root=_pipeline_root,
            data_root=_data_root,
            module_file=_module_file,
            serving_model_dir=_serving_model_dir,
            metadata_path=_metadata_path,
            # Error occurs if direct_num_workers set to 0
            direct_num_workers=1))
Пример #10
0
  def testIrisPipelineSklearn(self, mock_runner):
    BeamDagRunner().run(
        iris_pipeline_sklearn._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            ai_platform_serving_args=self._ai_platform_serving_args,
            direct_num_workers=1))

    self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
    self.assertTrue(tf.io.gfile.exists(self._metadata_path))
    mock_runner.deploy_model_for_aip_prediction.assert_called_once()
    expected_execution_count = 8  # 8 components
    metadata_config = metadata.sqlite_metadata_connection_config(
        self._metadata_path)
    with metadata.Metadata(metadata_config) as m:
      artifact_count = len(m.store.get_artifacts())
      execution_count = len(m.store.get_executions())
      self.assertGreaterEqual(artifact_count, execution_count)
      self.assertEqual(expected_execution_count, execution_count)

    self.assertPipelineExecution()

    # Runs pipeline the second time.
    BeamDagRunner().run(
        iris_pipeline_sklearn._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            ai_platform_serving_args=self._ai_platform_serving_args,
            direct_num_workers=1))

    # All executions but Evaluator and Pusher are cached.
    with metadata.Metadata(metadata_config) as m:
      self.assertEqual(artifact_count, len(m.store.get_artifacts()))
      artifact_count = len(m.store.get_artifacts())
      self.assertEqual(expected_execution_count * 2,
                       len(m.store.get_executions()))

    # Runs pipeline the third time.
    BeamDagRunner().run(
        iris_pipeline_sklearn._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            ai_platform_serving_args=self._ai_platform_serving_args,
            direct_num_workers=1))

    # Asserts cache execution.
    with metadata.Metadata(metadata_config) as m:
      # Artifact count is unchanged.
      self.assertEqual(artifact_count, len(m.store.get_artifacts()))
      self.assertEqual(expected_execution_count * 3,
                       len(m.store.get_executions()))
Пример #11
0
  vc.save_summary_steps = 1
  vc.save_checkpoints_secs = 14400
  vc.learning_rate = 2e-5

  var_names = pipeline_var_names(
    vc.run_dir,
    vc.run_str,
    vc.mlp_project,
    vc.mlp_subproject,
    vc.runner,
    vc.pipeline_type
  )
  vc.add_vars(**var_names)

  vc.base_model_uri = latest_artifact_path(prev_run_root, 'data/Trainer/model')
  vc.write(vc.vc_config_path)

  DAG = BeamDagRunner().run(
    create_pipeline(
      prev_run_root=vc.prev_run_root,
      run_root=vc.run_root,
      pipeline_name=vc.pipeline_name,
      pipeline_mod=vc.pipeline_mod,
      query=vc.query,
      base_model_uri=vc.base_model_uri,
      beam_pipeline_args=vc.beam_pipeline_args,
      metadata_path=vc.metadata_path,
      custom_config=vc.get_vars(),
    )
  )
Пример #12
0
import datetime
import logging
import shutil

from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner

from beam_local import config, builder

conf = config.load()

timestamp = datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%m')

if __name__ == "__main__":

    # try:
    # shutil.rmtree(conf['metadata_path'])
    # except FileNotFoundError as e:
    # pass

    logging.basicConfig(level='INFO')

    tfx_pipeline = builder.build_pipeline(timestamp)

    BeamDagRunner().run(tfx_pipeline)
Пример #13
0
    def testTaxiPipelineBeam(self):
        num_components = 10

        BeamDagRunner().run(
            taxi_pipeline_infraval_beam._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
        self.assertTrue(tf.io.gfile.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(num_components, execution_count)

        self.assertPipelineExecution()
        self.assertInfraValidatorPassed()

        # Runs pipeline the second time.
        BeamDagRunner().run(
            taxi_pipeline_infraval_beam._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # All executions but Evaluator and Pusher are cached.
        # Note that Resolver will always execute.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertEqual(artifact_count + 3, len(m.store.get_artifacts()))
            artifact_count = len(m.store.get_artifacts())
            # 10 more cached executions.
            self.assertEqual(num_components * 2, len(m.store.get_executions()))

        # Runs pipeline the third time.
        BeamDagRunner().run(
            taxi_pipeline_infraval_beam._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            # 10 more cached executions.
            self.assertEqual(num_components * 3, len(m.store.get_executions()))
Пример #14
0
                            base_directory=serving_model_dir_lite)))

    components = [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        trainer, model_resolver, evaluator, pusher
    ]

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)


# To run this pipeline from the python CLI:
#   $python cifar_pipeline_native_keras.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)
    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         module_file=_module_file,
                         serving_model_dir_lite=_serving_model_dir_lite,
                         metadata_path=_metadata_path,
                         labels_path=_labels_path,
                         beam_pipeline_args=_beam_pipeline_args))
Пример #15
0
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=[
          training_example_gen, inference_example_gen, statistics_gen,
          infer_schema, validate_stats, transform, trainer, model_analyzer,
          model_validator, bulk_inferrer
      ],
      enable_cache=True,
      metadata_connection_config=metadata.sqlite_metadata_connection_config(
          metadata_path),
      # TODO(b/141578059): The multi-processing API might change.
      beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_with_inference.py
if __name__ == '__main__':
  absl.logging.set_verbosity(absl.logging.INFO)

  BeamDagRunner().run(
      _create_pipeline(
          pipeline_name=_pipeline_name,
          pipeline_root=_pipeline_root,
          training_data_root=_training_data_root,
          inference_data_root=_inference_data_root,
          module_file=_module_file,
          metadata_path=_metadata_path,
          # 0 means auto-detect based on on the number of CPUs available during
          # execution time.
          direct_num_workers=0))
Пример #16
0
        data_spec=bulk_inferrer_pb2.DataSpec(),
        model_spec=bulk_inferrer_pb2.ModelSpec())

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            training_example_gen, inference_example_gen, statistics_gen,
            schema_gen, example_validator, transform, trainer, model_resolver,
            evaluator, bulk_inferrer
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_with_inference.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)

    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         training_data_root=_training_data_root,
                         inference_data_root=_inference_data_root,
                         module_file=_module_file,
                         metadata_path=_metadata_path,
                         beam_pipeline_args=_beam_pipeline_args))
Пример #17
0
  def generate_models(self, args, force_tf_compat_v1=True):
    # Modified version of Chicago Taxi Example pipeline
    # tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py

    root = tempfile.mkdtemp()
    pipeline_root = os.path.join(root, "pipeline")
    metadata_path = os.path.join(root, "metadata/metadata.db")
    module_file = os.path.join(
        os.path.dirname(__file__),
        "../../../examples/chicago_taxi_pipeline/taxi_utils.py")

    examples = external_input(os.path.dirname(self.dataset_path()))
    example_gen = components.ImportExampleGen(input=examples)
    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs["examples"])
    schema_gen = components.SchemaGen(
        statistics=statistics_gen.outputs["statistics"],
        infer_feature_shape=False)
    transform = components.Transform(
        examples=example_gen.outputs["examples"],
        schema=schema_gen.outputs["schema"],
        module_file=module_file,
        force_tf_compat_v1=force_tf_compat_v1)
    trainer = components.Trainer(
        module_file=module_file,
        transformed_examples=transform.outputs["transformed_examples"],
        schema=schema_gen.outputs["schema"],
        transform_graph=transform.outputs["transform_graph"],
        train_args=trainer_pb2.TrainArgs(num_steps=100),
        eval_args=trainer_pb2.EvalArgs(num_steps=50))
    p = pipeline.Pipeline(
        pipeline_name="chicago_taxi_beam",
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, schema_gen, transform, trainer
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path))
    BeamDagRunner().run(p)

    def join_unique_subdir(path):
      dirs = os.listdir(path)
      if len(dirs) != 1:
        raise ValueError(
            "expecting there to be only one subdirectory in %s, but "
            "subdirectories were: %s" % (path, dirs))
      return os.path.join(path, dirs[0])

    trainer_output_dir = join_unique_subdir(
        os.path.join(pipeline_root, "Trainer/model"))
    eval_model_dir = join_unique_subdir(
        os.path.join(trainer_output_dir, "eval_model_dir"))
    serving_model_dir = join_unique_subdir(
        os.path.join(trainer_output_dir,
                     "serving_model_dir/export/chicago-taxi"))
    transform_output_dir = join_unique_subdir(
        os.path.join(pipeline_root, "Transform/transform_graph"))
    transform_model_dir = os.path.join(transform_output_dir, "transform_fn")
    tft_saved_model_path = self.tft_saved_model_path(force_tf_compat_v1)

    shutil.rmtree(self.trained_saved_model_path(), ignore_errors=True)
    shutil.rmtree(self.tfma_saved_model_path(), ignore_errors=True)
    shutil.rmtree(tft_saved_model_path, ignore_errors=True)
    shutil.copytree(serving_model_dir, self.trained_saved_model_path())
    shutil.copytree(eval_model_dir, self.tfma_saved_model_path())
    shutil.copytree(transform_model_dir, tft_saved_model_path)
Пример #18
0
    # Define the training/model parameters
    vc.hidden_layer_dims = [10]
    vc.batch_size = 32
    vc.num_train_steps = 1000
    vc.num_eval_steps = 100
    vc.warmup_prop = 0.1
    vc.cooldown_prop = 0.1
    vc.warm_start_from = None
    vc.save_summary_steps = 100
    vc.save_checkpoints_secs = 3600
    vc.learning_rate = 2e-5
    vc.num_gpus = 1

    vc.add_vars(
        **pipeline_var_names(vc.run_dir, vc.run_str, vc.mlp_project,
                             vc.mlp_subproject, vc.runner, vc.pipeline_type))

    vc.beam_pipeline_args = ['--project=' + vc.gcp_project]

    vc.write(vc.vc_config_path)
    DAG = BeamDagRunner().run(
        create_pipeline(run_root=vc.run_root,
                        pipeline_name=vc.pipeline_name,
                        pipeline_mod=vc.pipeline_mod,
                        query=vc.query,
                        beam_pipeline_args=vc.beam_pipeline_args,
                        metadata_path=os.path.join(vc.run_root, 'metadata',
                                                   'metadata.db'),
                        custom_config=vc.get_vars()))
Пример #19
0
    'batch_size': 1,  # Batch size
    'learning_rate': 0.00002,  # Learning rate for Adam
    'accumulate_gradients': 1,  # Accumulate gradients across N minibatches.
    'memory_saving_gradients':
    False,  # Use gradient checkpointing to reduce vram usage.
    'only_train_transformer_layers':
    False,  # Restrict training to the transformer blocks.
    'optimizer': 'adam',  # Optimizer. <adam|sgd>.
    'noise':
    0.0,  # Add noise to input training data to regularize against typos.
    'top_k': 40,  # K for top-k sampling.
    'top_p': 0.0,  # P for top-p sampling. Overrides top_k if set > 0.
    'sample_every': 100,  # Generate samples every N steps
    'sample_length': 1023,  # Sample this many tokens
    'sample_num': 1,  # Generate this many samples
    'save_every': 1000,  # Write a checkpoint every N steps
}

output_dir = "./output"

pipeline = create_pipeline(pipeline_name=os.path.basename(__file__),
                           pipeline_root=output_dir,
                           model_name=model_name,
                           train_config=train_config,
                           mongo_colnames=mongo_colnames,
                           mongo_ip=mongo_ip,
                           enable_cache=True,
                           mlflow_tracking_url=mlflow_tracking_url)

BeamDagRunner().run(pipeline)
    def testTaxiPipelineBeam(self):
        # Runs the pipeline and record to self._recorded_output_dir
        record_taxi_pipeline = taxi_pipeline_beam._create_pipeline(  # pylint:disable=protected-access
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._recorded_mlmd_path,
            beam_pipeline_args=[])
        BeamDagRunner().run(record_taxi_pipeline)
        pipeline_recorder_utils.record_pipeline(
            output_dir=self._recorded_output_dir,
            metadata_db_uri=self._recorded_mlmd_path,
            host=None,
            port=None,
            pipeline_name=self._pipeline_name,
            run_id=None)

        # Run pipeline with stub executors.
        taxi_pipeline = taxi_pipeline_beam._create_pipeline(  # pylint:disable=protected-access
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path,
            beam_pipeline_args=[])

        model_resolver_id = 'ResolverNode.latest_blessed_model_resolver'
        stubbed_component_ids = [
            component.id for component in taxi_pipeline.components
            if component.id != model_resolver_id
        ]

        stub_launcher = stub_component_launcher.get_stub_launcher_class(
            test_data_dir=self._recorded_output_dir,
            stubbed_component_ids=stubbed_component_ids,
            stubbed_component_map={})
        stub_pipeline_config = pipeline_config.PipelineConfig(
            supported_launcher_classes=[
                stub_launcher,
            ])
        BeamDagRunner(config=stub_pipeline_config).run(taxi_pipeline)

        self.assertTrue(tf.io.gfile.exists(self._metadata_path))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)

        # Verify that recorded files are successfully copied to the output uris.
        with metadata.Metadata(metadata_config) as m:
            artifacts = m.store.get_artifacts()
            artifact_count = len(artifacts)
            executions = m.store.get_executions()
            execution_count = len(executions)
            # Artifact count is greater by 3 due to extra artifacts produced by
            # Evaluator(blessing and evaluation), Trainer(model and model_run) and
            # Transform(example, graph, cache) minus Resolver which doesn't generate
            # new artifact.
            self.assertEqual(artifact_count, execution_count + 3)
            self.assertLen(taxi_pipeline.components, execution_count)

            for execution in executions:
                component_id = execution.properties[
                    metadata._EXECUTION_TYPE_KEY_COMPONENT_ID].string_value  # pylint: disable=protected-access
                if component_id == 'ResolverNode.latest_blessed_model_resolver':
                    continue
                eid = [execution.id]
                events = m.store.get_events_by_execution_ids(eid)
                output_events = [
                    x for x in events
                    if x.type == metadata_store_pb2.Event.OUTPUT
                ]
                for event in output_events:
                    steps = event.path.steps
                    self.assertTrue(steps[0].HasField('key'))
                    name = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    for idx, artifact in enumerate(artifacts):
                        self.assertDirectoryEqual(
                            artifact.uri,
                            os.path.join(self._recorded_output_dir,
                                         component_id, name, str(idx)))
Пример #21
0
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, user_schema_importer, schema_gen,
            example_validator, transform, trainer, model_resolver, evaluator,
            pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_beam.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)

    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         user_schema_path=_user_schema_path,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         metadata_path=_metadata_path,
                         beam_pipeline_args=_beam_pipeline_args))
    def testStubbedImdbPipelineBeam(self):
        # Runs the pipeline and record to self._recorded_output_dir
        stub_component_launcher.StubComponentLauncher.initialize(
            test_data_dir=self._recorded_output_dir, test_component_ids=[])

        stub_pipeline_config = pipeline_config.PipelineConfig(
            supported_launcher_classes=[
                stub_component_launcher.StubComponentLauncher,
            ])
        BeamDagRunner(config=stub_pipeline_config).run(self.imdb_pipeline)

        self.assertTrue(fileio.exists(self._metadata_path))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)

        # Verify that recorded files are successfully copied to the output uris.
        with metadata.Metadata(metadata_config) as m:
            for execution in m.store.get_executions():
                component_id = execution.properties[
                    metadata._EXECUTION_TYPE_KEY_COMPONENT_ID].string_value  # pylint: disable=protected-access
                if component_id.startswith('ResolverNode'):
                    continue
                eid = [execution.id]
                events = m.store.get_events_by_execution_ids(eid)
                output_events = [
                    x for x in events
                    if x.type == metadata_store_pb2.Event.OUTPUT
                ]
                for event in output_events:
                    steps = event.path.steps
                    assert steps[0].HasField('key')
                    name = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    for idx, artifact in enumerate(artifacts):
                        self.assertDirectoryEqual(
                            artifact.uri,
                            os.path.join(self._recorded_output_dir,
                                         component_id, name, str(idx)))

        # Calls verifier for pipeline output artifacts, excluding the resolver node.
        BeamDagRunner().run(self.imdb_pipeline)
        pipeline_outputs = executor_verifier_utils.get_pipeline_outputs(
            self.imdb_pipeline.metadata_connection_config,
            self.imdb_pipeline.pipeline_info)

        verifier_map = {
            'model': self._verify_model,
            'model_run': self._verify_model,
            'examples': self._verify_examples,
            'schema': self._verify_schema,
            'anomalies': self._verify_anomalies,
            'evaluation': self._verify_evaluation
        }

        # List of components to verify. ResolverNode is ignored because it
        # doesn't have an executor.
        verify_component_ids = [
            component.id for component in self.imdb_pipeline.components
            if not component.id.startswith('ResolverNode')
        ]

        for component_id in verify_component_ids:
            for key, artifact_dict in pipeline_outputs[component_id].items():
                for idx, artifact in artifact_dict.items():
                    logging.info('Verifying %s', component_id)
                    recorded_uri = os.path.join(self._recorded_output_dir,
                                                component_id, key, str(idx))
                    verifier_map.get(key, self._verify_file_path)(artifact.uri,
                                                                  recorded_uri)
Пример #23
0
        model_resolver,
        evaluator,
        pusher,
    ]
    if enable_tuning:
        components.append(tuner)

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=components,
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)


# To run this pipeline from the python CLI:
#   $python iris_pipeline_native_keras.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)
    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         metadata_path=_metadata_path,
                         enable_tuning=True,
                         beam_pipeline_args=_beam_pipeline_args))
    def testStubbedTaxiPipelineBeam(self):
        pipeline_ir = compiler.Compiler().compile(self.taxi_pipeline)

        logging.info('Replacing with test_data_dir:%s',
                     self._recorded_output_dir)
        pipeline_mock.replace_executor_with_stub(pipeline_ir,
                                                 self._recorded_output_dir, [])

        BeamDagRunner().run_with_ir(pipeline_ir)

        self.assertTrue(fileio.exists(self._metadata_path))

        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)

        # Verify that recorded files are successfully copied to the output uris.
        with metadata.Metadata(metadata_config) as m:
            artifacts = m.store.get_artifacts()
            artifact_count = len(artifacts)
            executions = m.store.get_executions()
            execution_count = len(executions)
            # Artifact count is greater by 7 due to extra artifacts produced by
            # Evaluator(blessing and evaluation), Trainer(model and model_run) and
            # Transform(example, graph, cache, pre_transform_statistics,
            # pre_transform_schema, post_transform_statistics, post_transform_schema,
            # post_transform_anomalies) minus Resolver which doesn't generate
            # new artifact.
            self.assertEqual(artifact_count, execution_count + 7)
            self.assertLen(self.taxi_pipeline.components, execution_count)

            for execution in executions:
                component_id = pipeline_recorder_utils.get_component_id_from_execution(
                    m, execution)
                if component_id.startswith('Resolver'):
                    continue
                eid = [execution.id]
                events = m.store.get_events_by_execution_ids(eid)
                output_events = [
                    x for x in events
                    if x.type == metadata_store_pb2.Event.OUTPUT
                ]
                for event in output_events:
                    steps = event.path.steps
                    self.assertTrue(steps[0].HasField('key'))
                    name = steps[0].key
                    artifacts = m.store.get_artifacts_by_id(
                        [event.artifact_id])
                    for idx, artifact in enumerate(artifacts):
                        self.assertDirectoryEqual(
                            artifact.uri,
                            os.path.join(self._recorded_output_dir,
                                         component_id, name, str(idx)))

        # Calls verifier for pipeline output artifacts, excluding the resolver node.
        BeamDagRunner().run(self.taxi_pipeline)
        pipeline_outputs = executor_verifier_utils.get_pipeline_outputs(
            self.taxi_pipeline.metadata_connection_config, self._pipeline_name)

        verifier_map = {
            'model': self._verify_model,
            'model_run': self._verify_model,
            'examples': self._verify_examples,
            'schema': self._verify_schema,
            'anomalies': self._verify_anomalies,
            'evaluation': self._verify_evaluation,
            # A subdirectory of updated_analyzer_cache has changing name.
            'updated_analyzer_cache': self._veryify_root_dir,
        }

        # List of components to verify. Resolver is ignored because it
        # doesn't have an executor.
        verify_component_ids = [
            component.id for component in self.taxi_pipeline.components
            if not component.id.startswith('Resolver')
        ]

        for component_id in verify_component_ids:
            logging.info('Verifying %s', component_id)
            for key, artifact_dict in pipeline_outputs[component_id].items():
                for idx, artifact in artifact_dict.items():
                    recorded_uri = os.path.join(self._recorded_output_dir,
                                                component_id, key, str(idx))
                    verifier_map.get(key, self._verify_file_path)(artifact.uri,
                                                                  recorded_uri)
Пример #25
0
    example_gen = CsvExampleGen(input_base=examples)

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])

    # Generates schema based on statistics files.
    infer_schema = SchemaGen(stats=statistics_gen.outputs['output'])

    # Performs anomaly detection based on statistics and data schema.
    validate_stats = ExampleValidator(stats=statistics_gen.outputs['output'],
                                      schema=infer_schema.outputs['output'])

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[example_gen, statistics_gen, infer_schema, validate_stats],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        additional_pipeline_args={},
    )


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         metadata_path=_metadata_path))
Пример #26
0
    def testCustomComponentSchemaE2EFTFXPipeline(self):
        pipeline_def = get_pipeline(self._initial_pipeline_def)
        pipeline_def = pipeline_def.with_sqlite_ml_metadata().cache()

        expected_execution_count = 5  # 4 components + schema importer
        self.assertLen(pipeline_def.components.values(),
                       expected_execution_count)

        self.assertTrue('example_gen' in pipeline_def.components)
        self.assertIsInstance(
            pipeline_def.components['example_gen'], tfx.components.CsvExampleGen)

        self.assertTrue('statistics_gen' in pipeline_def.components)
        self.assertIsInstance(
            pipeline_def.components['statistics_gen'], tfx.components.StatisticsGen)

        self.assertTrue('schema_gen' in pipeline_def.components)
        self.assertIsInstance(
            pipeline_def.components['schema_gen'], tfx.components.SchemaGen)

        self.assertTrue('user_schema_importer' in pipeline_def.components)
        self.assertIsInstance(
            pipeline_def.components['user_schema_importer'], tfx.components.ImporterNode)

        self.assertTrue('schema_printer' in pipeline_def.components)
        self.assertIsInstance(
            pipeline_def.components['schema_printer'], tfx.components.base.base_component.BaseComponent)

        pipeline = pipeline_def.build()

        BeamDagRunner().run(pipeline)

        artifact_root = os.path.join(
            self._bucket_dir, pipeline_def.pipeline_name)
        self.assertTrue(os.path.join(artifact_root, 'metadata.db'))
        metadata_config = pipeline_def.metadata_connection_config
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertLessEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution()

        BeamDagRunner().run(pipeline)

        # All executions are cached
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(expected_execution_count * 2,
                             len(m.store.get_executions()))

        BeamDagRunner().run(pipeline)

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertEqual(artifact_count, len(m.store.get_artifacts()))
            self.assertEqual(expected_execution_count * 3,
                             len(m.store.get_executions()))
Пример #27
0
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=beam_pipeline_args)


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_beam.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)

    parser = argparse_flags.ArgumentParser()
    parser.add_argument(
        '--runner',
        type=str,
        default='DirectRunner',
        choices=['DirectRunner', 'FlinkRunner', 'SparkRunner'],
        help='The Beam runner to execute Beam-powered components. '
        'For FlinkRunner or SparkRunner, first run setup/setup_beam_on_flink.sh '
        'or setup/setup_beam_on_spark.sh, respectively.')
    parsed_args, _ = parser.parse_known_args(sys.argv)

    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         metadata_path=_metadata_path,
                         beam_pipeline_args=_beam_pipeline_args_by_runner[
                             parsed_args.runner]))
    def testTaxiPipelineNativeKeras(self):
        BeamDagRunner().run(
            taxi_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        expected_execution_count = 9  # 8 components + 1 resolver
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(expected_execution_count, execution_count)

        self.assertPipelineExecution()

        # Runs pipeline the second time.
        BeamDagRunner().run(
            taxi_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # All executions but Evaluator and Pusher are cached.
        # Note that Resolver will always execute.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is increased by 3 caused by Evaluator and Pusher.
            self.assertLen(m.store.get_artifacts(), artifact_count + 3)
            artifact_count = len(m.store.get_artifacts())
            self.assertLen(m.store.get_executions(),
                           expected_execution_count * 2)

        # Runs pipeline the third time.
        BeamDagRunner().run(
            taxi_pipeline_native_keras._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        # Asserts cache execution.
        with metadata.Metadata(metadata_config) as m:
            # Artifact count is unchanged.
            self.assertLen(m.store.get_artifacts(), artifact_count)
            self.assertLen(m.store.get_executions(),
                           expected_execution_count * 3)
Пример #29
0
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, schema_gen, example_validator,
            transform, trainer, evaluator, model_validator, pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
    )


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_presto.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)

    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         presto_config=_presto_config,
                         query=_query,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         metadata_path=_metadata_path))
Пример #30
0
        # Note that direct_num_workers != 1 will enable multi-process for TFX,
        # we hide the FnApiRunner[1] setting from user, but this is subject to
        # change if Beam offers pure flag setup.
        # [1]https://issues.apache.org/jira/browse/BEAM-3645
        beam_pipeline_args=['--direct_num_workers=%s' % direct_num_workers],
        additional_pipeline_args={},
    )


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_beam.py
if __name__ == '__main__':
    absl.logging.set_verbosity(absl.logging.INFO)

    try:
        parallelism = multiprocessing.cpu_count()
    except NotImplementedError:
        absl.logging.info(
            'Use single process as multiprocessing.cpu_count is not supported.'
        )
        parallelism = 1

    BeamDagRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         pipeline_root=_pipeline_root,
                         data_root=_data_root,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         metadata_path=_metadata_path,
                         direct_num_workers=parallelism))