def test_check_jax_usage(self): cfg = config.get_config() config.check_jax_usage() self.assertTrue(cfg.check_jax_usage) config.check_jax_usage(False) self.assertFalse(cfg.check_jax_usage) config.check_jax_usage(True) self.assertTrue(cfg.check_jax_usage)
def setUp(self): super().setUp() self._prev_check_jax_usage = config.check_jax_usage(True)
def tearDown(self): super().tearDown() config.check_jax_usage(self._prev_check_jax_usage)
def wrapper(*a, **k): old = config.check_jax_usage(True) try: return f(*a, **k) finally: config.check_jax_usage(old)