def test_repeated_module(self) -> None: """ Tests that repeated calls to the same submodule correct aggregates results to that submodule. """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) fc1_count = model.fc1_num * model.fc1_flops fc2_count = model.fc2_num * model.fc2_flops total_count = fc1_count + fc2_count fc1_per_operator = Counter({"addmm": fc1_count}) self.assertEqual(analyzer.total("fc1"), fc1_count) self.assertEqual(analyzer.total("fc2"), fc2_count) self.assertEqual(analyzer.total(""), total_count) self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set())
def test_unused_module(self) -> None: """ Tests that unused modules return 0 count for operator sums and and empty Counter() for per-operator results. Also tests that unused modules are reported by .uncalled_modules(), but that modules that simply have zero flops (like ReLU) are not. """ model = UnusedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) unused_count = 0 unused_per_operator = Counter() model_count = model.fc1_flops + model.fc2_flops self.assertEqual(analyzer.total("unused"), unused_count) self.assertEqual(analyzer.by_operator("unused"), unused_per_operator) self.assertEqual(analyzer.total(""), model_count) # The unused mod is recognized as never called self.assertEqual(analyzer.uncalled_modules(), {"unused"})