def test_event_handler(self): with self.create_summary_writer() as w: handle = register_event_handler(TensorboardEventHandler(w)) s = FixedCountStat( "asdf", (Aggregation.SUM, Aggregation.COUNT), 2, ) for i in range(10): s.add(i) self.assertEqual(s.count, 0) unregister_event_handler(handle) mul = event_multiplexer.EventMultiplexer() mul.AddRunsFromDirectory(self.temp_dirs[-1].name) mul.Reload() scalar_dict = mul.PluginRunToTagToContent("scalars") raw_result = { tag: mul.Tensors(run, tag) for run, run_dict in scalar_dict.items() for tag in run_dict } scalars = { tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items() } self.assertEqual(scalars, { "asdf.sum": [1, 5, 9, 13, 17], "asdf.count": [2, 2, 2, 2, 2], })
def test_interval_stat(self) -> None: events = [] def handler(event): events.append(event) handle = register_event_handler(handler) s = IntervalStat( "asdf", (Aggregation.SUM, Aggregation.COUNT), timedelta(milliseconds=1), ) self.assertIsInstance(s, Stat) self.assertEqual(s.name, "asdf") s.add(2) for _ in range(100): # NOTE: different platforms sleep may be inaccurate so we loop # instead (i.e. win) time.sleep(1 / 1000) # ms s.add(3) if len(events) >= 1: break self.assertGreaterEqual(len(events), 1) unregister_event_handler(handle)
def test_interval_stat(self) -> None: events = [] def handler(event): events.append(event) handle = register_event_handler(handler) s = IntervalStat( "asdf", (Aggregation.SUM, Aggregation.COUNT), timedelta(milliseconds=1), ) s.add(2) time.sleep(0.002) s.add(3) self.assertEqual(s.name, "asdf") self.assertGreaterEqual(len(events), 1) unregister_event_handler(handle)
def test_event_handler(self) -> None: events = [] def handler(event: Event) -> None: events.append(event) handle = register_event_handler(handler) e = Event( name="torch.monitor.TestEvent", timestamp=datetime.now(), data={}, ) log_event(e) self.assertEqual(len(events), 1) self.assertEqual(events[0], e) log_event(e) self.assertEqual(len(events), 2) unregister_event_handler(handle) log_event(e) self.assertEqual(len(events), 2)