def test_rpc_torchscript_record_function(self): # tests that torchscript functions can be profiled using with # record_function(...) over RPC. REMOTE_OP_STR = "#remote_op: " if self.rank == 0: dst_rank = (self.rank + 1) % self.world_size dst_worker_name = worker_name(dst_rank) block_scope = "foo" with _profile() as prof: call_rpc_torchscript_with_record_function( dst_worker_name, block_scope) # Need to call below to populate CPU children. prof.key_averages() function_events = prof.function_events expected_key = (_build_rpc_profiling_key( RPCExecMode.ASYNC_JIT, torch._jit_internal._qualified_name( script_add_ones_with_record_function), worker_name(self.rank), dst_worker_name, ) + REMOTE_OP_STR + block_scope) remote_record_function_event = [ evt for evt in function_events if evt.name == expected_key ][0] self.assertTrue(block_scope in remote_record_function_event.name) remote_children = remote_record_function_event.cpu_children self.assertTrue("aten::add" in child.name for child in remote_children)
def test_record_function_on_caller_rpc_async(self): if self.rank == 0: dst_rank = (self.rank + 1) % self.world_size dst_worker_name = worker_name(dst_rank) block_scope = "foo" with _profile() as prof: # Runs 2 rpc_async calls within JIT under record_function. record_function_on_caller_rpc_async(dst_worker_name, block_scope) # Ensure record_function event is profiled. function_events = prof.function_events record_function_scope_event = [ event for event in function_events if event.name == block_scope ] self.assertEqual(1, len(record_function_scope_event)) record_function_scope_event = record_function_scope_event[0] # Ensure RPC future is profiled. expected_key = _build_rpc_profiling_key( RPCExecMode.ASYNC_JIT, torch._jit_internal._qualified_name(script_add_ones), worker_name(self.rank), dst_worker_name, ) jit_rpc_events = [ event for event in function_events if event.name == expected_key ] self.assertEqual(2, len(jit_rpc_events)) # Validate that the record_function scope time is greater than both # of the individual RPC async call times. The reason it is not necessarily # greater than the sum is because the two can execute in parallel. for jit_rpc_event in jit_rpc_events: self.assertTrue(record_function_scope_event.cpu_time_total > jit_rpc_event.cpu_time_total)
def test_call_fork_in_jit_with_profiling(self): # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit # future from within a script function with torch.jit.fork with _profile() as prof: with torch.autograd.profiler.record_function("foo") as rf: ret = call_fork_with_profiling(rf.handle) events = prof.function_events function_event = get_function_event(events, "foo") self.assertEqual(function_event.name, "foo")
def test_record_function_jit_end_callbacks_with_fork(self): # Ensures that we can call rf._call_end_callbacks_on_future on a jit # future in python eager mode with torch.jit.fork sleep_interval = 1 with _profile() as prof: with torch.autograd.profiler.record_function("foo") as rf: fut = torch.jit._fork(sleep, sleep_interval) rf._call_end_callbacks_on_future(fut) fut.wait() function_events = prof.function_events sleep_event = get_function_event(function_events, "foo") self.assertEqual(sleep_event.name, "foo") # Validate that callbacks were fired at the right time by checking the # profiling event cpu time self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
def test_rpc_async_jit_profiled(self): # Tests that rpc_async calls made from within a TorchScript function are # profiled. if self.rank == 0: dst_rank = (self.rank + 1) % self.world_size dst_worker_name = worker_name(dst_rank) args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = {} with _profile() as prof: script_rpc_async_call(dst_worker_name, args, kwargs) # Ensure rpc_async call is profiled function_events = prof.function_events qual_name = torch._jit_internal._qualified_name( two_args_two_kwargs) rpc_async_jit_event = [ event for event in function_events if qual_name in event.name and event.node_id == self.rank ] self.assertEqual(len(rpc_async_jit_event), 1) rpc_async_jit_event = rpc_async_jit_event[0] profiled_name = _build_rpc_profiling_key( RPCExecMode.ASYNC_JIT, qual_name, worker_name(self.rank), dst_worker_name, ) self.assertEqual(profiled_name, rpc_async_jit_event.name) remote_events = [ event for event in function_events if event.is_remote ] # All remote events should have taken place on dst_rank remote_event_node_ids = { remote_event.node_id for remote_event in remote_events } self.assertEqual(remote_event_node_ids, {dst_rank}) # script_rpc_async_call invokes add operator # so we should see this as a remote event. remote_add = [ remote_event for remote_event in remote_events if "aten::add" in remote_event.name ][0] remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add" self.assertEqual(remote_add.name, remote_add_profiled_name)
def test_call_rpc_with_profiling(self): # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit # future from within a script function that calls rpc_async if self.rank == 0: with _profile() as prof: prof_key = _build_rpc_profiling_key( RPCExecMode.ASYNC, torch._jit_internal._qualified_name(one_arg), "worker0", "worker1", ) with torch.autograd.profiler.record_function(prof_key) as rf: ret = call_rpc_with_profiling(rf.record, "worker1") # TODO: Can't get a reliable time for this profiling event since # it's hard to estimate the execution time on the remote end for non-UDFs. # This can be resolved by https://github.com/pytorch/pytorch/issues/36272. # After that, this test should be modified to validate the function time. events = prof.function_events function_event = get_function_event(events, prof_key) self.assertTrue(torch._jit_internal._qualified_name(one_arg) in function_event.name)