def __init__( self, log_db: log_database.Database, checkpoint_ref: checkpoints.CheckpointReference, epoch_type: epoch.Type, outdir: pathlib.Path, export_batches_per_query: int = 128, ): """Constructor. Args: log_db: The database to export logs from. checkpoint_ref: A run ID and epoch number. epoch_type: The type of epoch to export graphs from. outdir: The directory to write results to. """ self.log_db = log_db self.logger = logger_lib.Logger(self.log_db) self.checkpoint = checkpoint_ref self.epoch_type = epoch_type self.analyzer = log_analysis.RunLogAnalyzer( self.log_db, self.checkpoint.run_id ) self.export_batches_per_query = export_batches_per_query self.outdir = outdir self.outdir.mkdir(parents=True, exist_ok=True) (self.outdir / "graph_stats").mkdir(exist_ok=True) (self.outdir / "in_graphs").mkdir(exist_ok=True) (self.outdir / "out_graphs").mkdir(exist_ok=True) # Count the total number of graphs to export cross all batches. with self.log_db.Session() as session: num_graphs = self.FilterBatchesQuery( session.query(sql.func.sum(log_database.Batch.graph_count)) ).scalar() if not num_graphs: raise ValueError("No graphs found!") super(BatchDetailsExporter, self).__init__( name=f"export {self.checkpoint} graphs", unit="graphs", i=0, n=num_graphs, )
def logger(log_db: log_database.Database) -> logging.Logger: """A test fixture which yields a logger.""" with logging.Logger(log_db, max_buffer_length=128) as logger: yield logger