Exemple #1
0
 def testTaxiPipelineCheckDagConstruction(self):
     airflow_config = {
         'schedule_interval': None,
         'start_date': datetime.datetime(2019, 1, 1),
     }
     logical_pipeline = taxi_pipeline_simple._create_pipeline(
         pipeline_name='Test',
         pipeline_root=self._test_dir,
         data_root=self._test_dir,
         module_file=self._test_dir,
         serving_model_dir=self._test_dir,
         metadata_path=self._test_dir)
     self.assertEqual(9, len(logical_pipeline.components))
     pipeline = AirflowDAGRunner(airflow_config).run(logical_pipeline)
     self.assertIsInstance(pipeline, models.DAG)
Exemple #2
0
              column_for_slicing=['trip_start_hour'])
      ]))

  # Performs quality validation of a candidate model (compared to a baseline).
  model_validator = ModelValidator(
      examples=example_gen.outputs.examples, model=trainer.outputs.output)

  # Checks whether the model passed the validation steps and pushes the model
  # to a file destination if check passed.
  pusher = Pusher(
      model_export=trainer.outputs.output,
      model_blessing=model_validator.outputs.blessing,
      push_destination=pusher_pb2.PushDestination(
          filesystem=pusher_pb2.PushDestination.Filesystem(
              base_directory=_serving_model_dir)))

  return pipeline.Pipeline(
      pipeline_name='chicago_taxi_simple',
      pipeline_root=_pipeline_root,
      components=[
          example_gen, statistics_gen, infer_schema, validate_stats, transform,
          trainer, model_analyzer, model_validator, pusher
      ],
      enable_cache=True,
      metadata_db_root=_metadata_db_root,
      additional_pipeline_args={'logger_args': logger_overrides},
  )


airflow_pipeline = AirflowDAGRunner(_airflow_config).run(_create_pipeline())
    model_validator = ModelValidator(examples=example_gen.outputs.examples,
                                     model=trainer.outputs.output)

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model_export=trainer.outputs.output,
                    model_blessing=model_validator.outputs.blessing,
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, infer_schema, validate_stats,
            transform, trainer, model_analyzer, model_validator, pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path))


airflow_pipeline = AirflowDAGRunner(_airflow_config).run(
    _create_pipeline(pipeline_name=_pipeline_name,
                     pipeline_root=_pipeline_root,
                     data_root=_data_root,
                     module_file=_module_file,
                     serving_model_dir=_serving_model_dir,
                     metadata_path=_metadata_path))
    # Compares new model against a baseline; both models evaluated on a dataset
    model_validator = ModelValidator(examples=example_gen.outputs.examples,
                                     model=trainer.outputs.output)

    # Pushes a blessed model to a deployment target (tfserving)
    pusher = Pusher(model_export=trainer.outputs.output,
                    model_blessing=model_validator.outputs.blessing,
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=SERVING_DIR)))

    return pipeline.Pipeline(pipeline_name=PIPELINE_NAME,
                             pipeline_root=DAGS_DIR,
                             components=[
                                 example_gen, statistics_gen, infer_schema,
                                 validate_stats, transform, trainer,
                                 model_analyzer, model_validator, pusher
                             ],
                             enable_cache=True,
                             metadata_db_root=METADATA_DIR,
                             additional_pipeline_args={
                                 'logger_args': {
                                     'log_root': LOGS_DIR,
                                     'log_level': logging.INFO
                                 }
                             })


airflow_pipeline = AirflowDAGRunner(AIRFLOW_CONFIG).run(create_pipeline())
Exemple #5
0
INPUT_PATH = Variable.get(
    "tf_sample_model.input_path",
    'gs://renault-ml-tf-sample-model-dev/datasets/input/tf-records')

TF_SERVING_MODEL_BASEDIR = Variable.get(
    "tf_sample_model.serving_model_basedir",
    'gs://renault-ml-tf-sample-model-dev/datasets/saved_models')

_AIRFLOW_CONFIG = {
    'schedule_interval': SCHEDULE_INTERVAL,
    'start_date': START_DATE,
}

# Logging overrides
_LOGGER_OVERRIDES = {'log_root': LOGS_ROOT, 'log_level': LOGS_LEVEL}

TFX_PIPELINE = create_pipeline(
    pipeline_name=PIPELINE_NAME,
    pipeline_root=PIPELINE_ROOT,
    input_path=INPUT_PATH,
    tf_transform_file=TF_TRANSFORM_FILE,
    tf_trainer_file=TF_TRAINER_FILE,
    serving_model_basedir=TF_SERVING_MODEL_BASEDIR,
    metadata_db_root=METADATA_DB,
    metadata_connection_config=metadata_connection_config,
    enable_cache=True,
    additional_pipeline_args={'logger_args': _LOGGER_OVERRIDES})

AIRFLOW_DAG = AirflowDAGRunner(config=_AIRFLOW_CONFIG).run(
    pipeline=TFX_PIPELINE)