def test_data_parallel(self) -> None: """ Tests that a model wrapped in DataParallel still returns results labeled by the correct scopes. """ model = NestedNet() inputs = (torch.randn((1, *model.input_size)), ) # Find flops for wrapper flops = { "module" + ("." if name else "") + name: flop for name, flop in model.flops.items() } flops[""] = model.flops[""] name_to_module = { "module" + ("." if name else "") + name: mod for name, mod in model.name_to_module.items() } name_to_module[""] = model.name_to_module[""] model = torch.nn.DataParallel(model) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.skipped_ops_warnings(enabled=False) # Using a string input for name in flops: with self.subTest(name=name): gt_flops = sum(flops[name].values()) self.assertEqual(analyzer.total(name), gt_flops) # Output as dictionary self.assertEqual(analyzer.by_module_and_operator(), flops) # Test no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set())
def test_by_module_and_operator(self) -> None: """ Tests that JitModelAnalysis.by_module_and_operator() returns the correct counts in the correct structure. """ model = NestedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.skipped_ops_warnings(enabled=False) self.assertEqual(analyzer.by_module_and_operator(), model.flops)
def test_by_operator(self) -> None: """ Tests that JitModelAnalysis.by_operator(module) returns the correct counts for string and module inputs. """ model = NestedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.skipped_ops_warnings(enabled=False) # Using a string input for name in model.flops: with self.subTest(name=name): self.assertEqual(analyzer.by_operator(name), model.flops[name])
def test_by_module(self) -> None: """ Tests that JitModelAnalysis.by_module() returns the correct counts in the correctly structured dictionary. """ model = NestedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.skipped_ops_warnings(enabled=False) flops = { name: sum(counts.values()) for name, counts in model.flops.items() } self.assertEqual(analyzer.by_module(), flops)