Exemple #1
0
    def test_call_fork_in_jit_with_profiling(self):
        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
        # future from within a script function with torch.jit.fork
        with torch.autograd.profiler.profile() as prof:
            with torch.autograd.profiler.record_function("foo") as rf:
                ret = call_fork_with_profiling(rf.handle)

        events = prof.function_events
        function_event = get_function_event(events, "foo")
        self.assertEqual(function_event.name, "foo")
Exemple #2
0
    def test_record_function_jit_end_callbacks_with_fork(self):
        # Ensures that we can call rf._call_end_callbacks_on_future on a jit
        # future in python eager mode with torch.jit.fork
        sleep_interval = 1
        with torch.autograd.profiler.profile() as prof:
            with torch.autograd.profiler.record_function("foo") as rf:
                fut = torch.jit._fork(sleep, sleep_interval)
                rf._call_end_callbacks_on_future(fut)
            fut.wait()

        function_events = prof.function_events
        sleep_event = get_function_event(function_events, "foo")
        self.assertEqual(sleep_event.name, "foo")
        # Validate that callbacks were fired at the right time by checking the
        # profiling event cpu time
        self.assertGreaterEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
Exemple #3
0
 def test_call_rpc_with_profiling(self):
     # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
     # future from within a script function that calls rpc_async
     if self.rank == 0:
         with torch.autograd.profiler.profile() as prof:
             prof_key = _build_rpc_profiling_key(
                 RPCExecMode.ASYNC,
                 torch.jit._qualified_name(one_arg),
                 "worker0",
                 "worker1",
             )
             with torch.autograd.profiler.record_function(prof_key) as rf:
                 ret = call_rpc_with_profiling(rf.handle, "worker1")
         # TODO: Can't get a reliable time for this profiling event since
         # it's hard to estimate the execution time on the remote end for non-UDFs.
         # This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
         # After that, this test should be modified to validate the function time.
         events = prof.function_events
         function_event = get_function_event(events, prof_key)
         self.assertTrue(torch.jit._qualified_name(one_arg) in function_event.name)