def testClose_canBeCalledMultipleTimes(self): id_ = db.RUN_ROWID.create(1, 1) path = self._save_records('events.out.tfevents.0.localhost', []) with self.connect_db() as db_conn, self.EventLog(path) as log: run = loader.RunReader(id_, 'doodle') run.add_event_log(db_conn, log) run.close() run.close()
def testRestartProgram_resumesThings(self): id_ = db.RUN_ROWID.create(1, 1) event1 = tf.Event(step=123) event2 = tf.Event(step=456) path = self._save_records('events.out.tfevents.1.localhost', [event1.SerializeToString(), event2.SerializeToString()]) with self.connect_db() as db_conn: with self.EventLog(path) as log: with loader.RunReader(id_, 'doodle') as run: run.add_event_log(db_conn, log) self.assertEqual(event1, run.get_next_event()) run.save_progress(db_conn) with self.EventLog(path) as log: with loader.RunReader(id_, 'doodle') as run: run.add_event_log(db_conn, log) self.assertEqual(event2, run.get_next_event())
def testReadOneEvent(self): id_ = db.RUN_ROWID.create(1, 1) event = tf.Event(step=123) path = self._save_records('events.out.tfevents.0.localhost', [event.SerializeToString()]) with self.connect_db() as db_conn: with self.EventLog(path) as log: with loader.RunReader(id_, 'doodle') as run: run.add_event_log(db_conn, log) self.assertEqual(event, run.get_next_event()) self.assertIsNone(run.get_next_event())
def testMarkWithShrinkingBatchSize_raisesValueError(self): id_ = db.RUN_ROWID.create(1, 1) event1 = tf.Event(step=123) event2 = tf.Event(step=456) path1 = self._save_records('events.out.tfevents.1.localhost', [event1.SerializeToString()]) path2 = self._save_records('events.out.tfevents.2.localhost', [event2.SerializeToString()]) with self.connect_db() as db_conn: with self.EventLog(path1) as log1, self.EventLog(path2) as log2: with loader.RunReader(id_, 'doodle') as run: run.add_event_log(db_conn, log1) run.add_event_log(db_conn, log2) run.mark() self.assertEqual(event1, run.get_next_event()) self.assertEqual(event2, run.get_next_event()) self.assertIsNone(run.get_next_event()) run.reset() self.assertEqual(event1, run.get_next_event()) with six.assertRaisesRegex(self, ValueError, r'monotonic'): run.mark()
def testMarkReset_acrossFiles(self): id_ = db.RUN_ROWID.create(1, 1) event1 = tf.Event(step=123) event2 = tf.Event(step=456) path1 = self._save_records('events.out.tfevents.1.localhost', [event1.SerializeToString()]) path2 = self._save_records('events.out.tfevents.2.localhost', [event2.SerializeToString()]) with self.connect_db() as db_conn: with self.EventLog(path1) as log1, self.EventLog(path2) as log2: with loader.RunReader(id_, 'doodle') as run: run.add_event_log(db_conn, log1) run.add_event_log(db_conn, log2) run.mark() self.assertEqual(event1, run.get_next_event()) self.assertEqual(event2, run.get_next_event()) self.assertIsNone(run.get_next_event()) run.reset() self.assertEqual(event1, run.get_next_event()) self.assertEqual(event2, run.get_next_event()) self.assertIsNone(run.get_next_event()) run.mark()
def testNoEventLogs_returnsNone(self): id_ = db.RUN_ROWID.create(1, 1) with loader.RunReader(id_, 'doodle') as run: self.assertIsNone(run.get_next_event())
def testFields(self): id_ = db.RUN_ROWID.create(1, 1) with loader.RunReader(id_, 'doodle') as run: self.assertEqual('doodle', run.name) self.assertEqual(id_, run.rowid)
def testEqualAndSortsByRowId(self): a = loader.RunReader(db.RUN_ROWID.create(1, 1), 'doodle') b = loader.RunReader(db.RUN_ROWID.create(1, 2), 'doodle') c = loader.RunReader(db.RUN_ROWID.create(2, 1), 'doodle') self.assertEqual([a, b, c], sorted([c, b, a]))
def testBadRowId_throwsValueError(self): with self.assertRaises(ValueError): loader.RunReader(0, 'doodle')