def test_disable_warnings(self) -> None: """ Tests .skipped_ops_warnings(...) and .tracer_warnings(...) """ model = TraceWarningNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) # Tracer warnings analyzer.tracer_warnings(mode="all") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(torch.jit._trace.TracerWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer.tracer_warnings(mode="none") analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse( torch.jit._trace.TracerWarning in warning_types) self.assertFalse(RuntimeWarning in warning_types) analyzer.tracer_warnings(mode="no_tracer_warning") analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse( torch.jit._trace.TracerWarning in warning_types) # Skipped ops and uncalled modules warnings logger = logging.getLogger() skipped_string = "Skipped operation aten::add 1 time(s)" uncalled_string = "Module never called: fc1" analyzer.uncalled_modules_warnings(enabled=False) analyzer.skipped_ops_warnings(enabled=False) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: logger.warning("Dummy warning.") _ = analyzer.total() self.assertFalse(any(skipped_string in s for s in cm.output)) self.assertFalse(any(uncalled_string in s for s in cm.output)) analyzer.skipped_ops_warnings(enabled=True) analyzer.uncalled_modules_warnings(enabled=True) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: _ = analyzer.total() self.assertTrue(any(skipped_string in s for s in cm.output)) self.assertTrue(any(uncalled_string in s for s in cm.output))
def test_copy(self) -> None: """ Tests .copy(...) """ model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) op_handles = { "aten::addmm": addmm_flop_jit, } # type: Dict[str, Handle] analyzer = JitModelAnalysis(model=model, inputs=inputs, op_handles=op_handles) analyzer.skipped_ops_warnings(enabled=False) analyzer.tracer_warnings(mode="none") repeated_net_flops = model.fc1_num * model.fc1_flops repeated_net_flops += model.fc2_num * model.fc2_flops analyzer_copy = analyzer.copy() # Outputs are the same self.assertEqual( analyzer.by_module_and_operator(), analyzer_copy.by_module_and_operator(), ) # Settings match self.assertEqual(analyzer._enable_warn_skipped_ops, analyzer_copy._enable_warn_skipped_ops) self.assertEqual( analyzer._enable_warn_uncalled_mods, analyzer_copy._enable_warn_uncalled_mods, ) self.assertEqual(analyzer._warn_trace, analyzer_copy._warn_trace) # Changing copy does not change original analyzer_copy.skipped_ops_warnings(enabled=True) self.assertNotEqual(analyzer._enable_warn_skipped_ops, analyzer_copy._enable_warn_skipped_ops) # Copy with new model and inputs new_model = NonForwardNet() bs = 5 new_inputs = (torch.randn((bs, *new_model.input_size)), ) analyzer_new = analyzer.copy(new_model=new_model, new_inputs=new_inputs) non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops # Total is correct for new model and inputs self.assertEqual(analyzer_new.total(), non_forward_flops * bs) # Original is unaffected self.assertEqual(analyzer.total(), repeated_net_flops) # Settings match self.assertEqual(analyzer._enable_warn_skipped_ops, analyzer_new._enable_warn_skipped_ops) self.assertEqual(analyzer._warn_trace, analyzer_new._warn_trace)