def test_repeated_module(self) -> None: """ Tests that repeated calls to the same submodule correct aggregates results to that submodule. """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) fc1_count = model.fc1_num * model.fc1_flops fc2_count = model.fc2_num * model.fc2_flops total_count = fc1_count + fc2_count fc1_per_operator = Counter({"addmm": fc1_count}) self.assertEqual(analyzer.total("fc1"), fc1_count) self.assertEqual(analyzer.total("fc2"), fc2_count) self.assertEqual(analyzer.total(""), total_count) self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set())
def test_non_forward_func_call(self) -> None: """ Tests that calls to a submodule's non-forward function attribute resulting counts to the calling module. Also tests that the intermediate module is correctly identified as a skipped module. """ model = NonForwardNet() inputs = (torch.randn((1, 10)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) submod_count = 0 inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count self.assertEqual(analyzer.total("submod"), submod_count) self.assertEqual(analyzer.total("submod.fc"), inner_fc_count) self.assertEqual(analyzer.total(""), total_count) # The mod not directly called is registered as such self.assertEqual(analyzer.uncalled_modules(), {"submod"})
def test_unused_module(self) -> None: """ Tests that unused modules return 0 count for operator sums and and empty Counter() for per-operator results. Also tests that unused modules are reported by .uncalled_modules(), but that modules that simply have zero flops (like ReLU) are not. """ model = UnusedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) unused_count = 0 unused_per_operator = Counter() model_count = model.fc1_flops + model.fc2_flops self.assertEqual(analyzer.total("unused"), unused_count) self.assertEqual(analyzer.by_operator("unused"), unused_per_operator) self.assertEqual(analyzer.total(""), model_count) # The unused mod is recognized as never called self.assertEqual(analyzer.uncalled_modules(), {"unused"})
def test_shared_module(self) -> None: """ Tests the behavior of shared submodules that may have multiple names. """ model = SharedModuleNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) analyzer.skipped_ops_warnings(enabled=False) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. # The counts associated with these cases are included under # `submod1.submod` and `multiname1` respectively. multiname_flops = 2 * model.multiname_flops # Called under 2 names shared_flops = 2 * model.shared_flops # Shared under 2 submodules total_flops = multiname_flops + shared_flops flops = { "": total_flops, "submod1": model.shared_flops, "submod1.submod": shared_flops, "submod2": model.shared_flops, "multiname1": multiname_flops, } self.assertEqual(analyzer.by_module(), flops) # Test access by alternative name self.assertEqual( analyzer.total("submod2.submod"), flops["submod1.submod"], ) self.assertEqual( analyzer.total("multiname2"), flops["multiname1"], ) # Test getting canonical name self.assertEqual(analyzer.canonical_module_name("multiname2"), "multiname1") self.assertEqual(analyzer.canonical_module_name("multiname1"), "multiname1") self.assertEqual(analyzer.canonical_module_name("submod2.submod"), "submod1.submod") self.assertEqual(analyzer.canonical_module_name("submod1.submod"), "submod1.submod") # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set())