Ejemplo n.º 1
0
  def __init__(
    self,
    log_db: log_database.Database,
    checkpoint_ref: checkpoints.CheckpointReference,
    epoch_type: epoch.Type,
    outdir: pathlib.Path,
    export_batches_per_query: int = 128,
  ):
    """Constructor.

    Args:
      log_db: The database to export logs from.
      checkpoint_ref: A run ID and epoch number.
      epoch_type: The type of epoch to export graphs from.
      outdir: The directory to write results to.
    """
    self.log_db = log_db
    self.logger = logger_lib.Logger(self.log_db)
    self.checkpoint = checkpoint_ref
    self.epoch_type = epoch_type
    self.analyzer = log_analysis.RunLogAnalyzer(
      self.log_db, self.checkpoint.run_id
    )
    self.export_batches_per_query = export_batches_per_query

    self.outdir = outdir
    self.outdir.mkdir(parents=True, exist_ok=True)
    (self.outdir / "graph_stats").mkdir(exist_ok=True)
    (self.outdir / "in_graphs").mkdir(exist_ok=True)
    (self.outdir / "out_graphs").mkdir(exist_ok=True)

    # Count the total number of graphs to export cross all batches.
    with self.log_db.Session() as session:
      num_graphs = self.FilterBatchesQuery(
        session.query(sql.func.sum(log_database.Batch.graph_count))
      ).scalar()

    if not num_graphs:
      raise ValueError("No graphs found!")

    super(BatchDetailsExporter, self).__init__(
      name=f"export {self.checkpoint} graphs", unit="graphs", i=0, n=num_graphs,
    )
Ejemplo n.º 2
0
def logger(log_db: log_database.Database) -> logging.Logger:
  """A test fixture which yields a logger."""
  with logging.Logger(log_db, max_buffer_length=128) as logger:
    yield logger