def test_data_parallel(self) -> None:
        """
        Tests that a model wrapped in DataParallel still returns results
        labeled by the correct scopes.
        """
        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        # Find flops for wrapper
        flops = {
            "module" + ("." if name else "") + name: flop
            for name, flop in model.flops.items()
        }
        flops[""] = model.flops[""]
        name_to_module = {
            "module" + ("." if name else "") + name: mod
            for name, mod in model.name_to_module.items()
        }
        name_to_module[""] = model.name_to_module[""]

        model = torch.nn.DataParallel(model).cpu()
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        # Using a string input
        for name in flops:
            with self.subTest(name=name):
                gt_flops = sum(flops[name].values())
                self.assertEqual(analyzer.total(name), gt_flops)

        # Output as dictionary
        self.assertEqual(analyzer.by_module_and_operator(), flops)

        # Test no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())
Beispiel #2
0
    def test_disable_warnings(self) -> None:
        """
        Tests .unsupported_ops_warnings(...) and .tracer_warnings(...)
        """
        model = TraceWarningNet()
        inputs = (torch.randn((1, *model.input_size)),)
        analyzer = FlopCountAnalysis(model=model, inputs=inputs)

        # Tracer warnings
        analyzer.tracer_warnings(mode="all")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)

        analyzer.tracer_warnings(mode="none")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(torch.jit._trace.TracerWarning in warning_types)
                self.assertFalse(RuntimeWarning in warning_types)

        analyzer.tracer_warnings(mode="no_tracer_warning")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(torch.jit._trace.TracerWarning in warning_types)

        # Unsupported ops and uncalled modules warnings

        logger = logging.getLogger()
        skipeed_msg = "Unsupported operator aten::add encountered 1 time(s)"
        uncalled_msg = "never called"
        uncalled_modules = "fc1"  # fc2 is called by chance

        analyzer.uncalled_modules_warnings(enabled=False)
        analyzer.unsupported_ops_warnings(enabled=False)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            logger.warning("Dummy warning.")
            _ = analyzer.total()
        self.assertFalse(any(skipeed_msg in s for s in cm.output))
        self.assertFalse(any(uncalled_msg in s for s in cm.output))

        analyzer.unsupported_ops_warnings(enabled=True)
        analyzer.uncalled_modules_warnings(enabled=True)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            _ = analyzer.total()
        self.assertTrue(any(skipeed_msg in s for s in cm.output))
        self.assertTrue(any(uncalled_msg in s for s in cm.output))
        self.assertTrue(any(uncalled_modules in s for s in cm.output))
    def test_by_module_and_operator(self) -> None:
        """
        Tests that JitModelAnalysis.by_module_and_operator() returns
        the correct counts in the correct structure.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        self.assertEqual(analyzer.by_module_and_operator(), model.flops)
Beispiel #4
0
    def test_by_module(self) -> None:
        """
        Tests that JitModelAnalysis.by_module() returns the correct
        counts in the correctly structured dictionary.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)),)

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        flops = {name: sum(counts.values()) for name, counts in model.flops.items()}

        self.assertEqual(analyzer.by_module(), flops)
    def test_by_operator(self) -> None:
        """
        Tests that JitModelAnalysis.by_operator(module) returns the correct
        counts for string and module inputs.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = FlopCountAnalysis(model=model, inputs=inputs)
        analyzer.unsupported_ops_warnings(enabled=False)

        # Using a string input
        for name in model.flops:
            with self.subTest(name=name):
                self.assertEqual(analyzer.by_operator(name), model.flops[name])
    def test_skip_uncalled_containers_warnings(self) -> None:
        # uncalled containers should not warn

        class A(nn.Module):
            def forward(self, x):
                return self.submod[0](x) + 1

        mod = A()
        mod.submod = nn.ModuleList([nn.Linear(3, 3)])  # pyre-ignore
        analyzer = FlopCountAnalysis(model=mod, inputs=torch.rand(1, 3))
        analyzer.unsupported_ops_warnings(enabled=False)

        logger = logging.getLogger()
        with self.assertLogs(logger, logging.WARN) as cm:
            logger.warning("Dummy warning.")
            _ = analyzer.total()
        uncalled_string = "Module never called: submod"
        self.assertFalse(any(uncalled_string in s for s in cm.output))