Beispiel #1
0
 def test_module_namescope_setting_unchanged(self, flag):
     current_setting = module.modules_with_named_call
     try:
         module.profiler_name_scopes(enabled=flag)
         _ = dot.to_dot(lambda x: x)(jnp.ones((1, 1)))
         self.assertEqual(module.modules_with_named_call, flag)
     finally:
         module.profiler_name_scopes(enabled=current_setting)
Beispiel #2
0
 def test_no_namescopes_inside_dot(self):
   mod = AddModule()
   current_setting = module.modules_with_named_call
   try:
     module.profiler_name_scopes(enabled=True)
     with mock.patch.object(stateful, "named_call") as mock_f:
       _ = dot.to_dot(mod)(1, 1)
       mock_f.assert_not_called()
   finally:
     module.profiler_name_scopes(enabled=current_setting)