Beispiel #1
0
 def test_happy_path(self):
     system = sample_system_object()
     recorder = HistoryRecorder(system=system,
                                est_path="test.py",
                                db_path=self.db_path)
     with recorder:
         print("Test Log Capture")
         print("Line 2")
     db = connect(self.db_path)
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM history WHERE pk = (?)",
                        [system.exp_id])
         results = cursor.fetchall()
     with self.subTest("History Captured"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("File name captured"):
         self.assertEqual(results['file'], 'test.py')
     with self.subTest("Status updated"):
         self.assertEqual(results['status'], 'Completed')
     # Logs
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM logs WHERE fk = (?)",
                        [system.exp_id])
         results = cursor.fetchall()
     with self.subTest("Log Captured"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("Complete log captured"):
         self.assertListEqual(results['log'].splitlines(),
                              ["Test Log Capture", "Line 2"])
     db.close()
Beispiel #2
0
 def test_joint_update_lower_kgl(self):
     update_settings(n_keep=400, n_keep_logs=300, db_path=self.db_path)
     db = connect(self.db_path)
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM settings WHERE pk = 0")
         results = cursor.fetchall()[0]
     with self.subTest("n_keep updated"):
         self.assertEqual(results['n_keep'], 400)
     with self.subTest("n_keep_logs updated"):
         self.assertEqual(results['n_keep_logs'], 300)
Beispiel #3
0
 def test_decrease_logs(self):
     update_settings(n_keep_logs=42, db_path=self.db_path)
     db = connect(self.db_path)
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM settings WHERE pk = 0")
         results = cursor.fetchall()[0]
     with self.subTest("n_keep not updated"):
         self.assertEqual(results['n_keep'], DEFAULT_KEEP)
     with self.subTest("n_keep_logs updated"):
         self.assertEqual(results['n_keep_logs'], 42)
Beispiel #4
0
 def test_make_schema(self):
     db = connect(":memory:")
     with closing(db.cursor()) as cursor:
         cursor.execute(
             "SELECT name FROM sqlite_master WHERE type = 'table'")
         results = cursor.fetchall()
     results = {result['name'] for result in results}
     expected = {
         "datasets", "errors", "features", "history", "pipeline", "network",
         "postprocess", "traces", "errors", "logs", "settings"
     }
     self.assertSetEqual(results, expected)
     db.close()
Beispiel #5
0
 def test_restore_training_old_missing(self):
     system1 = sample_system_object()
     recorder1 = HistoryRecorder(system=system1,
                                 est_path="test.py",
                                 db_path=self.db_path)
     try:
         with recorder1:
             print("Test Log Capture")
             print("Line 2")
             raise RuntimeError("Training Died")
     except RuntimeError:
         pass
     db = connect(self.db_path)
     db.execute("DELETE FROM history WHERE pk = (?)", [system1.exp_id])
     db.commit()
     system2 = sample_system_object()
     recorder2 = HistoryRecorder(system=system2,
                                 est_path="test.py",
                                 db_path=self.db_path)
     with recorder2:
         # Fake a restore wizard
         system2.__dict__.update(system1.__dict__)
         print("Line 3")
         print("Line 4")
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM history")
         results = cursor.fetchall()
     with self.subTest("History Captured and Consolidated"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("File name captured"):
         self.assertEqual(results['file'], 'test.py')
     with self.subTest("Status updated"):
         self.assertEqual(results['status'], 'Completed')
     with self.subTest("Correct PK"):
         self.assertEqual(results['pk'], system1.exp_id)
     with self.subTest("Restarts Incremented"):
         self.assertEqual(results['n_restarts'], 1)
     # Logs
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM logs WHERE fk = (?)",
                        [system1.exp_id])
         results = cursor.fetchall()
     with self.subTest("Log Captured"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("Complete log captured"):
         self.assertListEqual(results['log'].splitlines(),
                              ["Line 3", "Line 4"])
     db.close()
Beispiel #6
0
 def test_initial_settings(self):
     db = connect(":memory:")
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM settings")
         results = cursor.fetchall()
     with self.subTest("should be exactly one setting row"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("settings should have pk = 0"):
         self.assertEqual(results['pk'], 0)
     with self.subTest("schema version should be 1"):
         self.assertEqual(results['schema_version'], 1)
     with self.subTest(f"n_keep should be {DEFAULT_KEEP}"):
         self.assertEqual(results['n_keep'], DEFAULT_KEEP)
     with self.subTest(f"n_keep_logs should be {DEFAULT_LOG_KEEP}"):
         self.assertEqual(results['n_keep_logs'], DEFAULT_LOG_KEEP)
Beispiel #7
0
 def test_fk_update(self):
     db = connect(":memory:")
     db.execute("INSERT INTO history (pk) VALUES (0)")
     db.execute("INSERT INTO network (fk) VALUES (0)")
     db.commit()
     with closing(db.cursor()) as cursor:
         with self.subTest("entry should be inserted correctly"):
             cursor.execute("SELECT * FROM network")
             results = cursor.fetchall()
             self.assertEqual(results[0]['fk'], 0)
         with self.subTest("entry should cascade update"):
             db.execute("UPDATE history SET pk = 10 WHERE pk = 0")
             cursor.execute("SELECT * FROM network")
             results = cursor.fetchall()
             self.assertEqual(results[0]['fk'], 10)
     db.close()
Beispiel #8
0
 def test_fk_delete(self):
     db = connect(":memory:")
     db.execute("INSERT INTO history (pk) VALUES (0)")
     db.execute("INSERT INTO network (fk) VALUES (0)")
     db.commit()
     with closing(db.cursor()) as cursor:
         with self.subTest("entry should be inserted correctly"):
             cursor.execute("SELECT count(*) AS count FROM network")
             results = cursor.fetchall()
             self.assertEqual(results[0]['count'], 1)
         with self.subTest("entry should cascade delete"):
             db.execute("DELETE FROM history WHERE pk = 0")
             cursor.execute("SELECT count(*) AS count FROM network")
             results = cursor.fetchall()
             self.assertEqual(results[0]['count'], 0)
     db.close()
Beispiel #9
0
 def test_under_threshold(self):
     db = connect(self.db_path)
     for i in range(10):
         db.execute("INSERT INTO history (pk, train_start) VALUES (?, ?)",
                    [i, datetime.now()])
     db.commit()
     with closing(db.cursor()) as cursor:
         with self.subTest("entries should be inserted correctly"):
             cursor.execute("SELECT count(*) AS count FROM history")
             results = cursor.fetchall()
             self.assertEqual(results[0]['count'], 10)
         delete(n_keep=10, db_path=self.db_path)
         with self.subTest("no entries should be deleted"):
             cursor.execute("SELECT count(*) AS count FROM history")
             results = cursor.fetchall()
             self.assertEqual(results[0]['count'], 10)
     db.close()
Beispiel #10
0
 def test_error_raised(self):
     system = sample_system_object()
     recorder = HistoryRecorder(system=system,
                                est_path="test.py",
                                db_path=self.db_path)
     try:
         with recorder:
             print("Test Log Capture")
             print("Line 2")
             raise RuntimeError("Training Died")
     except RuntimeError:
         pass
     db = connect(self.db_path)
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM history WHERE pk = (?)",
                        [system.exp_id])
         results = cursor.fetchall()
     with self.subTest("History Captured"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("File name captured"):
         self.assertEqual(results['file'], 'test.py')
     with self.subTest("Status updated"):
         self.assertEqual(results['status'], 'Failed')
     # Logs
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM logs WHERE fk = (?)",
                        [system.exp_id])
         results = cursor.fetchall()
     with self.subTest("Log Captured"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("Complete log captured"):
         self.assertEqual(results['log'], "Test Log Capture\nLine 2\n")
     # Error
     with closing(db.cursor()) as cursor:
         cursor.execute("SELECT * FROM errors WHERE fk = (?)",
                        [system.exp_id])
         results = cursor.fetchall()
     with self.subTest("Error Captured"):
         self.assertEqual(len(results), 1)
     results = results[0]
     with self.subTest("Error info captured"):
         self.assertEqual(results['exc_type'], "RuntimeError")
     db.close()
Beispiel #11
0
    def test_auto_delete(self):
        update_settings(n_keep=5, db_path=self.db_path)

        for i in range(7):
            system = sample_system_object()
            recorder = HistoryRecorder(system=system,
                                       est_path=f"{i}",
                                       db_path=self.db_path)
            with recorder:
                print(f"Run {i}")

        db = connect(self.db_path)
        with closing(db.cursor()) as cursor:
            cursor.execute("SELECT * FROM history")
            results = cursor.fetchall()

        with self.subTest("Ensure correct number retained"):
            self.assertEqual(len(results), 5)
        with self.subTest("Ensure correct entries retained"):
            actual_names = {result['file'] for result in results}
            expected_names = {f"{i}" for i in range(2, 7)}
            self.assertSetEqual(actual_names, expected_names)
        db.close()
Beispiel #12
0
 def test_over_threshold(self):
     db = connect(self.db_path)
     for i in range(10):
         db.execute("INSERT INTO history (pk, train_start) VALUES (?, ?)",
                    [i, datetime.now()])
     db.commit()
     with closing(db.cursor()) as cursor:
         with self.subTest("entries should be inserted correctly"):
             cursor.execute("SELECT count(*) AS count FROM history")
             results = cursor.fetchall()
             self.assertEqual(results[0]['count'], 10)
         delete(n_keep=9, db_path=self.db_path)
         with self.subTest("one entry should be deleted"):
             cursor.execute("SELECT count(*) AS count FROM history")
             results = cursor.fetchall()
             self.assertEqual(results[0]['count'], 9)
         with self.subTest("oldest entry should be deleted"):
             cursor.execute("SELECT pk FROM history")
             results = cursor.fetchall()
             results = {result['pk'] for result in results}
             expected = {i for i in range(1, 10)}
             self.assertSetEqual(results, expected)
     db.close()