Ejemplo n.º 1
0
    def test_data_parallel(self) -> None:
        """
        Tests that a model wrapped in DataParallel still returns results
        labeled by the correct scopes.
        """
        model = NestedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        # Find flops for wrapper
        flops = {
            "module" + ("." if name else "") + name: flop
            for name, flop in model.flops.items()
        }
        flops[""] = model.flops[""]
        name_to_module = {
            "module" + ("." if name else "") + name: mod
            for name, mod in model.name_to_module.items()
        }
        name_to_module[""] = model.name_to_module[""]

        model = torch.nn.DataParallel(model)
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.skipped_ops_warnings(enabled=False)

        # Using a string input
        for name in flops:
            with self.subTest(name=name):
                gt_flops = sum(flops[name].values())
                self.assertEqual(analyzer.total(name), gt_flops)

        # Output as dictionary
        self.assertEqual(analyzer.by_module_and_operator(), flops)

        # Test no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Ejemplo n.º 2
0
    def test_by_module_and_operator(self) -> None:
        """
        Tests that JitModelAnalysis.by_module_and_operator() returns
        the correct counts in the correct structure.
        """

        model = NestedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.skipped_ops_warnings(enabled=False)

        self.assertEqual(analyzer.by_module_and_operator(), model.flops)
Ejemplo n.º 3
0
    def test_by_operator(self) -> None:
        """
        Tests that JitModelAnalysis.by_operator(module) returns the correct
        counts for string and module inputs.
        """

        model = NestedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.skipped_ops_warnings(enabled=False)

        # Using a string input
        for name in model.flops:
            with self.subTest(name=name):
                self.assertEqual(analyzer.by_operator(name), model.flops[name])
Ejemplo n.º 4
0
    def test_by_module(self) -> None:
        """
        Tests that JitModelAnalysis.by_module() returns the correct
        counts in the correctly structured dictionary.
        """

        model = NestedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.skipped_ops_warnings(enabled=False)

        flops = {
            name: sum(counts.values())
            for name, counts in model.flops.items()
        }

        self.assertEqual(analyzer.by_module(), flops)