def test_happy_path(self): system = sample_system_object() recorder = HistoryRecorder(system=system, est_path="test.py", db_path=self.db_path) with recorder: print("Test Log Capture") print("Line 2") db = connect(self.db_path) with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM history WHERE pk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("History Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("File name captured"): self.assertEqual(results['file'], 'test.py') with self.subTest("Status updated"): self.assertEqual(results['status'], 'Completed') # Logs with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM logs WHERE fk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("Log Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("Complete log captured"): self.assertListEqual(results['log'].splitlines(), ["Test Log Capture", "Line 2"]) db.close()
def test_joint_update_lower_kgl(self): update_settings(n_keep=400, n_keep_logs=300, db_path=self.db_path) db = connect(self.db_path) with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM settings WHERE pk = 0") results = cursor.fetchall()[0] with self.subTest("n_keep updated"): self.assertEqual(results['n_keep'], 400) with self.subTest("n_keep_logs updated"): self.assertEqual(results['n_keep_logs'], 300)
def test_decrease_logs(self): update_settings(n_keep_logs=42, db_path=self.db_path) db = connect(self.db_path) with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM settings WHERE pk = 0") results = cursor.fetchall()[0] with self.subTest("n_keep not updated"): self.assertEqual(results['n_keep'], DEFAULT_KEEP) with self.subTest("n_keep_logs updated"): self.assertEqual(results['n_keep_logs'], 42)
def test_make_schema(self): db = connect(":memory:") with closing(db.cursor()) as cursor: cursor.execute( "SELECT name FROM sqlite_master WHERE type = 'table'") results = cursor.fetchall() results = {result['name'] for result in results} expected = { "datasets", "errors", "features", "history", "pipeline", "network", "postprocess", "traces", "errors", "logs", "settings" } self.assertSetEqual(results, expected) db.close()
def test_restore_training_old_missing(self): system1 = sample_system_object() recorder1 = HistoryRecorder(system=system1, est_path="test.py", db_path=self.db_path) try: with recorder1: print("Test Log Capture") print("Line 2") raise RuntimeError("Training Died") except RuntimeError: pass db = connect(self.db_path) db.execute("DELETE FROM history WHERE pk = (?)", [system1.exp_id]) db.commit() system2 = sample_system_object() recorder2 = HistoryRecorder(system=system2, est_path="test.py", db_path=self.db_path) with recorder2: # Fake a restore wizard system2.__dict__.update(system1.__dict__) print("Line 3") print("Line 4") with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM history") results = cursor.fetchall() with self.subTest("History Captured and Consolidated"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("File name captured"): self.assertEqual(results['file'], 'test.py') with self.subTest("Status updated"): self.assertEqual(results['status'], 'Completed') with self.subTest("Correct PK"): self.assertEqual(results['pk'], system1.exp_id) with self.subTest("Restarts Incremented"): self.assertEqual(results['n_restarts'], 1) # Logs with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM logs WHERE fk = (?)", [system1.exp_id]) results = cursor.fetchall() with self.subTest("Log Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("Complete log captured"): self.assertListEqual(results['log'].splitlines(), ["Line 3", "Line 4"]) db.close()
def test_initial_settings(self): db = connect(":memory:") with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM settings") results = cursor.fetchall() with self.subTest("should be exactly one setting row"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("settings should have pk = 0"): self.assertEqual(results['pk'], 0) with self.subTest("schema version should be 1"): self.assertEqual(results['schema_version'], 1) with self.subTest(f"n_keep should be {DEFAULT_KEEP}"): self.assertEqual(results['n_keep'], DEFAULT_KEEP) with self.subTest(f"n_keep_logs should be {DEFAULT_LOG_KEEP}"): self.assertEqual(results['n_keep_logs'], DEFAULT_LOG_KEEP)
def test_fk_update(self): db = connect(":memory:") db.execute("INSERT INTO history (pk) VALUES (0)") db.execute("INSERT INTO network (fk) VALUES (0)") db.commit() with closing(db.cursor()) as cursor: with self.subTest("entry should be inserted correctly"): cursor.execute("SELECT * FROM network") results = cursor.fetchall() self.assertEqual(results[0]['fk'], 0) with self.subTest("entry should cascade update"): db.execute("UPDATE history SET pk = 10 WHERE pk = 0") cursor.execute("SELECT * FROM network") results = cursor.fetchall() self.assertEqual(results[0]['fk'], 10) db.close()
def test_fk_delete(self): db = connect(":memory:") db.execute("INSERT INTO history (pk) VALUES (0)") db.execute("INSERT INTO network (fk) VALUES (0)") db.commit() with closing(db.cursor()) as cursor: with self.subTest("entry should be inserted correctly"): cursor.execute("SELECT count(*) AS count FROM network") results = cursor.fetchall() self.assertEqual(results[0]['count'], 1) with self.subTest("entry should cascade delete"): db.execute("DELETE FROM history WHERE pk = 0") cursor.execute("SELECT count(*) AS count FROM network") results = cursor.fetchall() self.assertEqual(results[0]['count'], 0) db.close()
def test_under_threshold(self): db = connect(self.db_path) for i in range(10): db.execute("INSERT INTO history (pk, train_start) VALUES (?, ?)", [i, datetime.now()]) db.commit() with closing(db.cursor()) as cursor: with self.subTest("entries should be inserted correctly"): cursor.execute("SELECT count(*) AS count FROM history") results = cursor.fetchall() self.assertEqual(results[0]['count'], 10) delete(n_keep=10, db_path=self.db_path) with self.subTest("no entries should be deleted"): cursor.execute("SELECT count(*) AS count FROM history") results = cursor.fetchall() self.assertEqual(results[0]['count'], 10) db.close()
def test_error_raised(self): system = sample_system_object() recorder = HistoryRecorder(system=system, est_path="test.py", db_path=self.db_path) try: with recorder: print("Test Log Capture") print("Line 2") raise RuntimeError("Training Died") except RuntimeError: pass db = connect(self.db_path) with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM history WHERE pk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("History Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("File name captured"): self.assertEqual(results['file'], 'test.py') with self.subTest("Status updated"): self.assertEqual(results['status'], 'Failed') # Logs with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM logs WHERE fk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("Log Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("Complete log captured"): self.assertEqual(results['log'], "Test Log Capture\nLine 2\n") # Error with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM errors WHERE fk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("Error Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("Error info captured"): self.assertEqual(results['exc_type'], "RuntimeError") db.close()
def test_auto_delete(self): update_settings(n_keep=5, db_path=self.db_path) for i in range(7): system = sample_system_object() recorder = HistoryRecorder(system=system, est_path=f"{i}", db_path=self.db_path) with recorder: print(f"Run {i}") db = connect(self.db_path) with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM history") results = cursor.fetchall() with self.subTest("Ensure correct number retained"): self.assertEqual(len(results), 5) with self.subTest("Ensure correct entries retained"): actual_names = {result['file'] for result in results} expected_names = {f"{i}" for i in range(2, 7)} self.assertSetEqual(actual_names, expected_names) db.close()
def test_over_threshold(self): db = connect(self.db_path) for i in range(10): db.execute("INSERT INTO history (pk, train_start) VALUES (?, ?)", [i, datetime.now()]) db.commit() with closing(db.cursor()) as cursor: with self.subTest("entries should be inserted correctly"): cursor.execute("SELECT count(*) AS count FROM history") results = cursor.fetchall() self.assertEqual(results[0]['count'], 10) delete(n_keep=9, db_path=self.db_path) with self.subTest("one entry should be deleted"): cursor.execute("SELECT count(*) AS count FROM history") results = cursor.fetchall() self.assertEqual(results[0]['count'], 9) with self.subTest("oldest entry should be deleted"): cursor.execute("SELECT pk FROM history") results = cursor.fetchall() results = {result['pk'] for result in results} expected = {i for i in range(1, 10)} self.assertSetEqual(results, expected) db.close()