def testClose_canBeCalledMultipleTimes(self):
     id_ = db.RUN_ROWID.create(1, 1)
     path = self._save_records('events.out.tfevents.0.localhost', [])
     with self.connect_db() as db_conn, self.EventLog(path) as log:
         run = loader.RunReader(id_, 'doodle')
         run.add_event_log(db_conn, log)
         run.close()
         run.close()
Beispiel #2
0
 def testRestartProgram_resumesThings(self):
   id_ = db.RUN_ROWID.create(1, 1)
   event1 = tf.Event(step=123)
   event2 = tf.Event(step=456)
   path = self._save_records('events.out.tfevents.1.localhost',
                             [event1.SerializeToString(),
                              event2.SerializeToString()])
   with self.connect_db() as db_conn:
     with self.EventLog(path) as log:
       with loader.RunReader(id_, 'doodle') as run:
         run.add_event_log(db_conn, log)
         self.assertEqual(event1, run.get_next_event())
         run.save_progress(db_conn)
     with self.EventLog(path) as log:
       with loader.RunReader(id_, 'doodle') as run:
         run.add_event_log(db_conn, log)
         self.assertEqual(event2, run.get_next_event())
 def testReadOneEvent(self):
     id_ = db.RUN_ROWID.create(1, 1)
     event = tf.Event(step=123)
     path = self._save_records('events.out.tfevents.0.localhost',
                               [event.SerializeToString()])
     with self.connect_db() as db_conn:
         with self.EventLog(path) as log:
             with loader.RunReader(id_, 'doodle') as run:
                 run.add_event_log(db_conn, log)
                 self.assertEqual(event, run.get_next_event())
                 self.assertIsNone(run.get_next_event())
 def testMarkWithShrinkingBatchSize_raisesValueError(self):
     id_ = db.RUN_ROWID.create(1, 1)
     event1 = tf.Event(step=123)
     event2 = tf.Event(step=456)
     path1 = self._save_records('events.out.tfevents.1.localhost',
                                [event1.SerializeToString()])
     path2 = self._save_records('events.out.tfevents.2.localhost',
                                [event2.SerializeToString()])
     with self.connect_db() as db_conn:
         with self.EventLog(path1) as log1, self.EventLog(path2) as log2:
             with loader.RunReader(id_, 'doodle') as run:
                 run.add_event_log(db_conn, log1)
                 run.add_event_log(db_conn, log2)
                 run.mark()
                 self.assertEqual(event1, run.get_next_event())
                 self.assertEqual(event2, run.get_next_event())
                 self.assertIsNone(run.get_next_event())
                 run.reset()
                 self.assertEqual(event1, run.get_next_event())
                 with six.assertRaisesRegex(self, ValueError, r'monotonic'):
                     run.mark()
 def testMarkReset_acrossFiles(self):
     id_ = db.RUN_ROWID.create(1, 1)
     event1 = tf.Event(step=123)
     event2 = tf.Event(step=456)
     path1 = self._save_records('events.out.tfevents.1.localhost',
                                [event1.SerializeToString()])
     path2 = self._save_records('events.out.tfevents.2.localhost',
                                [event2.SerializeToString()])
     with self.connect_db() as db_conn:
         with self.EventLog(path1) as log1, self.EventLog(path2) as log2:
             with loader.RunReader(id_, 'doodle') as run:
                 run.add_event_log(db_conn, log1)
                 run.add_event_log(db_conn, log2)
                 run.mark()
                 self.assertEqual(event1, run.get_next_event())
                 self.assertEqual(event2, run.get_next_event())
                 self.assertIsNone(run.get_next_event())
                 run.reset()
                 self.assertEqual(event1, run.get_next_event())
                 self.assertEqual(event2, run.get_next_event())
                 self.assertIsNone(run.get_next_event())
                 run.mark()
 def testNoEventLogs_returnsNone(self):
     id_ = db.RUN_ROWID.create(1, 1)
     with loader.RunReader(id_, 'doodle') as run:
         self.assertIsNone(run.get_next_event())
 def testFields(self):
     id_ = db.RUN_ROWID.create(1, 1)
     with loader.RunReader(id_, 'doodle') as run:
         self.assertEqual('doodle', run.name)
         self.assertEqual(id_, run.rowid)
 def testEqualAndSortsByRowId(self):
     a = loader.RunReader(db.RUN_ROWID.create(1, 1), 'doodle')
     b = loader.RunReader(db.RUN_ROWID.create(1, 2), 'doodle')
     c = loader.RunReader(db.RUN_ROWID.create(2, 1), 'doodle')
     self.assertEqual([a, b, c], sorted([c, b, a]))
 def testBadRowId_throwsValueError(self):
     with self.assertRaises(ValueError):
         loader.RunReader(0, 'doodle')