def test_inherits_default_from_main_thread(self, default): e1 = threading.Event() e2 = threading.Event() config.get_config().check_jax_usage = default def f(): self.assertEqual(config.get_config().check_jax_usage, default) config.get_config().check_jax_usage = True e1.set() e2.wait() self.assertTrue(config.get_config().check_jax_usage) def g(): e1.wait() self.assertEqual(config.get_config().check_jax_usage, default) config.get_config().check_jax_usage = False e2.set() self.assertFalse(config.get_config().check_jax_usage) with futures.ThreadPoolExecutor() as tpe: f1 = tpe.submit(g) f2 = tpe.submit(f) f2.result() f1.result() self.assertEqual(config.get_config().check_jax_usage, default)
def test_set(self): cfg = config.get_config() cfg.check_jax_usage = False config.set(check_jax_usage=True) self.assertTrue(cfg.check_jax_usage) config.set(check_jax_usage=False) self.assertFalse(cfg.check_jax_usage)
def __call__(cls: Type[T], *args, **kwargs) -> T: # pylint: disable=no-self-argument # Call new such that we have an un-initialized module instance that we can # still reference even if there is an exception during __init__. This is # needed such that we can make sure the name_scope constructed in __init__ # is closed even if there is an exception. # NOTE: We disable pytype since (somewhat surprisingly) this method is bound # with the new class and not the metaclass. module = cls.__new__(cls, *args, **kwargs) # pytype: disable=wrong-arg-types # Now attempt to initialize the object. init = wrap_method("__init__", cls.__init__) init(module, *args, **kwargs) if (config.get_config().module_auto_repr and getattr(module, "AUTO_REPR", True)): module._auto_repr = utils.auto_repr(cls, *args, **kwargs) # pylint: disable=protected-access else: module._auto_repr = object.__repr__(module) ran_super_ctor = hasattr(module, "module_name") if not ran_super_ctor: raise ValueError( "Constructing an hk.Module without calling the super constructor " "is not supported. Add the following as the first line in your " "__init__ method:\n\nsuper(%s, self).__init__()" % cls.__name__) return module
def test_check_jax_usage(self): cfg = config.get_config() config.check_jax_usage() self.assertTrue(cfg.check_jax_usage) config.check_jax_usage(False) self.assertFalse(cfg.check_jax_usage) config.check_jax_usage(True) self.assertTrue(cfg.check_jax_usage)
def test_assign_with_error(self): cfg = config.get_config() cfg.check_jax_usage = False try: with config.assign(check_jax_usage=True): self.assertTrue(cfg.check_jax_usage) # Raise an exception to test that config is reset on error. raise ValueError("expected") except ValueError: pass self.assertFalse(cfg.check_jax_usage)
def test_with_config(self): ran_f = [False] @config.with_config(check_jax_usage=False) def f(): ran_f[0] = True return config.get_config().check_jax_usage cfg = config.get_config() cfg.check_jax_usage = True self.assertFalse(f()) self.assertTrue(ran_f[0]) self.assertTrue(cfg.check_jax_usage)
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`.") # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) if jax.config.jax_experimental_name_stack and module_name: local_module_name = module_name.split("/")[-1] f = jax.named_call(f, name=local_module_name) if method_name != "__call__": f = jax.named_call(f, name=method_name) elif module_name: # TODO(lenamartens): remove this branch once jax_experimental_name_stack # flag is removed. cfg = config.get_config() if cfg.profiler_name_scopes and method_name == "__call__": local_module_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_module_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def assert_jax_usage(public_symbol_name: str): if not config.get_config().check_jax_usage: return expected_level = current_frame().jax_trace_stack.peek() trace_level = JaxTraceLevel.current() if trace_level != expected_level: raise JaxUsageError( "tl;dr - You need to use a Haiku overloaded transform (e.g. `hk.vmap`) " "or control flow operator (e.g. `hk.cond`) instead of the `jax.*` " "equivalent for untransformed functions using Haiku APIs." "\n\n", "Some APIs in JAX (e.g. `jit`, `vmap`, `cond`, `switch`) take " f"functions that are expected to be pure. `hk.{public_symbol_name}` " "has a side effect, and using it inside a function (without also using " "`hk.transform`) makes that function 'impure' (the function has a side " "effect)." "\n\n" "Haiku includes drop-in replacements for these JAX APIs (e.g. " "`hk.vmap`) that carefully turn your function into a pure function and " "then call the underlying JAX function.")
def test_attr_disable_auto_repr(self): self.assertTrue(config.get_config().module_auto_repr) self.assertRegex(str(NoAutoReprModule()), "<.*.NoAutoReprModule object at .*>")
def test_assign(self): cfg = config.get_config() cfg.check_jax_usage = False with config.assign(check_jax_usage=True): self.assertTrue(cfg.check_jax_usage) self.assertFalse(cfg.check_jax_usage)
def f(): ran_f[0] = True return config.get_config().check_jax_usage
def g(): e1.wait() self.assertEqual(config.get_config().check_jax_usage, default) config.get_config().check_jax_usage = False e2.set() self.assertFalse(config.get_config().check_jax_usage)
def f(): self.assertEqual(config.get_config().check_jax_usage, default) config.get_config().check_jax_usage = True e1.set() e2.wait() self.assertTrue(config.get_config().check_jax_usage)
def __new__(cls, data): if config.get_config().restore_flatmap: return to_immutable_dict(data) else: return to_haiku_dict(data)