def test_unsupported_ops(self) -> None: """ Tests per-module recording of unsupported operations. """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle( "aten::addmm", addmm_flop_jit, "aten::linear", linear_flop_jit, ) analyzer.total() skipped_inner_conv = Counter({"aten::_convolution": 1}) skipped_inner_fc = Counter() skipped_inner = Counter({"aten::add": 1, "aten::mul": 1}) skipped_inner += skipped_inner_fc skipped_inner += skipped_inner_conv skipped_outer_conv = Counter({"aten::_convolution": 1}) skipped_outer_fc = Counter() skipped_outer = Counter({"aten::pow": 1}) skipped_outer += skipped_outer_conv skipped_outer += skipped_outer_fc skipped_outer += skipped_inner skipped = { "": skipped_outer, "conv": skipped_outer_conv, "fc": skipped_outer_fc, "submod": skipped_inner, "submod.conv": skipped_inner_conv, "submod.fc": skipped_inner_fc, } # Access by string for name in skipped: with self.subTest(name=name): self.assertEqual(analyzer.unsupported_ops(name), skipped[name])
def test_skipped_ops(self) -> None: """ Tests per-module recording of skipped operations. """ model = NestedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) skipped_inner_conv = Counter({"aten::_convolution": 1}) skipped_inner_fc = Counter() skipped_inner = Counter({"aten::add": 1, "aten::mul": 1}) skipped_inner += skipped_inner_fc skipped_inner += skipped_inner_conv skipped_outer_conv = Counter({"aten::_convolution": 1}) skipped_outer_fc = Counter() skipped_outer = Counter({"aten::pow": 1}) skipped_outer += skipped_outer_conv skipped_outer += skipped_outer_fc skipped_outer += skipped_inner skipped = { "": skipped_outer, "conv": skipped_outer_conv, "fc": skipped_outer_fc, "submod": skipped_inner, "submod.conv": skipped_inner_conv, "submod.fc": skipped_inner_fc, } # Access by string for name in skipped: with self.subTest(name=name): self.assertEqual(analyzer.skipped_ops(name), skipped[name])
def test_non_forward_func_call(self) -> None: """ Tests that calls to a submodule's non-forward function attribute resulting counts to the calling module. Also tests that the intermediate module is correctly identified as a skipped module. """ model = NonForwardNet() inputs = (torch.randn((1, 10)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) submod_count = 0 inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count self.assertEqual(analyzer.total("submod"), submod_count) self.assertEqual(analyzer.total("submod.fc"), inner_fc_count) self.assertEqual(analyzer.total(""), total_count) # The mod not directly called is registered as such self.assertEqual(analyzer.uncalled_modules(), {"submod"})
def test_repeated_module(self) -> None: """ Tests that repeated calls to the same submodule correct aggregates results to that submodule. """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) fc1_count = model.fc1_num * model.fc1_flops fc2_count = model.fc2_num * model.fc2_flops total_count = fc1_count + fc2_count fc1_per_operator = Counter({"addmm": fc1_count}) self.assertEqual(analyzer.total("fc1"), fc1_count) self.assertEqual(analyzer.total("fc2"), fc2_count) self.assertEqual(analyzer.total(""), total_count) self.assertEqual(analyzer.by_operator("fc1"), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set())
def test_unused_module(self) -> None: """ Tests that unused modules return 0 count for operator sums and and empty Counter() for per-operator results. Also tests that unused modules are reported by .uncalled_modules(), but that modules that simply have zero flops (like ReLU) are not. """ model = UnusedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) unused_count = 0 unused_per_operator = Counter() model_count = model.fc1_flops + model.fc2_flops self.assertEqual(analyzer.total("unused"), unused_count) self.assertEqual(analyzer.by_operator("unused"), unused_per_operator) self.assertEqual(analyzer.total(""), model_count) # The unused mod is recognized as never called self.assertEqual(analyzer.uncalled_modules(), {"unused"})
def test_copy(self) -> None: """ Tests .copy(...) """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = (JitModelAnalysis(model=model, inputs=inputs).set_op_handle( "aten::addmm", addmm_flop_jit, "aten::linear", linear_flop_jit, ).unsupported_ops_warnings(enabled=False).tracer_warnings(mode="none")) repeated_net_flops = model.fc1_num * model.fc1_flops repeated_net_flops += model.fc2_num * model.fc2_flops analyzer_copy = analyzer.copy() # Outputs are the same self.assertEqual( analyzer.by_module_and_operator(), analyzer_copy.by_module_and_operator(), ) # Settings match self.assertEqual( analyzer._enable_warn_unsupported_ops, analyzer_copy._enable_warn_unsupported_ops, ) self.assertEqual( analyzer._enable_warn_uncalled_mods, analyzer_copy._enable_warn_uncalled_mods, ) self.assertEqual(analyzer._warn_trace, analyzer_copy._warn_trace) # Changing copy does not change original analyzer_copy.unsupported_ops_warnings(enabled=True) self.assertNotEqual( analyzer._enable_warn_unsupported_ops, analyzer_copy._enable_warn_unsupported_ops, ) # Copy with new model and inputs new_model = NonForwardNet() bs = 5 new_inputs = (torch.randn((bs, *new_model.input_size)), ) analyzer_new = analyzer.copy(new_model=new_model, new_inputs=new_inputs) non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops # Total is correct for new model and inputs self.assertEqual(analyzer_new.total(), non_forward_flops * bs) # Original is unaffected self.assertEqual(analyzer.total(), repeated_net_flops) # Settings match self.assertEqual( analyzer._enable_warn_unsupported_ops, analyzer_new._enable_warn_unsupported_ops, ) self.assertEqual(analyzer._warn_trace, analyzer_new._warn_trace)
def test_changing_handles(self) -> None: """ Tests .set_op_handle(), .clear_op_handles() """ model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, "aten::linear": linear_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle(**op_handles) analyzer.unsupported_ops_warnings(enabled=False) # Request a result once to cache flop counts _ = analyzer.total("") # Add an op handle analyzer.set_op_handle("aten::_convolution", conv_flop_jit) self.assertEqual(analyzer.by_module_and_operator(), model.flops) # Overwrite an op handle def make_dummy_op(name: str, output: int) -> Handle: def dummy_ops_handle(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: return Counter({name: output}) return dummy_ops_handle dummy_name = "dummy_op" dummy_out = 1000 analyzer.set_op_handle("aten::{}".format(self.lin_op), make_dummy_op(dummy_name, dummy_out)) dummy_flops = {} for name, counts in model.flops.items(): dummy_flops[name] = Counter( {op: flop for op, flop in counts.items() if op != self.lin_op}) dummy_flops[""][dummy_name] = 2 * dummy_out dummy_flops["fc"][dummy_name] = dummy_out dummy_flops["submod"][dummy_name] = dummy_out dummy_flops["submod.fc"][dummy_name] = dummy_out self.assertEqual(analyzer.by_module_and_operator(), dummy_flops) # Clear ops handles analyzer.clear_op_handles() empty_flops = {name: Counter() for name in model.flops} self.assertEqual(analyzer.by_module_and_operator(), empty_flops)
def test_disable_warnings(self) -> None: """ Tests .skipped_ops_warnings(...) and .tracer_warnings(...) """ model = TraceWarningNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) # Tracer warnings analyzer.tracer_warnings(mode="all") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer.tracer_warnings(mode="none") analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse( torch.jit._trace.TracerWarning in warning_types) self.assertFalse(RuntimeWarning in warning_types) analyzer.tracer_warnings(mode="no_tracer_warning") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse( torch.jit._trace.TracerWarning in warning_types) # Skipped ops and uncalled modules warnings logger = logging.getLogger() skipped_string = "Skipped operation aten::add 1 time(s)" uncalled_string = "Module never called: fc1" analyzer.uncalled_modules_warnings(enabled=False) analyzer.skipped_ops_warnings(enabled=False) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: logger.warning("Dummy warning.") _ = analyzer.total() self.assertFalse(any(skipped_string in s for s in cm.output)) self.assertFalse(any(uncalled_string in s for s in cm.output)) analyzer.skipped_ops_warnings(enabled=True) analyzer.uncalled_modules_warnings(enabled=True) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: _ = analyzer.total() self.assertTrue(any(skipped_string in s for s in cm.output)) self.assertTrue(any(uncalled_string in s for s in cm.output))
def test_shared_module(self) -> None: """ Tests the behavior of shared submodules that may have multiple names. """ model = SharedModuleNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) analyzer.skipped_ops_warnings(enabled=False) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. # The counts associated with these cases are included under # `submod1.submod` and `multiname1` respectively. multiname_flops = 2 * model.multiname_flops # Called under 2 names shared_flops = 2 * model.shared_flops # Shared under 2 submodules total_flops = multiname_flops + shared_flops flops = { "": total_flops, "submod1": model.shared_flops, "submod1.submod": shared_flops, "submod2": model.shared_flops, "multiname1": multiname_flops, } self.assertEqual(analyzer.by_module(), flops) # Test access by alternative name self.assertEqual( analyzer.total("submod2.submod"), flops["submod1.submod"], ) self.assertEqual( analyzer.total("multiname2"), flops["multiname1"], ) # Test getting canonical name self.assertEqual(analyzer.canonical_module_name("multiname2"), "multiname1") self.assertEqual(analyzer.canonical_module_name("multiname1"), "multiname1") self.assertEqual(analyzer.canonical_module_name("submod2.submod"), "submod1.submod") self.assertEqual(analyzer.canonical_module_name("submod1.submod"), "submod1.submod") # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set())