예제 #1
0
    def test_disable_warnings(self) -> None:
        """
        Tests .skipped_ops_warnings(...) and .tracer_warnings(...)
        """
        model = TraceWarningNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)

        # Tracer warnings
        analyzer.tracer_warnings(mode="all")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)

        analyzer.tracer_warnings(mode="none")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(
                    torch.jit._trace.TracerWarning in warning_types)
                self.assertFalse(RuntimeWarning in warning_types)

        analyzer.tracer_warnings(mode="no_tracer_warning")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(
                    torch.jit._trace.TracerWarning in warning_types)

        # Skipped ops and uncalled modules warnings

        logger = logging.getLogger()
        skipped_string = "Skipped operation aten::add 1 time(s)"
        uncalled_string = "Module never called: fc1"

        analyzer.uncalled_modules_warnings(enabled=False)
        analyzer.skipped_ops_warnings(enabled=False)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            logger.warning("Dummy warning.")
            _ = analyzer.total()
        self.assertFalse(any(skipped_string in s for s in cm.output))
        self.assertFalse(any(uncalled_string in s for s in cm.output))

        analyzer.skipped_ops_warnings(enabled=True)
        analyzer.uncalled_modules_warnings(enabled=True)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            _ = analyzer.total()
        self.assertTrue(any(skipped_string in s for s in cm.output))
        self.assertTrue(any(uncalled_string in s for s in cm.output))
예제 #2
0
    def test_copy(self) -> None:
        """
        Tests .copy(...)
        """

        model = RepeatedNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)
        analyzer.skipped_ops_warnings(enabled=False)
        analyzer.tracer_warnings(mode="none")

        repeated_net_flops = model.fc1_num * model.fc1_flops
        repeated_net_flops += model.fc2_num * model.fc2_flops

        analyzer_copy = analyzer.copy()

        # Outputs are the same
        self.assertEqual(
            analyzer.by_module_and_operator(),
            analyzer_copy.by_module_and_operator(),
        )

        # Settings match
        self.assertEqual(analyzer._enable_warn_skipped_ops,
                         analyzer_copy._enable_warn_skipped_ops)
        self.assertEqual(
            analyzer._enable_warn_uncalled_mods,
            analyzer_copy._enable_warn_uncalled_mods,
        )
        self.assertEqual(analyzer._warn_trace, analyzer_copy._warn_trace)

        # Changing copy does not change original
        analyzer_copy.skipped_ops_warnings(enabled=True)
        self.assertNotEqual(analyzer._enable_warn_skipped_ops,
                            analyzer_copy._enable_warn_skipped_ops)

        # Copy with new model and inputs
        new_model = NonForwardNet()
        bs = 5
        new_inputs = (torch.randn((bs, *new_model.input_size)), )
        analyzer_new = analyzer.copy(new_model=new_model,
                                     new_inputs=new_inputs)

        non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops

        # Total is correct for new model and inputs
        self.assertEqual(analyzer_new.total(), non_forward_flops * bs)

        # Original is unaffected
        self.assertEqual(analyzer.total(), repeated_net_flops)

        # Settings match
        self.assertEqual(analyzer._enable_warn_skipped_ops,
                         analyzer_new._enable_warn_skipped_ops)
        self.assertEqual(analyzer._warn_trace, analyzer_new._warn_trace)