Пример #1
0
    def test_non_forward_func_call(self) -> None:
        """
        Tests calls to a submodule's non-forward function.
        Also tests that the intermediate module is correctly identified as a skipped module.
        """

        model = NonForwardNet()
        inputs = (torch.randn((1, 10)), )
        analyzer = FlopCountAnalysis(model=model,
                                     inputs=inputs).ancestor_mode("caller")

        inner_fc_count = model.submod.fc_flops
        total_count = model.fc_flops + inner_fc_count

        self.assertEqual(analyzer.total("submod"), 0)
        self.assertEqual(analyzer.total("submod.fc"), inner_fc_count)
        self.assertEqual(analyzer.total(""), total_count)

        # The mod not directly called is registered as such
        self.assertEqual(analyzer.uncalled_modules(), {"submod"})

        analyzer = FlopCountAnalysis(model=model,
                                     inputs=inputs).ancestor_mode("owner")
        self.assertEqual(analyzer.total("submod"), inner_fc_count)
        self.assertEqual(analyzer.total("submod.fc"), inner_fc_count)
        self.assertEqual(analyzer.total(""), total_count)
        self.assertEqual(analyzer.uncalled_modules(), set())
Пример #2
0
    def test_data_parallel(self) -> None:
        """
        Tests that a model wrapped in DataParallel still returns results
        labeled by the correct scopes.
        """
        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        # Find flops for wrapper
        flops = {
            "module" + ("." if name else "") + name: flop
            for name, flop in model.flops.items()
        }
        flops[""] = model.flops[""]
        name_to_module = {
            "module" + ("." if name else "") + name: mod
            for name, mod in model.name_to_module.items()
        }
        name_to_module[""] = model.name_to_module[""]

        model = torch.nn.DataParallel(model).cpu()
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        # Using a string input
        for name in flops:
            with self.subTest(name=name):
                gt_flops = sum(flops[name].values())
                self.assertEqual(analyzer.total(name), gt_flops)

        # Output as dictionary
        self.assertEqual(analyzer.by_module_and_operator(), flops)

        # Test no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Пример #3
0
    def test_shared_module(self) -> None:
        """
        Tests the behavior of shared submodules that may have multiple
        names.
        """

        model = SharedModuleNet()
        inputs = (torch.randn((1, *model.input_size)),)

        analyzer = FlopCountAnalysis(
            model=model, inputs=inputs
        ).unsupported_ops_warnings(enabled=False)

        # The names `submod2.submod` and `multiname2` are not included,
        # since only the first name of a module is made the canonical one.
        # The counts associated with these cases are included under
        # `submod1.submod` and `multiname1` respectively.
        multiname_flops = 2 * model.multiname_flops  # Called under 2 names
        shared_flops = 2 * model.shared_flops  # Shared under 2 submodules
        total_flops = multiname_flops + shared_flops
        flops = {
            "": total_flops,
            "submod1": model.shared_flops,
            "submod1.submod": shared_flops,
            "submod2": model.shared_flops,
            "multiname1": multiname_flops,
        }

        self.assertEqual(analyzer.by_module(), flops)

        # Test access by alternative name
        self.assertEqual(
            analyzer.total("submod2.submod"),
            flops["submod1.submod"],
        )
        self.assertEqual(
            analyzer.total("multiname2"),
            flops["multiname1"],
        )

        # Test getting canonical name
        self.assertEqual(analyzer.canonical_module_name("multiname2"), "multiname1")
        self.assertEqual(analyzer.canonical_module_name("multiname1"), "multiname1")
        self.assertEqual(
            analyzer.canonical_module_name("submod2.submod"), "submod1.submod"
        )
        self.assertEqual(
            analyzer.canonical_module_name("submod1.submod"), "submod1.submod"
        )

        # Tests no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Пример #4
0
    def test_recursive_scope(self) -> None:
        """
        Tests that an op is only counted once per module, even if it is
        in the scope of that module multiple times.
        """
        model = RecursiveScopeNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model, inputs)

        self.assertEqual(analyzer.total(), model.flops)
        self.assertEqual(analyzer.total("fc"), model.flops)

        # Tests no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Пример #5
0
    def test_repeated_module(self) -> None:
        """
        Tests that repeated calls to the same submodule correct aggregates
        results to that submodule.
        """

        model = RepeatedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        fc1_count = model.fc1_num * model.fc1_flops
        fc2_count = model.fc2_num * model.fc2_flops
        total_count = fc1_count + fc2_count
        fc1_per_operator = Counter({self.lin_op: fc1_count})

        self.assertEqual(analyzer.total("fc1"), fc1_count)
        self.assertEqual(analyzer.total("fc2"), fc2_count)
        self.assertEqual(analyzer.total(""), total_count)
        self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator)

        # Tests no uncalled mods
        self.assertEqual(analyzer.uncalled_modules(), set())
Пример #6
0
    def test_unused_module(self) -> None:
        """
        Tests that unused modules return 0 count for operator sums and
        and empty Counter() for per-operator results. Also tests that
        unused modules are reported by .uncalled_modules(), but that
        modules that simply have zero flops (like ReLU) are not.
        """

        model = UnusedNet()
        inputs = (torch.randn((1, *model.input_size)), )
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)

        unused_count = 0
        unused_per_operator = Counter()
        model_count = model.fc1_flops + model.fc2_flops

        self.assertEqual(analyzer.total("unused"), unused_count)
        self.assertEqual(analyzer.by_operator("unused"), unused_per_operator)
        self.assertEqual(analyzer.total(""), model_count)

        # The unused mod is recognized as never called
        self.assertEqual(analyzer.uncalled_modules(), {"unused"})