def test_by_operator(self) -> None: """ Tests that JitModelAnalysis.by_operator(module) returns the correct counts for string and module inputs. """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) # Using a string input for name in model.flops: with self.subTest(name=name): self.assertEqual(analyzer.by_operator(name), model.flops[name])
def test_repeated_module(self) -> None: """ Tests that repeated calls to the same submodule correct aggregates results to that submodule. """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) fc1_count = model.fc1_num * model.fc1_flops fc2_count = model.fc2_num * model.fc2_flops total_count = fc1_count + fc2_count fc1_per_operator = Counter({self.lin_op: fc1_count}) self.assertEqual(analyzer.total("fc1"), fc1_count) self.assertEqual(analyzer.total("fc2"), fc2_count) self.assertEqual(analyzer.total(""), total_count) self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set())
def test_unused_module(self) -> None: """ Tests that unused modules return 0 count for operator sums and and empty Counter() for per-operator results. Also tests that unused modules are reported by .uncalled_modules(), but that modules that simply have zero flops (like ReLU) are not. """ model = UnusedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) unused_count = 0 unused_per_operator = Counter() model_count = model.fc1_flops + model.fc2_flops self.assertEqual(analyzer.total("unused"), unused_count) self.assertEqual(analyzer.by_operator("unused"), unused_per_operator) self.assertEqual(analyzer.total(""), model_count) # The unused mod is recognized as never called self.assertEqual(analyzer.uncalled_modules(), {"unused"})