コード例 #1
0
    def test_event_handler(self):
        with self.create_summary_writer() as w:
            handle = register_event_handler(TensorboardEventHandler(w))

            s = FixedCountStat(
                "asdf",
                (Aggregation.SUM, Aggregation.COUNT),
                2,
            )
            for i in range(10):
                s.add(i)
            self.assertEqual(s.count, 0)

            unregister_event_handler(handle)

        mul = event_multiplexer.EventMultiplexer()
        mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
        mul.Reload()
        scalar_dict = mul.PluginRunToTagToContent("scalars")
        raw_result = {
            tag: mul.Tensors(run, tag)
            for run, run_dict in scalar_dict.items() for tag in run_dict
        }
        scalars = {
            tag: [e.tensor_proto.float_val[0] for e in events]
            for tag, events in raw_result.items()
        }
        self.assertEqual(scalars, {
            "asdf.sum": [1, 5, 9, 13, 17],
            "asdf.count": [2, 2, 2, 2, 2],
        })
コード例 #2
0
    def test_interval_stat(self) -> None:
        events = []

        def handler(event):
            events.append(event)

        handle = register_event_handler(handler)
        s = IntervalStat(
            "asdf",
            (Aggregation.SUM, Aggregation.COUNT),
            timedelta(milliseconds=1),
        )
        self.assertIsInstance(s, Stat)
        self.assertEqual(s.name, "asdf")

        s.add(2)
        for _ in range(100):
            # NOTE: different platforms sleep may be inaccurate so we loop
            # instead (i.e. win)
            time.sleep(1 / 1000)  # ms
            s.add(3)
            if len(events) >= 1:
                break
        self.assertGreaterEqual(len(events), 1)
        unregister_event_handler(handle)
コード例 #3
0
    def test_interval_stat(self) -> None:
        events = []

        def handler(event):
            events.append(event)

        handle = register_event_handler(handler)
        s = IntervalStat(
            "asdf",
            (Aggregation.SUM, Aggregation.COUNT),
            timedelta(milliseconds=1),
        )
        s.add(2)
        time.sleep(0.002)
        s.add(3)
        self.assertEqual(s.name, "asdf")
        self.assertGreaterEqual(len(events), 1)
        unregister_event_handler(handle)
コード例 #4
0
ファイル: test_monitor.py プロジェクト: nateanl/pytorch
    def test_event_handler(self) -> None:
        events = []

        def handler(event: Event) -> None:
            events.append(event)

        handle = register_event_handler(handler)
        e = Event(
            name="torch.monitor.TestEvent",
            timestamp=datetime.now(),
            data={},
        )
        log_event(e)
        self.assertEqual(len(events), 1)
        self.assertEqual(events[0], e)
        log_event(e)
        self.assertEqual(len(events), 2)

        unregister_event_handler(handle)
        log_event(e)
        self.assertEqual(len(events), 2)