def test_module_namescope_setting_unchanged(self, flag): current_setting = module.modules_with_named_call try: module.profiler_name_scopes(enabled=flag) _ = dot.to_dot(lambda x: x)(jnp.ones((1, 1))) self.assertEqual(module.modules_with_named_call, flag) finally: module.profiler_name_scopes(enabled=current_setting)
def test_no_namescopes_inside_dot(self): mod = AddModule() current_setting = module.modules_with_named_call try: module.profiler_name_scopes(enabled=True) with mock.patch.object(stateful, "named_call") as mock_f: _ = dot.to_dot(mod)(1, 1) mock_f.assert_not_called() finally: module.profiler_name_scopes(enabled=current_setting)