def test_Run_test_only( disposable_log_db: log_database.Database, graph_db: graph_tuple_database.Database, k_fold: bool, ): """Test the run.Run() method.""" log_db = disposable_log_db # Set the flags that determine the behaviour of Run(). FLAGS.graph_db = flags_parsers.DatabaseFlag( graph_tuple_database.Database, graph_db.url, must_exist=True ) FLAGS.log_db = flags_parsers.DatabaseFlag( log_database.Database, log_db.url, must_exist=True ) FLAGS.test_only = True FLAGS.k_fold = k_fold run.Run(MockModel) # Test that k-fold produces multiple runs. assert log_db.run_count == graph_db.split_count if k_fold else 1 run_ids = log_db.run_ids for run_id in run_ids: logs = log_analysis.RunLogAnalyzer(log_db=log_db, run_id=run_id) epochs = logs.tables["epochs"] # Check that we performed as many epochs as expected. assert 1 == len(epochs) test_count = len(epochs[epochs["test_accuracy"].notnull()]) # Check that we produced a test result. assert test_count == 1
def test_RunLogAnalyser_best_epoch_num(populated_log_db: DatabaseAndRunIds, metric: str): """Black-box test that run log properties work.""" for run_id in populated_log_db.run_ids: run = log_analysis.RunLogAnalyzer(populated_log_db.db, run_id) try: assert run.GetBestEpochNum(metric=metric) except ValueError as e: # Some metrics will raise an error if they are not met. This is fine. assert str(e) == f"No {run_id} epochs reached {metric}"
def test_Run( disposable_log_db: log_database.Database, graph_db: graph_tuple_database.Database, k_fold: bool, run_with_memory_profiler: bool, test_on: str, stop_at: List[str], ): """Test the run.Run() method.""" log_db = disposable_log_db # Set the flags that determine the behaviour of Run(). FLAGS.graph_db = flags_parsers.DatabaseFlag( graph_tuple_database.Database, graph_db.url, must_exist=True ) FLAGS.log_db = flags_parsers.DatabaseFlag( log_database.Database, log_db.url, must_exist=True ) FLAGS.epoch_count = 3 FLAGS.k_fold = k_fold FLAGS.run_with_memory_profiler = run_with_memory_profiler FLAGS.test_on = test_on FLAGS.stop_at = stop_at run.Run(MockModel) # Test that k-fold produces multiple runs. assert log_db.run_count == graph_db.split_count if k_fold else 1 run_ids = log_db.run_ids for run_id in run_ids: logs = log_analysis.RunLogAnalyzer(log_db=log_db, run_id=run_id) epochs = logs.tables["epochs"] # Check that we performed as many epochs as expected. We can't check the # exact value because of --stop_at options. assert 1 <= len(epochs) <= FLAGS.epoch_count test_count = len(epochs[epochs["test_accuracy"].notnull()]) # Test that the number of test epochs matches the expected amount depending # on --test_on flag. if test_on == "none": assert test_count == 0 elif test_on == "best": assert test_count == 1 elif test_on == "improvement": assert test_count >= 1 elif test_on == "improvement_and_last": assert test_count >= 1
def __init__( self, log_db: log_database.Database, checkpoint_ref: checkpoints.CheckpointReference, epoch_type: epoch.Type, outdir: pathlib.Path, export_batches_per_query: int = 128, ): """Constructor. Args: log_db: The database to export logs from. checkpoint_ref: A run ID and epoch number. epoch_type: The type of epoch to export graphs from. outdir: The directory to write results to. """ self.log_db = log_db self.logger = logger_lib.Logger(self.log_db) self.checkpoint = checkpoint_ref self.epoch_type = epoch_type self.analyzer = log_analysis.RunLogAnalyzer( self.log_db, self.checkpoint.run_id ) self.export_batches_per_query = export_batches_per_query self.outdir = outdir self.outdir.mkdir(parents=True, exist_ok=True) (self.outdir / "graph_stats").mkdir(exist_ok=True) (self.outdir / "in_graphs").mkdir(exist_ok=True) (self.outdir / "out_graphs").mkdir(exist_ok=True) # Count the total number of graphs to export cross all batches. with self.log_db.Session() as session: num_graphs = self.FilterBatchesQuery( session.query(sql.func.sum(log_database.Batch.graph_count)) ).scalar() if not num_graphs: raise ValueError("No graphs found!") super(BatchDetailsExporter, self).__init__( name=f"export {self.checkpoint} graphs", unit="graphs", i=0, n=num_graphs, )
def test_GetInputOutputGraphs(populated_log_db: DatabaseAndRunIds): """Test reconstructing graphs from a detailed batch.""" # Select a random run to analyze. run_id = random.choice(populated_log_db.run_ids) run = log_analysis.RunLogAnalyzer(populated_log_db.db, run_id) with populated_log_db.db.Session() as session: # Select some random detailed batches to reconstruct the graphs of. detailed_batches = (session.query(log_database.Batch).join( log_database.BatchDetails).options( sql.orm.joinedload(log_database.Batch.details)).order_by( populated_log_db.db.Random()).limit(50).all()) # Sanity check that there are detailed batches. assert detailed_batches for batch in detailed_batches: input_output_graphs = list(run.GetInputOutputGraphs(batch)) # Check that the number of input_output_graphs matches the size of the # batch. assert len(input_output_graphs) == len(batch.graph_ids)
def test_RunLogAnalyser_smoke_tests(populated_log_db: DatabaseAndRunIds): """Black-box test that run log properties work.""" for run_id in populated_log_db.run_ids: run = log_analysis.RunLogAnalyzer(populated_log_db.db, run_id) assert run.graph_db assert run.tables.keys() == {"parameters", "epochs", "runs", "tags"}
def test_RunLogAnalyser_empty_db(empty_log_db: log_database.Database): """Test that cannot analyse non-existing run.""" with test.Raises(ValueError): log_analysis.RunLogAnalyzer(empty_log_db, run_id.RunId.GenerateUnique("foo"))