コード例 #1
0
  def test_taxi_pipeline_beam(self):
    BeamRunner().run(
        taxi_pipeline_beam._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path))

    self.assertTrue(tf.gfile.Exists(self._serving_model_dir))
    self.assertTrue(tf.gfile.Exists(self._metadata_path))
    metadata_config = metadata.sqlite_metadata_connection_config(
        self._metadata_path)
    with metadata.Metadata(metadata_config) as m:
      artifact_count = len(m.store.get_artifacts())
      execution_count = len(m.store.get_executions())
      self.assertGreaterEqual(artifact_count, execution_count)
      self.assertEqual(9, execution_count)

    self.assertPipelineExecution()

    # Run pipeline again.
    BeamRunner().run(
        taxi_pipeline_beam._create_pipeline(
            pipeline_name=self._pipeline_name,
            data_root=self._data_root,
            module_file=self._module_file,
            serving_model_dir=self._serving_model_dir,
            pipeline_root=self._pipeline_root,
            metadata_path=self._metadata_path))

    # Assert cache execution.
    with metadata.Metadata(metadata_config) as m:
      # Artifact count is unchanged.
      self.assertEqual(artifact_count, len(m.store.get_artifacts()))
      # 9 more cached executions.
      self.assertEqual(18, len(m.store.get_executions()))

    self.assertPipelineExecution()
コード例 #2
0
ファイル: taxi_pipeline_slack.py プロジェクト: jay90099/tfx
        model_blessing=model_validator.outputs['blessing'],
        slack_token=_slack_token,
        slack_channel_id=_slack_channel_id,
        timeout_sec=3600,
    )

    # Checks whether the model passed the validation steps and pushes the model
    # to a file destination if check passed.
    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=slack_validator.outputs['slack_blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=_serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=_pipeline_name,
        pipeline_root=_pipeline_root,
        components=[
            example_gen, statistics_gen, schema_gen, example_validator,
            transform, trainer, evaluator, model_validator, slack_validator,
            pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            _metadata_db_root),
    )


if __name__ == '__main__':
    BeamRunner().run(_create_pipeline())
コード例 #3
0
                    model_blessing=model_validator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, statistics_gen, infer_schema, validate_stats,
            transform, trainer, model_analyzer, model_validator, pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        additional_pipeline_args={},
    )


# To run this pipeline from the python CLI:
#   $python taxi_pipeline_beam.py
if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    BeamRunner().run(
        _create_pipeline(pipeline_name=_pipeline_name,
                         data_root=_data_root,
                         module_file=_module_file,
                         serving_model_dir=_serving_model_dir,
                         pipeline_root=_pipeline_root,
                         metadata_path=_metadata_path))