def testSessionLogStartMessageDiscardsExpiredEvents(self): """Test that SessionLog.START message discards expired events. This discard logic is preferred over the out-of-order step discard logic, but this logic can only be used for event protos which have the SessionLog enum, which was introduced to event.proto for file_version >= brain.Event:2. """ gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) gen.AddEvent( event_pb2.Event(wall_time=0, step=1, file_version="brain.Event:2") ) gen.AddScalarTensor("s1", wall_time=1, step=100, value=20) gen.AddScalarTensor("s1", wall_time=1, step=200, value=20) gen.AddScalarTensor("s1", wall_time=1, step=300, value=20) gen.AddScalarTensor("s1", wall_time=1, step=400, value=20) gen.AddScalarTensor("s2", wall_time=1, step=202, value=20) gen.AddScalarTensor("s2", wall_time=1, step=203, value=20) slog = event_pb2.SessionLog(status=event_pb2.SessionLog.START) gen.AddEvent(event_pb2.Event(wall_time=2, step=201, session_log=slog)) acc.Reload() self.assertEqual([x.step for x in acc.Tensors("s1")], [100, 200]) self.assertEqual([x.step for x in acc.Tensors("s2")], [])
def testSessionLogSummaries(self): data = [ { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.START), 'step': 0 }, { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.CHECKPOINT), 'step': 1 }, { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.CHECKPOINT), 'step': 2 }, { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.CHECKPOINT), 'step': 3 }, { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.STOP), 'step': 4 }, { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.START), 'step': 5 }, { 'session_log': event_pb2.SessionLog(status=event_pb2.SessionLog.STOP), 'step': 6 }, ] self._WriteScalarSummaries(data) units = efi.get_inspection_units(self.logdir) self.assertEqual(1, len(units)) printable = efi.get_dict_to_print(units[0].field_to_obs) self.assertEqual(printable['sessionlog:start']['steps'], [0, 5]) self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6]) self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3)
def testSessionLogSummaries(self): data = [ { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.START), "step": 0, }, { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.CHECKPOINT), "step": 1, }, { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.CHECKPOINT), "step": 2, }, { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.CHECKPOINT), "step": 3, }, { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.STOP), "step": 4, }, { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.START), "step": 5, }, { "session_log": event_pb2.SessionLog(status=event_pb2.SessionLog.STOP), "step": 6, }, ] self._WriteScalarSummaries(data) units = efi.get_inspection_units(self.logdir) self.assertEqual(1, len(units)) printable = efi.get_dict_to_print(units[0].field_to_obs) self.assertEqual(printable["sessionlog:start"]["steps"], [0, 5]) self.assertEqual(printable["sessionlog:stop"]["steps"], [4, 6]) self.assertEqual(printable["sessionlog:checkpoint"]["num_steps"], 3)