示例#1
0
 def test_save_invalid_format(self):
     tracer = VizTracer()
     tracer.start()
     _ = len([1, 2])
     tracer.stop()
     with self.assertRaises(Exception):
         tracer.save("test.invalid")
示例#2
0
    def do_one_function(self, func):
        # the original speed
        with Timer() as t:
            func()
            origin = t.get_time()

        # With viztracer + c tracer + vdb
        tracer = VizTracer(verbose=0, vdb=True)
        tracer.start()
        with Timer() as t:
            func()
            instrumented_c_vdb = t.get_time()
        tracer.stop()
        with Timer() as t:
            tracer.parse()
            instrumented_c_vdb_parse = t.get_time()
        with Timer() as t:
            tracer.save(output_file="tmp.json")
            instrumented_c_vdb_json = t.get_time()
        os.remove("tmp.json")
        tracer.clear()

        # With viztracer + c tracer
        tracer = VizTracer(verbose=0)
        tracer.start()
        with Timer() as t:
            func()
            instrumented_c = t.get_time()
        tracer.stop()
        with Timer() as t:
            tracer.parse()
            instrumented_c_parse = t.get_time()
        with Timer() as t:
            tracer.generate_json(allow_binary=True)
            instrumented_c_json = t.get_time()
        tracer.clear()

        # With cProfiler
        pr = cProfile.Profile()
        pr.enable()
        with Timer() as t:
            func()
            cprofile = t.get_time()
        pr.disable()

        def time_str(name, origin, instrumented):
            return "{:.9f}({:.2f})[{}] ".format(instrumented,
                                                instrumented / origin, name)

        print(time_str("origin", origin, origin))
        print(
            time_str("c+vdb", origin, instrumented_c_vdb) +
            time_str("parse", origin, instrumented_c_vdb_parse) +
            time_str("json", origin, instrumented_c_vdb_json))
        print(
            time_str("c", origin, instrumented_c) +
            time_str("parse", origin, instrumented_c_parse) +
            time_str("json", origin, instrumented_c_json))
        print(time_str("cProfile", origin, cprofile))
示例#3
0
 def test_basic(self):
     pl = MyPlugin()
     tracer = VizTracer(plugins=[pl])
     tracer.start()
     tracer.stop()
     tracer.save()
     self.assertEqual(pl.event_counter, 4)
     self.assertEqual(pl.handler_triggered, True)
示例#4
0
 def test_save(self):
     tracer = VizTracer(tracer_entries=10)
     tracer.start()
     fib(5)
     tracer.stop()
     tracer.parse()
     tracer.save("./tmp/result.html")
     self.assertTrue(os.path.exists("./tmp/result.html"))
     shutil.rmtree("./tmp")
示例#5
0
 def test_double_parse(self):
     tracer = VizTracer()
     tracer.start()
     fib(10)
     tracer.stop()
     tracer.parse()
     result1 = tracer.save()
     tracer.parse()
     result2 = tracer.save()
     self.assertEqual(result1, result2)
示例#6
0
 def test_buffer_wrap(self):
     tracer = VizTracer(tracer_entries=10)
     tracer.start()
     a = VizObject(tracer, "my variable")
     for i in range(15):
         a.hello = i
     tracer.stop()
     entries = tracer.parse()
     tracer.save()
     self.assertEqual(entries, 10)
示例#7
0
 def test_inherit(self):
     tracer = VizTracer()
     tracer.start()
     a = Hello(tracer, "name")
     a.b = 1
     a.c = 2
     a.d = 3
     a.log()
     tracer.stop()
     entries = tracer.parse()
     tracer.save()
     self.assertEqual(entries, 2)
示例#8
0
 def test_basic(self):
     tracer = VizTracer()
     tracer.start()
     a = Hello(tracer)
     a.set_viztracer_attributes(["a", "b"])
     a.change_val()
     a.change_val2()
     a = Hello(tracer)
     a.change_val()
     tracer.stop()
     entries = tracer.parse()
     tracer.save()
     self.assertEqual(entries, 8)
示例#9
0
 def test_instant(self):
     def s():
         return 0
     tracer = VizTracer()
     tracer.start()
     # This is a library function which will be ignored, but
     # this could trick the system into a ignoring status
     tracer.add_instant("name", {"a": 1})
     s()
     s()
     s()
     tracer.stop()
     entries = tracer.parse()
     tracer.save()
     self.assertEqual(entries, 4)
示例#10
0
 def test_trigger_on_change(self):
     tracer = VizTracer()
     tracer.stop()
     tracer.cleanup()
     tracer.start()
     a = VizObject(tracer, "my variable", trigger_on_change=False)
     a.hello = 1
     a.b = 2
     a.c = 3
     a.lol = 4
     a.log()
     tracer.stop()
     entries = tracer.parse()
     tracer.save()
     del a
     self.assertEqual(entries, 2)
示例#11
0
 def test_c_run_after_clear(self):
     tracer = VizTracer()
     tracer.start()
     fib(5)
     tracer.stop()
     entries1 = tracer.parse()
     with io.StringIO() as s:
         tracer.save(s)
         report1 = s.getvalue()
     tracer.start()
     fib(5)
     tracer.stop()
     entries2 = tracer.parse()
     with io.StringIO() as s:
         tracer.save(s)
         report2 = s.getvalue()
     self.assertEqual(entries1, entries2)
     self.assertNotEqual(report1, report2)
示例#12
0
    def test_basic(self):
        tracer = VizTracer(max_stack_depth=4)
        tracer.start()

        thread1 = MyThread()
        thread2 = MyThread()
        thread3 = MyThread()
        thread4 = MyThread()

        thread1.start()
        thread2.start()
        thread3.start()
        thread4.start()

        threads = [thread1, thread2, thread3, thread4]

        for thread in threads:
            thread.join()

        tracer.stop()
        entries = tracer.parse()
        tracer.save("testres.html")
        self.assertGreater(entries, 160)
示例#13
0

def g(a, b):
    a += h(a)
    b += 3
    # raise Exception("lol")


def f(a, b):
    # wthell.wth()
    a = a + 2
    ob.s = str(b)
    g(a + 1, b * 2)
    # wthell.wth()
    h(36)


def t(a):
    f(a + 1, a + 2)
    a += 3
    f(a + 5, 2)


tracer = VizTracer()
counter = VizCounter(tracer, "a")
ob = VizObject(tracer, "b")
tracer.start()
t(3)
tracer.stop()
tracer.save("vdb_basic.json")
示例#14
0
        return 1
    time.sleep(0.0000001)
    return fib(n - 1) + fib(n - 2)


class MyThread(threading.Thread):
    def run(self):
        fib(7)


tracer = VizTracer(verbose=1)
tracer.start()

thread1 = MyThread()
thread2 = MyThread()
thread3 = MyThread()
thread4 = MyThread()

thread1.start()
thread2.start()
thread3.start()
thread4.start()

threads = [thread1, thread2, thread3, thread4]

for thread in threads:
    thread.join()

tracer.stop()
tracer.save(output_file="vdb_multithread.json")
示例#15
0

if __name__ == "__main__":
    import argparse

    # Read and Parse arguments
    function_to_run, profiling_type, log_name, sort_stats = setup_args()
    if profiling_type == 'internal_cprof':
        import cProfile
        import pstats

        profiler = cProfile.Profile()
        profiler.enable()
    elif profiling_type == 'internal_viz':
        from viztracer import VizTracer

        tracer = VizTracer()
        tracer.start()

    # Initialize and run ProfilingPlay
    prof_play = ProfilingPlay(log_name=log_name)
    prof_play.run(func_to_run=function_to_run)

    if profiling_type == 'internal_cprof':
        profiler.disable()
        stats = pstats.Stats(profiler).sort_stats(sort_stats)
        stats.print_stats()
    elif profiling_type == 'internal_viz':
        tracer.stop()
        tracer.save(f"{function_to_run}.html")
示例#16
0

def g(a, b):
    a += h(a)
    b += 3
    #raise Exception("lol")


def f(a, b):
    #wthell.wth()
    a = a + 2
    ob.s = str(b)
    g(a + 1, b * 2)
    #wthell.wth()
    h(36)


def t(a):
    f(a + 1, a + 2)
    a += 3
    f(a + 5, 2)


tracer = VizTracer()
counter = VizCounter(tracer, "a")
ob = VizObject(tracer, "b")
tracer.start()
t(3)
tracer.stop()
tracer.save("test.json")
        exp_source, buffer_size=params.replay_size)

    optimizer = optim.Adam(net.parameters(), lr=params.learning_rate)

    def process_batch(engine, batch):
        optimizer.zero_grad()
        loss_v = common.calc_loss_dqn(batch,
                                      net,
                                      tgt_net.target_model,
                                      gamma=params.gamma,
                                      device=device)

        loss_v.backward()
        optimizer.step()
        epsilon_tracker.frame(engine.state.iteration)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()

        return {
            "loss": loss_v.item(),
            "epsilon": selector.epsilon,
        }

    engine = Engine(process_batch)
    common.setup_ignite(engine, params, exp_source, NAME)
    tracer.stop()
    tracer.save()
    engine.run(
        common.batch_generator(buffer, params.replay_initial,
                               params.batch_size))