예제 #1
0
    def test_rpc_torchscript_record_function(self):
        # tests that torchscript functions can be profiled using with
        # record_function(...) over RPC.
        REMOTE_OP_STR = "#remote_op: "
        if self.rank == 0:
            dst_rank = (self.rank + 1) % self.world_size
            dst_worker_name = worker_name(dst_rank)
            block_scope = "foo"
            with torch.autograd.profiler.profile() as prof:
                call_rpc_torchscript_with_record_function(
                    dst_worker_name, block_scope)

            # Need to call below to populate CPU children.
            prof.key_averages()
            function_events = prof.function_events
            expected_key = (_build_rpc_profiling_key(
                RPCExecMode.ASYNC_JIT,
                torch._jit_internal._qualified_name(
                    script_add_ones_with_record_function),
                worker_name(self.rank),
                dst_worker_name,
            ) + REMOTE_OP_STR + block_scope)
            remote_record_function_event = [
                evt for evt in function_events if evt.name == expected_key
            ][0]
            self.assertTrue(block_scope in remote_record_function_event.name)
            remote_children = remote_record_function_event.cpu_children
            self.assertTrue("aten::add" in child.name
                            for child in remote_children)
예제 #2
0
    def test_record_function_on_caller_rpc_async(self):
        if self.rank == 0:
            dst_rank = (self.rank + 1) % self.world_size
            dst_worker_name = worker_name(dst_rank)
            block_scope = "foo"
            with torch.autograd.profiler.profile() as prof:
                # Runs 2 rpc_async calls within JIT under record_function.
                record_function_on_caller_rpc_async(dst_worker_name,
                                                    block_scope)

            # Ensure record_function event is profiled.
            function_events = prof.function_events
            record_function_scope_event = [
                event for event in function_events if event.name == block_scope
            ]
            self.assertEqual(1, len(record_function_scope_event))
            record_function_scope_event = record_function_scope_event[0]
            # Ensure RPC future is profiled.
            expected_key = _build_rpc_profiling_key(
                RPCExecMode.ASYNC_JIT,
                torch._jit_internal._qualified_name(script_add_ones),
                worker_name(self.rank),
                dst_worker_name,
            )
            jit_rpc_events = [
                event for event in function_events
                if event.name == expected_key
            ]
            self.assertEqual(2, len(jit_rpc_events))
            # Validate that the record_function scope time is greater than both
            # of the individual RPC async call times. The reason it is not necessarily
            # greater than the sum is because the two can execute in parallel.
            for jit_rpc_event in jit_rpc_events:
                self.assertTrue(record_function_scope_event.cpu_time_total >
                                jit_rpc_event.cpu_time_total)
예제 #3
0
    def test_rpc_async_jit_profiled(self):
        # Tests that rpc_async calls made from within a TorchScript function are
        # profiled.
        if self.rank == 0:
            dst_rank = (self.rank + 1) % self.world_size
            dst_worker_name = worker_name(dst_rank)
            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
            kwargs = {}
            with torch.autograd.profiler.profile() as prof:
                script_rpc_async_call(dst_worker_name, args, kwargs)

            # Ensure rpc_async call is profiled
            function_events = prof.function_events
            qual_name = torch._jit_internal._qualified_name(
                two_args_two_kwargs)
            rpc_async_jit_event = [
                event for event in function_events
                if qual_name in event.name and event.node_id == self.rank
            ]
            self.assertEqual(len(rpc_async_jit_event), 1)
            rpc_async_jit_event = rpc_async_jit_event[0]
            profiled_name = _build_rpc_profiling_key(
                RPCExecMode.ASYNC_JIT,
                qual_name,
                worker_name(self.rank),
                dst_worker_name,
            )
            self.assertEqual(profiled_name, rpc_async_jit_event.name)
            remote_events = [
                event for event in function_events if event.is_remote
            ]
            # All remote events should have taken place on dst_rank
            remote_event_node_ids = {
                remote_event.node_id
                for remote_event in remote_events
            }
            self.assertEqual(remote_event_node_ids, {dst_rank})
            # script_rpc_async_call invokes add operator
            # so we should see this as a remote event.
            remote_add = [
                remote_event for remote_event in remote_events
                if "aten::add" in remote_event.name
            ][0]
            remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add"
            self.assertEqual(remote_add.name, remote_add_profiled_name)
예제 #4
0
파일: rpc_test.py 프로젝트: naoyam/pytorch
 def test_call_rpc_with_profiling(self):
     # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
     # future from within a script function that calls rpc_async
     if self.rank == 0:
         with torch.autograd.profiler.profile() as prof:
             prof_key = _build_rpc_profiling_key(
                 RPCExecMode.ASYNC,
                 torch.jit._qualified_name(one_arg),
                 "worker0",
                 "worker1",
             )
             with torch.autograd.profiler.record_function(prof_key) as rf:
                 ret = call_rpc_with_profiling(rf.handle, "worker1")
         # TODO: Can't get a reliable time for this profiling event since
         # it's hard to estimate the execution time on the remote end for non-UDFs.
         # This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
         # After that, this test should be modified to validate the function time.
         events = prof.function_events
         function_event = get_function_event(events, prof_key)
         self.assertTrue(torch.jit._qualified_name(one_arg) in function_event.name)