def test_call_fork_in_jit_with_profiling(self): # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit # future from within a script function with torch.jit.fork with torch.autograd.profiler.profile() as prof: with torch.autograd.profiler.record_function("foo") as rf: ret = call_fork_with_profiling(rf.handle) events = prof.function_events function_event = get_function_event(events, "foo") self.assertEqual(function_event.name, "foo")
def test_record_function_jit_end_callbacks_with_fork(self): # Ensures that we can call rf._call_end_callbacks_on_future on a jit # future in python eager mode with torch.jit.fork sleep_interval = 1 with torch.autograd.profiler.profile() as prof: with torch.autograd.profiler.record_function("foo") as rf: fut = torch.jit._fork(sleep, sleep_interval) rf._call_end_callbacks_on_future(fut) fut.wait() function_events = prof.function_events sleep_event = get_function_event(function_events, "foo") self.assertEqual(sleep_event.name, "foo") # Validate that callbacks were fired at the right time by checking the # profiling event cpu time self.assertGreaterEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
def test_call_rpc_with_profiling(self): # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit # future from within a script function that calls rpc_async if self.rank == 0: with torch.autograd.profiler.profile() as prof: prof_key = _build_rpc_profiling_key( RPCExecMode.ASYNC, torch.jit._qualified_name(one_arg), "worker0", "worker1", ) with torch.autograd.profiler.record_function(prof_key) as rf: ret = call_rpc_with_profiling(rf.handle, "worker1") # TODO: Can't get a reliable time for this profiling event since # it's hard to estimate the execution time on the remote end for non-UDFs. # This can be resolved by https://github.com/pytorch/pytorch/issues/36272. # After that, this test should be modified to validate the function time. events = prof.function_events function_event = get_function_event(events, prof_key) self.assertTrue(torch.jit._qualified_name(one_arg) in function_event.name)