コード例 #1
0
    def test_unsupported_ops(self) -> None:
        """
        Tests per-module recording of unsupported operations.
        """

        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle(
            "aten::addmm",
            addmm_flop_jit,
            "aten::linear",
            linear_flop_jit,
        )
        analyzer.total()

        skipped_inner_conv = Counter({"aten::_convolution": 1})
        skipped_inner_fc = Counter()
        skipped_inner = Counter({"aten::add": 1, "aten::mul": 1})
        skipped_inner += skipped_inner_fc
        skipped_inner += skipped_inner_conv

        skipped_outer_conv = Counter({"aten::_convolution": 1})
        skipped_outer_fc = Counter()
        skipped_outer = Counter({"aten::pow": 1})
        skipped_outer += skipped_outer_conv
        skipped_outer += skipped_outer_fc
        skipped_outer += skipped_inner

        skipped = {
            "": skipped_outer,
            "conv": skipped_outer_conv,
            "fc": skipped_outer_fc,
            "submod": skipped_inner,
            "submod.conv": skipped_inner_conv,
            "submod.fc": skipped_inner_fc,
        }

        # Access by string
        for name in skipped:
            with self.subTest(name=name):
                self.assertEqual(analyzer.unsupported_ops(name), skipped[name])
コード例 #2
0
    def test_skipped_ops(self) -> None:
        """
        Tests per-module recording of skipped operations.
        """

        model = NestedNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)

        skipped_inner_conv = Counter({"aten::_convolution": 1})
        skipped_inner_fc = Counter()
        skipped_inner = Counter({"aten::add": 1, "aten::mul": 1})
        skipped_inner += skipped_inner_fc
        skipped_inner += skipped_inner_conv

        skipped_outer_conv = Counter({"aten::_convolution": 1})
        skipped_outer_fc = Counter()
        skipped_outer = Counter({"aten::pow": 1})
        skipped_outer += skipped_outer_conv
        skipped_outer += skipped_outer_fc
        skipped_outer += skipped_inner

        skipped = {
            "": skipped_outer,
            "conv": skipped_outer_conv,
            "fc": skipped_outer_fc,
            "submod": skipped_inner,
            "submod.conv": skipped_inner_conv,
            "submod.fc": skipped_inner_fc,
        }

        # Access by string
        for name in skipped:
            with self.subTest(name=name):
                self.assertEqual(analyzer.skipped_ops(name), skipped[name])
コード例 #3
0
    def test_non_forward_func_call(self) -> None:
        """
        Tests that calls to a submodule's non-forward function attribute
        resulting counts to the calling module. Also tests that the
        intermediate module is correctly identified as a skipped module.
        """

        model = NonForwardNet()
        inputs = (torch.randn((1, 10)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)

        submod_count = 0
        inner_fc_count = model.submod.fc_flops
        total_count = model.fc_flops + inner_fc_count

        self.assertEqual(analyzer.total("submod"), submod_count)
        self.assertEqual(analyzer.total("submod.fc"), inner_fc_count)
        self.assertEqual(analyzer.total(""), total_count)

        # The mod not directly called is registered as such
        self.assertEqual(analyzer.uncalled_modules(), {"submod"})
コード例 #4
0
    def test_repeated_module(self) -> None:
        """
        Tests that repeated calls to the same submodule correct aggregates
        results to that submodule.
        """

        model = RepeatedNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)

        fc1_count = model.fc1_num * model.fc1_flops
        fc2_count = model.fc2_num * model.fc2_flops
        total_count = fc1_count + fc2_count
        fc1_per_operator = Counter({"addmm": fc1_count})

        self.assertEqual(analyzer.total("fc1"), fc1_count)
        self.assertEqual(analyzer.total("fc2"), fc2_count)
        self.assertEqual(analyzer.total(""), total_count)
        self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator)

        # Tests no uncalled mods
        self.assertEqual(analyzer.uncalled_modules(), set())
コード例 #5
0
    def test_unused_module(self) -> None:
        """
        Tests that unused modules return 0 count for operator sums and
        and empty Counter() for per-operator results. Also tests that
        unused modules are reported by .uncalled_modules(), but that
        modules that simply have zero flops (like ReLU) are not.
        """

        model = UnusedNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)

        unused_count = 0
        unused_per_operator = Counter()
        model_count = model.fc1_flops + model.fc2_flops

        self.assertEqual(analyzer.total("unused"), unused_count)
        self.assertEqual(analyzer.by_operator("unused"), unused_per_operator)
        self.assertEqual(analyzer.total(""), model_count)

        # The unused mod is recognized as never called
        self.assertEqual(analyzer.uncalled_modules(), {"unused"})
コード例 #6
0
    def test_copy(self) -> None:
        """
        Tests .copy(...)
        """

        model = RepeatedNet()
        inputs = (torch.randn((1, *model.input_size)), )

        analyzer = (JitModelAnalysis(model=model, inputs=inputs).set_op_handle(
            "aten::addmm",
            addmm_flop_jit,
            "aten::linear",
            linear_flop_jit,
        ).unsupported_ops_warnings(enabled=False).tracer_warnings(mode="none"))

        repeated_net_flops = model.fc1_num * model.fc1_flops
        repeated_net_flops += model.fc2_num * model.fc2_flops

        analyzer_copy = analyzer.copy()

        # Outputs are the same
        self.assertEqual(
            analyzer.by_module_and_operator(),
            analyzer_copy.by_module_and_operator(),
        )

        # Settings match
        self.assertEqual(
            analyzer._enable_warn_unsupported_ops,
            analyzer_copy._enable_warn_unsupported_ops,
        )
        self.assertEqual(
            analyzer._enable_warn_uncalled_mods,
            analyzer_copy._enable_warn_uncalled_mods,
        )
        self.assertEqual(analyzer._warn_trace, analyzer_copy._warn_trace)

        # Changing copy does not change original
        analyzer_copy.unsupported_ops_warnings(enabled=True)
        self.assertNotEqual(
            analyzer._enable_warn_unsupported_ops,
            analyzer_copy._enable_warn_unsupported_ops,
        )

        # Copy with new model and inputs
        new_model = NonForwardNet()
        bs = 5
        new_inputs = (torch.randn((bs, *new_model.input_size)), )
        analyzer_new = analyzer.copy(new_model=new_model,
                                     new_inputs=new_inputs)

        non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops

        # Total is correct for new model and inputs
        self.assertEqual(analyzer_new.total(), non_forward_flops * bs)

        # Original is unaffected
        self.assertEqual(analyzer.total(), repeated_net_flops)

        # Settings match
        self.assertEqual(
            analyzer._enable_warn_unsupported_ops,
            analyzer_new._enable_warn_unsupported_ops,
        )
        self.assertEqual(analyzer._warn_trace, analyzer_new._warn_trace)
コード例 #7
0
    def test_changing_handles(self) -> None:
        """
        Tests .set_op_handle(), .clear_op_handles()
        """
        model = NestedNet(lin_op=self.lin_op)
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
            "aten::linear": linear_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs).set_op_handle(**op_handles)
        analyzer.unsupported_ops_warnings(enabled=False)

        # Request a result once to cache flop counts
        _ = analyzer.total("")

        # Add an op handle
        analyzer.set_op_handle("aten::_convolution", conv_flop_jit)

        self.assertEqual(analyzer.by_module_and_operator(), model.flops)

        # Overwrite an op handle
        def make_dummy_op(name: str, output: int) -> Handle:
            def dummy_ops_handle(inputs: List[Any],
                                 outputs: List[Any]) -> typing.Counter[str]:
                return Counter({name: output})

            return dummy_ops_handle

        dummy_name = "dummy_op"
        dummy_out = 1000
        analyzer.set_op_handle("aten::{}".format(self.lin_op),
                               make_dummy_op(dummy_name, dummy_out))

        dummy_flops = {}
        for name, counts in model.flops.items():
            dummy_flops[name] = Counter(
                {op: flop
                 for op, flop in counts.items() if op != self.lin_op})
        dummy_flops[""][dummy_name] = 2 * dummy_out
        dummy_flops["fc"][dummy_name] = dummy_out
        dummy_flops["submod"][dummy_name] = dummy_out
        dummy_flops["submod.fc"][dummy_name] = dummy_out

        self.assertEqual(analyzer.by_module_and_operator(), dummy_flops)

        # Clear ops handles
        analyzer.clear_op_handles()

        empty_flops = {name: Counter() for name in model.flops}

        self.assertEqual(analyzer.by_module_and_operator(), empty_flops)
コード例 #8
0
    def test_disable_warnings(self) -> None:
        """
        Tests .skipped_ops_warnings(...) and .tracer_warnings(...)
        """
        model = TraceWarningNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)

        # Tracer warnings
        analyzer.tracer_warnings(mode="all")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)

        analyzer.tracer_warnings(mode="none")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(
                    torch.jit._trace.TracerWarning in warning_types)
                self.assertFalse(RuntimeWarning in warning_types)

        analyzer.tracer_warnings(mode="no_tracer_warning")
        analyzer._stats = None  # Manually clear cache so trace is rerun
        self.assertWarns(RuntimeWarning, analyzer.total)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            _ = analyzer.total()
            if w:
                warning_types = [s.category for s in w]
                self.assertFalse(
                    torch.jit._trace.TracerWarning in warning_types)

        # Skipped ops and uncalled modules warnings

        logger = logging.getLogger()
        skipped_string = "Skipped operation aten::add 1 time(s)"
        uncalled_string = "Module never called: fc1"

        analyzer.uncalled_modules_warnings(enabled=False)
        analyzer.skipped_ops_warnings(enabled=False)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            logger.warning("Dummy warning.")
            _ = analyzer.total()
        self.assertFalse(any(skipped_string in s for s in cm.output))
        self.assertFalse(any(uncalled_string in s for s in cm.output))

        analyzer.skipped_ops_warnings(enabled=True)
        analyzer.uncalled_modules_warnings(enabled=True)
        analyzer._stats = None  # Manually clear cache so trace is rerun
        with self.assertLogs(logger, logging.WARN) as cm:
            _ = analyzer.total()
        self.assertTrue(any(skipped_string in s for s in cm.output))
        self.assertTrue(any(uncalled_string in s for s in cm.output))
コード例 #9
0
    def test_shared_module(self) -> None:
        """
        Tests the behavior of shared submodules that may have multiple
        names.
        """

        model = SharedModuleNet()
        inputs = (torch.randn((1, *model.input_size)), )
        op_handles = {
            "aten::addmm": addmm_flop_jit,
        }  # type: Dict[str, Handle]

        analyzer = JitModelAnalysis(model=model,
                                    inputs=inputs,
                                    op_handles=op_handles)
        analyzer.skipped_ops_warnings(enabled=False)

        # The names `submod2.submod` and `multiname2` are not included,
        # since only the first name of a module is made the canonical one.
        # The counts associated with these cases are included under
        # `submod1.submod` and `multiname1` respectively.
        multiname_flops = 2 * model.multiname_flops  # Called under 2 names
        shared_flops = 2 * model.shared_flops  # Shared under 2 submodules
        total_flops = multiname_flops + shared_flops
        flops = {
            "": total_flops,
            "submod1": model.shared_flops,
            "submod1.submod": shared_flops,
            "submod2": model.shared_flops,
            "multiname1": multiname_flops,
        }

        self.assertEqual(analyzer.by_module(), flops)

        # Test access by alternative name
        self.assertEqual(
            analyzer.total("submod2.submod"),
            flops["submod1.submod"],
        )
        self.assertEqual(
            analyzer.total("multiname2"),
            flops["multiname1"],
        )

        # Test getting canonical name
        self.assertEqual(analyzer.canonical_module_name("multiname2"),
                         "multiname1")
        self.assertEqual(analyzer.canonical_module_name("multiname1"),
                         "multiname1")
        self.assertEqual(analyzer.canonical_module_name("submod2.submod"),
                         "submod1.submod")
        self.assertEqual(analyzer.canonical_module_name("submod1.submod"),
                         "submod1.submod")

        # Tests no uncalled modules
        self.assertEqual(analyzer.uncalled_modules(), set())