def test_data_parallel_root_scope(self) -> None: # A test case discussed in D32227000 model = nn.DataParallel(nn.Linear(10, 10)) for mode in ["caller", "owner"]: flop = FlopCountAnalysis(model, (torch.randn(10, 10), )) flop.ancestor_mode(mode) self.assertEqual(flop.total(), 1000)
def test_flop_counter_class(self) -> None: """ Test FlopCountAnalysis. """ batch_size = 4 input_dim = 2 conv_dim = 5 spatial_dim = 10 linear_dim = 3 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) threeNet = ThreeNet(input_dim, conv_dim, linear_dim) flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim flop_linear1 = batch_size * conv_dim * linear_dim flop_linear2 = batch_size * linear_dim * 1 flop2 = flop_linear1 + flop_linear2 flop_counter = FlopCountAnalysis(threeNet, (x,)) gt_dict = Counter( { "": flop1 + flop2, "conv": flop1, "linear1": flop_linear1, "linear2": flop_linear2, "pool": 0, } ) self.assertEqual(flop_counter.by_module(), gt_dict)
def test_autograd_function(self): # test support on custom autograd function class Mod(nn.Module): def forward(self, x): return _CustomOp.apply(x) flop = FlopCountAnalysis(Mod(), (torch.rand(4, 5), )).set_op_handle( "prim::PythonOp._CustomOp", lambda *args, **kwargs: 42) self.assertEqual(flop.total(), 42)
def test_by_module_and_operator(self) -> None: """ Tests that JitModelAnalysis.by_module_and_operator() returns the correct counts in the correct structure. """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) self.assertEqual(analyzer.by_module_and_operator(), model.flops)
def test_shared_module(self) -> None: """ Tests the behavior of shared submodules that may have multiple names. """ model = SharedModuleNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis( model=model, inputs=inputs).unsupported_ops_warnings(enabled=False) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. # The counts associated with these cases are included under # `submod1.submod` and `multiname1` respectively. multiname_flops = 2 * model.multiname_flops # Called under 2 names shared_flops = 2 * model.shared_flops # Shared under 2 submodules total_flops = multiname_flops + shared_flops flops = { "": total_flops, "submod1": model.shared_flops, "submod1.submod": shared_flops, "submod2": model.shared_flops, "multiname1": multiname_flops, } self.assertEqual(analyzer.by_module(), flops) # Test access by alternative name self.assertEqual( analyzer.total("submod2.submod"), flops["submod1.submod"], ) self.assertEqual( analyzer.total("multiname2"), flops["multiname1"], ) # Test getting canonical name self.assertEqual(analyzer.canonical_module_name("multiname2"), "multiname1") self.assertEqual(analyzer.canonical_module_name("multiname1"), "multiname1") self.assertEqual(analyzer.canonical_module_name("submod2.submod"), "submod1.submod") self.assertEqual(analyzer.canonical_module_name("submod1.submod"), "submod1.submod") # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set())
def test_scripted_function(self): # Scripted function is not yet supported. It should produce a warning def func(x): return x @ x class Mod(nn.Module): def forward(self, x): f = torch.jit.script(func) return f(x * x) flop = FlopCountAnalysis(Mod(), (torch.rand(5, 5), )) _ = flop.total() self.assertIn("prim::CallFunction", flop.unsupported_ops())
def test_recursive_scope(self) -> None: """ Tests that an op is only counted once per module, even if it is in the scope of that module multiple times. """ model = RecursiveScopeNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model, inputs) self.assertEqual(analyzer.total(), model.flops) self.assertEqual(analyzer.total("fc"), model.flops) # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set())
def test_by_module(self) -> None: """ Tests that JitModelAnalysis.by_module() returns the correct counts in the correctly structured dictionary. """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)),) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) flops = {name: sum(counts.values()) for name, counts in model.flops.items()} self.assertEqual(analyzer.by_module(), flops)
def test_by_operator(self) -> None: """ Tests that JitModelAnalysis.by_operator(module) returns the correct counts for string and module inputs. """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) # Using a string input for name in model.flops: with self.subTest(name=name): self.assertEqual(analyzer.by_operator(name), model.flops[name])
def test_total(self) -> None: """ Tests that JitModelAnalysis.total(module) returns the correct counts for string and module inputs. """ model = NestedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.skipped_ops_warnings(enabled=False) # Using a string input for name in model.flops: with self.subTest(name=name): gt_flops = sum(model.flops[name].values()) self.assertEqual(analyzer.total(name), gt_flops)
def test_data_parallel(self) -> None: """ Tests that a model wrapped in DataParallel still returns results labeled by the correct scopes. """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) # Find flops for wrapper flops = { "module" + ("." if name else "") + name: flop for name, flop in model.flops.items() } flops[""] = model.flops[""] name_to_module = { "module" + ("." if name else "") + name: mod for name, mod in model.name_to_module.items() } name_to_module[""] = model.name_to_module[""] model = torch.nn.DataParallel(model).cpu() analyzer = FlopCountAnalysis(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) # Using a string input for name in flops: with self.subTest(name=name): gt_flops = sum(flops[name].values()) self.assertEqual(analyzer.total(name), gt_flops) # Output as dictionary self.assertEqual(analyzer.by_module_and_operator(), flops) # Test no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set())
def test_skip_uncalled_containers_warnings(self) -> None: # uncalled containers should not warn class A(nn.Module): def forward(self, x): return self.submod[0](x) + 1 mod = A() mod.submod = nn.ModuleList([nn.Linear(3, 3)]) # pyre-ignore analyzer = FlopCountAnalysis(model=mod, inputs=torch.rand(1, 3)) analyzer.unsupported_ops_warnings(enabled=False) logger = logging.getLogger() with self.assertLogs(logger, logging.WARN) as cm: logger.warning("Dummy warning.") _ = analyzer.total() uncalled_string = "Module never called: submod" self.assertFalse(any(uncalled_string in s for s in cm.output))
def test_non_forward_func_call(self) -> None: """ Tests that calls to a submodule's non-forward function attribute resulting counts to the calling module. Also tests that the intermediate module is correctly identified as a skipped module. """ model = NonForwardNet() inputs = (torch.randn((1, 10)),) analyzer = FlopCountAnalysis(model=model, inputs=inputs) submod_count = 0 inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count self.assertEqual(analyzer.total("submod"), submod_count) self.assertEqual(analyzer.total("submod.fc"), inner_fc_count) self.assertEqual(analyzer.total(""), total_count) # The mod not directly called is registered as such self.assertEqual(analyzer.uncalled_modules(), {"submod"})
def test_repeated_module(self) -> None: """ Tests that repeated calls to the same submodule correct aggregates results to that submodule. """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) fc1_count = model.fc1_num * model.fc1_flops fc2_count = model.fc2_num * model.fc2_flops total_count = fc1_count + fc2_count fc1_per_operator = Counter({self.lin_op: fc1_count}) self.assertEqual(analyzer.total("fc1"), fc1_count) self.assertEqual(analyzer.total("fc2"), fc2_count) self.assertEqual(analyzer.total(""), total_count) self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set())
def test_unused_module(self) -> None: """ Tests that unused modules return 0 count for operator sums and and empty Counter() for per-operator results. Also tests that unused modules are reported by .uncalled_modules(), but that modules that simply have zero flops (like ReLU) are not. """ model = UnusedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) unused_count = 0 unused_per_operator = Counter() model_count = model.fc1_flops + model.fc2_flops self.assertEqual(analyzer.total("unused"), unused_count) self.assertEqual(analyzer.by_operator("unused"), unused_per_operator) self.assertEqual(analyzer.total(""), model_count) # The unused mod is recognized as never called self.assertEqual(analyzer.uncalled_modules(), {"unused"})
def test_disable_warnings(self) -> None: """ Tests .unsupported_ops_warnings(...) and .tracer_warnings(...) """ model = TraceWarningNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopCountAnalysis(model=model, inputs=inputs) # Tracer warnings analyzer.tracer_warnings(mode="all") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer.tracer_warnings(mode="none") analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse( torch.jit._trace.TracerWarning in warning_types) self.assertFalse(RuntimeWarning in warning_types) analyzer.tracer_warnings(mode="no_tracer_warning") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse( torch.jit._trace.TracerWarning in warning_types) # Unsupported ops and uncalled modules warnings logger = logging.getLogger() skipeed_msg = "Unsupported operator aten::add encountered 1 time(s)" uncalled_msg = "never called" uncalled_modules = "fc1" # fc2 is called by chance analyzer.uncalled_modules_warnings(enabled=False) analyzer.unsupported_ops_warnings(enabled=False) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: logger.warning("Dummy warning.") _ = analyzer.total() self.assertFalse(any(skipeed_msg in s for s in cm.output)) self.assertFalse(any(uncalled_msg in s for s in cm.output)) analyzer.unsupported_ops_warnings(enabled=True) analyzer.uncalled_modules_warnings(enabled=True) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: _ = analyzer.total() self.assertTrue(any(skipeed_msg in s for s in cm.output)) self.assertTrue(any(uncalled_msg in s for s in cm.output)) self.assertTrue(any(uncalled_modules in s for s in cm.output))
import sys sys.path.append('..') import argparse from batchgenerators.utilities.file_and_folder_operations import * from nnunet.training.model_restore import load_model_and_checkpoint_files from fvcore.nn.flop_count import _DEFAULT_SUPPORTED_OPS, FlopCountAnalysis, flop_count import numpy as np import torch import os join = os.path.join parser = argparse.ArgumentParser() parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres", default="3d_fullres", required=False) args = parser.parse_args() model = args.model model_path = join('./data/RESULTS_FOLDER/nnUNet', model, 'Task000_FLARE21Baseline/nnUNetTrainerV2__nnUNetPlansv2.1') trainer, params = load_model_and_checkpoint_files(model_path, folds='all', checkpoint_name='model_final_checkpoint') pkl_file = join(model_path, "all/model_final_checkpoint.model.pkl") info = load_pickle(pkl_file) if model == '2d' or model == '3d_lowres': patch_size = info['plans']['plans_per_stage'][0]['patch_size'] else: patch_size = info['plans']['plans_per_stage'][1]['patch_size'] patch_size = np.append(np.array(1), patch_size) inputs = (torch.randn(tuple(np.append(np.array(1),patch_size))).cuda(),) flops = FlopCountAnalysis(trainer.network, inputs) print('Total FLOPs:', flops.total())