def test_data_parallel_root_scope(self) -> None:
     # A test case discussed in D32227000
     model = nn.DataParallel(nn.Linear(10, 10))
     for mode in ["caller", "owner"]:
         flop = FlopCountAnalysis(model, (torch.randn(10, 10), ))
         flop.ancestor_mode(mode)
         self.assertEqual(flop.total(), 1000)
Exemple #2
0
 def test_flop_counter_class(self) -> None:
     """
     Test FlopCountAnalysis.
     """
     batch_size = 4
     input_dim = 2
     conv_dim = 5
     spatial_dim = 10
     linear_dim = 3
     x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
     threeNet = ThreeNet(input_dim, conv_dim, linear_dim)
     flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim
     flop_linear1 = batch_size * conv_dim * linear_dim
     flop_linear2 = batch_size * linear_dim * 1
     flop2 = flop_linear1 + flop_linear2
     flop_counter = FlopCountAnalysis(threeNet, (x,))
     gt_dict = Counter(
         {
             "": flop1 + flop2,
             "conv": flop1,
             "linear1": flop_linear1,
             "linear2": flop_linear2,
             "pool": 0,
         }
     )
     self.assertEqual(flop_counter.by_module(), gt_dict)
Exemple #3
0
    def test_autograd_function(self):
        # test support on custom autograd function

        class Mod(nn.Module):
            def forward(self, x):
                return _CustomOp.apply(x)

        flop = FlopCountAnalysis(Mod(), (torch.rand(4, 5), )).set_op_handle(
            "prim::PythonOp._CustomOp", lambda *args, **kwargs: 42)
        self.assertEqual(flop.total(), 42)
    def test_by_module_and_operator(self) -> None:
        """
        Tests that JitModelAnalysis.by_module_and_operator() returns
        the correct counts in the correct structure.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        self.assertEqual(analyzer.by_module_and_operator(), model.flops)
Exemple #5
0
    def test_shared_module(self) -> None:
        """
        Tests the behavior of shared submodules that may have multiple
        names.
        """

        model = SharedModuleNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(
            model=model, inputs=inputs).unsupported_ops_warnings(enabled=False)

        # The names `submod2.submod` and `multiname2` are not included,
        # since only the first name of a module is made the canonical one.
        # The counts associated with these cases are included under
        # `submod1.submod` and `multiname1` respectively.
        multiname_flops = 2 * model.multiname_flops  # Called under 2 names
        shared_flops = 2 * model.shared_flops  # Shared under 2 submodules
        total_flops = multiname_flops + shared_flops
        flops = {
            "": total_flops,
            "submod1": model.shared_flops,
            "submod1.submod": shared_flops,
            "submod2": model.shared_flops,
            "multiname1": multiname_flops,
        }

        self.assertEqual(analyzer.by_module(), flops)

        # Test access by alternative name
        self.assertEqual(
            analyzer.total("submod2.submod"),
            flops["submod1.submod"],
        )
        self.assertEqual(
            analyzer.total("multiname2"),
            flops["multiname1"],
        )

        # Test getting canonical name
        self.assertEqual(analyzer.canonical_module_name("multiname2"),
                         "multiname1")
        self.assertEqual(analyzer.canonical_module_name("multiname1"),
                         "multiname1")
        self.assertEqual(analyzer.canonical_module_name("submod2.submod"),
                         "submod1.submod")
        self.assertEqual(analyzer.canonical_module_name("submod1.submod"),
                         "submod1.submod")

        # Tests no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Exemple #6
0
    def test_scripted_function(self):
        # Scripted function is not yet supported. It should produce a warning

        def func(x):
            return x @ x

        class Mod(nn.Module):
            def forward(self, x):
                f = torch.jit.script(func)
                return f(x * x)

        flop = FlopCountAnalysis(Mod(), (torch.rand(5, 5), ))
        _ = flop.total()
        self.assertIn("prim::CallFunction", flop.unsupported_ops())
    def test_recursive_scope(self) -> None:
        """
        Tests that an op is only counted once per module, even if it is
        in the scope of that module multiple times.
        """
        model = RecursiveScopeNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model, inputs)

        self.assertEqual(analyzer.total(), model.flops)
        self.assertEqual(analyzer.total("fc"), model.flops)

        # Tests no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Exemple #8
0
    def test_by_module(self) -> None:
        """
        Tests that JitModelAnalysis.by_module() returns the correct
        counts in the correctly structured dictionary.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)),)

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        flops = {name: sum(counts.values()) for name, counts in model.flops.items()}

        self.assertEqual(analyzer.by_module(), flops)
    def test_by_operator(self) -> None:
        """
        Tests that JitModelAnalysis.by_operator(module) returns the correct
        counts for string and module inputs.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        # Using a string input
        for name in model.flops:
            with self.subTest(name=name):
                self.assertEqual(analyzer.by_operator(name), model.flops[name])
    def test_total(self) -> None:
        """
        Tests that JitModelAnalysis.total(module) returns the correct
        counts for string and module inputs.
        """

        model = NestedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.skipped_ops_warnings(enabled=False)

        # Using a string input
        for name in model.flops:
            with self.subTest(name=name):
                gt_flops = sum(model.flops[name].values())
                self.assertEqual(analyzer.total(name), gt_flops)
    def test_data_parallel(self) -> None:
        """
        Tests that a model wrapped in DataParallel still returns results
        labeled by the correct scopes.
        """
        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        # Find flops for wrapper
        flops = {
            "module" + ("." if name else "") + name: flop
            for name, flop in model.flops.items()
        }
        flops[""] = model.flops[""]
        name_to_module = {
            "module" + ("." if name else "") + name: mod
            for name, mod in model.name_to_module.items()
        }
        name_to_module[""] = model.name_to_module[""]

        model = torch.nn.DataParallel(model).cpu()
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        # Using a string input
        for name in flops:
            with self.subTest(name=name):
                gt_flops = sum(flops[name].values())
                self.assertEqual(analyzer.total(name), gt_flops)

        # Output as dictionary
        self.assertEqual(analyzer.by_module_and_operator(), flops)

        # Test no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
    def test_skip_uncalled_containers_warnings(self) -> None:
        # uncalled containers should not warn

        class A(nn.Module):
            def forward(self, x):
                return self.submod[0](x) + 1

        mod = A()
        mod.submod = nn.ModuleList([nn.Linear(3, 3)])  # pyre-ignore
        analyzer = FlopCountAnalysis(model=mod, inputs=torch.rand(1, 3))
        analyzer.unsupported_ops_warnings(enabled=False)

        logger = logging.getLogger()
        with self.assertLogs(logger, logging.WARN) as cm:
            logger.warning("Dummy warning.")
            _ = analyzer.total()
        uncalled_string = "Module never called: submod"
        self.assertFalse(any(uncalled_string in s for s in cm.output))
Exemple #13
0
    def test_non_forward_func_call(self) -> None:
        """
        Tests that calls to a submodule's non-forward function attribute
        resulting counts to the calling module. Also tests that the
        intermediate module is correctly identified as a skipped module.
        """

        model = NonForwardNet()
        inputs = (torch.randn((1, 10)),)
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)

        submod_count = 0
        inner_fc_count = model.submod.fc_flops
        total_count = model.fc_flops + inner_fc_count

        self.assertEqual(analyzer.total("submod"), submod_count)
        self.assertEqual(analyzer.total("submod.fc"), inner_fc_count)
        self.assertEqual(analyzer.total(""), total_count)

        # The mod not directly called is registered as such
        self.assertEqual(analyzer.uncalled_modules(), {"submod"})
    def test_repeated_module(self) -> None:
        """
        Tests that repeated calls to the same submodule correct aggregates
        results to that submodule.
        """

        model = RepeatedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        fc1_count = model.fc1_num * model.fc1_flops
        fc2_count = model.fc2_num * model.fc2_flops
        total_count = fc1_count + fc2_count
        fc1_per_operator = Counter({self.lin_op: fc1_count})

        self.assertEqual(analyzer.total("fc1"), fc1_count)
        self.assertEqual(analyzer.total("fc2"), fc2_count)
        self.assertEqual(analyzer.total(""), total_count)
        self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator)

        # Tests no uncalled mods
        self.assertEqual(analyzer.uncalled_modules(), set())
    def test_unused_module(self) -> None:
        """
        Tests that unused modules return 0 count for operator sums and
        and empty Counter() for per-operator results. Also tests that
        unused modules are reported by .uncalled_modules(), but that
        modules that simply have zero flops (like ReLU) are not.
        """

        model = UnusedNet()
        inputs = (torch.randn((1, *model.input_size)), )
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)

        unused_count = 0
        unused_per_operator = Counter()
        model_count = model.fc1_flops + model.fc2_flops

        self.assertEqual(analyzer.total("unused"), unused_count)
        self.assertEqual(analyzer.by_operator("unused"), unused_per_operator)
        self.assertEqual(analyzer.total(""), model_count)

        # The unused mod is recognized as never called
        self.assertEqual(analyzer.uncalled_modules(), {"unused"})
    def test_disable_warnings(self) -> None:
        """
        Tests .unsupported_ops_warnings(...) and .tracer_warnings(...)
        """
        model = TraceWarningNet()
        inputs = (torch.randn((1, *model.input_size)), )
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)

        # Tracer warnings
        analyzer.tracer_warnings(mode="all")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)

        analyzer.tracer_warnings(mode="none")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(
                    torch.jit._trace.TracerWarning in warning_types)
                self.assertFalse(RuntimeWarning in warning_types)

        analyzer.tracer_warnings(mode="no_tracer_warning")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(
                    torch.jit._trace.TracerWarning in warning_types)

        # Unsupported ops and uncalled modules warnings

        logger = logging.getLogger()
        skipeed_msg = "Unsupported operator aten::add encountered 1 time(s)"
        uncalled_msg = "never called"
        uncalled_modules = "fc1"  # fc2 is called by chance

        analyzer.uncalled_modules_warnings(enabled=False)
        analyzer.unsupported_ops_warnings(enabled=False)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            logger.warning("Dummy warning.")
            _ = analyzer.total()
        self.assertFalse(any(skipeed_msg in s for s in cm.output))
        self.assertFalse(any(uncalled_msg in s for s in cm.output))

        analyzer.unsupported_ops_warnings(enabled=True)
        analyzer.uncalled_modules_warnings(enabled=True)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            _ = analyzer.total()
        self.assertTrue(any(skipeed_msg in s for s in cm.output))
        self.assertTrue(any(uncalled_msg in s for s in cm.output))
        self.assertTrue(any(uncalled_modules in s for s in cm.output))
import sys
sys.path.append('..')
import argparse
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.training.model_restore import load_model_and_checkpoint_files
from fvcore.nn.flop_count import _DEFAULT_SUPPORTED_OPS, FlopCountAnalysis, flop_count
import numpy as np
import torch
import os
join = os.path.join

parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres", default="3d_fullres", required=False)
args = parser.parse_args()
model = args.model

model_path = join('./data/RESULTS_FOLDER/nnUNet', model, 'Task000_FLARE21Baseline/nnUNetTrainerV2__nnUNetPlansv2.1')
trainer, params = load_model_and_checkpoint_files(model_path, folds='all', checkpoint_name='model_final_checkpoint')
pkl_file = join(model_path, "all/model_final_checkpoint.model.pkl")
info = load_pickle(pkl_file)
if model == '2d' or model == '3d_lowres':
    patch_size = info['plans']['plans_per_stage'][0]['patch_size']
else:
    patch_size = info['plans']['plans_per_stage'][1]['patch_size']
patch_size = np.append(np.array(1), patch_size)

inputs = (torch.randn(tuple(np.append(np.array(1),patch_size))).cuda(),)
flops = FlopCountAnalysis(trainer.network, inputs)
print('Total FLOPs:', flops.total())