Exemple #1
0
    def test_inherits_default_from_main_thread(self, default):
        e1 = threading.Event()
        e2 = threading.Event()

        config.get_config().check_jax_usage = default

        def f():
            self.assertEqual(config.get_config().check_jax_usage, default)
            config.get_config().check_jax_usage = True
            e1.set()
            e2.wait()
            self.assertTrue(config.get_config().check_jax_usage)

        def g():
            e1.wait()
            self.assertEqual(config.get_config().check_jax_usage, default)
            config.get_config().check_jax_usage = False
            e2.set()
            self.assertFalse(config.get_config().check_jax_usage)

        with futures.ThreadPoolExecutor() as tpe:
            f1 = tpe.submit(g)
            f2 = tpe.submit(f)
            f2.result()
            f1.result()

        self.assertEqual(config.get_config().check_jax_usage, default)
Exemple #2
0
 def test_set(self):
     cfg = config.get_config()
     cfg.check_jax_usage = False
     config.set(check_jax_usage=True)
     self.assertTrue(cfg.check_jax_usage)
     config.set(check_jax_usage=False)
     self.assertFalse(cfg.check_jax_usage)
Exemple #3
0
  def __call__(cls: Type[T], *args, **kwargs) -> T:  # pylint: disable=no-self-argument
    # Call new such that we have an un-initialized module instance that we can
    # still reference even if there is an exception during __init__. This is
    # needed such that we can make sure the name_scope constructed in __init__
    # is closed even if there is an exception.

    # NOTE: We disable pytype since (somewhat surprisingly) this method is bound
    # with the new class and not the metaclass.
    module = cls.__new__(cls, *args, **kwargs)  # pytype: disable=wrong-arg-types

    # Now attempt to initialize the object.
    init = wrap_method("__init__", cls.__init__)
    init(module, *args, **kwargs)

    if (config.get_config().module_auto_repr and
        getattr(module, "AUTO_REPR", True)):
      module._auto_repr = utils.auto_repr(cls, *args, **kwargs)  # pylint: disable=protected-access
    else:
      module._auto_repr = object.__repr__(module)

    ran_super_ctor = hasattr(module, "module_name")
    if not ran_super_ctor:
      raise ValueError(
          "Constructing an hk.Module without calling the super constructor "
          "is not supported. Add the following as the first line in your "
          "__init__ method:\n\nsuper(%s, self).__init__()" % cls.__name__)

    return module
Exemple #4
0
 def test_check_jax_usage(self):
     cfg = config.get_config()
     config.check_jax_usage()
     self.assertTrue(cfg.check_jax_usage)
     config.check_jax_usage(False)
     self.assertFalse(cfg.check_jax_usage)
     config.check_jax_usage(True)
     self.assertTrue(cfg.check_jax_usage)
Exemple #5
0
 def test_assign_with_error(self):
     cfg = config.get_config()
     cfg.check_jax_usage = False
     try:
         with config.assign(check_jax_usage=True):
             self.assertTrue(cfg.check_jax_usage)
             # Raise an exception to test that config is reset on error.
             raise ValueError("expected")
     except ValueError:
         pass
     self.assertFalse(cfg.check_jax_usage)
Exemple #6
0
    def test_with_config(self):
        ran_f = [False]

        @config.with_config(check_jax_usage=False)
        def f():
            ran_f[0] = True
            return config.get_config().check_jax_usage

        cfg = config.get_config()
        cfg.check_jax_usage = True
        self.assertFalse(f())
        self.assertTrue(ran_f[0])
        self.assertTrue(cfg.check_jax_usage)
Exemple #7
0
  def wrapped(self, *args, **kwargs):
    """Calls the original method with a group name set before and after."""
    if not base.frame_stack:
      raise ValueError(
          "All `hk.Module`s must be initialized inside an `hk.transform`.")

    # Submodules are associated with this method. We allow users to associate
    # submodules with a different method than the one being called via
    # `@name_like("other_method")`. Interceptors and custom getters are still
    # provided the actual method name (e.g. "submodule_method_name" is only used
    # for naming submodules).
    submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name)

    frame = base.current_frame()
    state = base.ModuleState(module=self, method_name=submodule_method_name)
    with frame.module(state), _module_method_call(self, method_name):
      # hk.Module enters the module name scope for all methods.
      module_name = getattr(self, "module_name", None)
      f = functools.partial(unbound_method, self)
      f = functools.partial(run_interceptors, f, method_name, self)
      if jax.config.jax_experimental_name_stack and module_name:
        local_module_name = module_name.split("/")[-1]
        f = jax.named_call(f, name=local_module_name)
        if method_name != "__call__":
          f = jax.named_call(f, name=method_name)
      elif module_name:
        # TODO(lenamartens): remove this branch once jax_experimental_name_stack
        # flag is removed.
        cfg = config.get_config()
        if cfg.profiler_name_scopes and method_name == "__call__":
          local_module_name = module_name.split("/")[-1]
          f = stateful.named_call(f, name=local_module_name)

      out = f(*args, **kwargs)

      # Module names are set in the constructor. If `f` is the constructor then
      # its name will only be set **after** `f` has run. For methods other
      # than `__init__` we need the name before running in order to wrap their
      # execution with `named_call`.
      if module_name is None:
        module_name = getattr(self, "module_name", None)

      # Notify parent modules about our existence.
      if module_name is not None:
        for module_state in frame.module_stack:
          if module_state.module is not self:
            module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
    return out
Exemple #8
0
def assert_jax_usage(public_symbol_name: str):
  if not config.get_config().check_jax_usage:
    return

  expected_level = current_frame().jax_trace_stack.peek()
  trace_level = JaxTraceLevel.current()
  if trace_level != expected_level:
    raise JaxUsageError(
        "tl;dr - You need to use a Haiku overloaded transform (e.g. `hk.vmap`) "
        "or control flow operator (e.g. `hk.cond`) instead of the `jax.*` "
        "equivalent for untransformed functions using Haiku APIs."
        "\n\n",
        "Some APIs in JAX (e.g. `jit`, `vmap`, `cond`, `switch`) take "
        f"functions that are expected to be pure. `hk.{public_symbol_name}` "
        "has a side effect, and using it inside a function (without also using "
        "`hk.transform`) makes that function 'impure' (the function has a side "
        "effect)."
        "\n\n"
        "Haiku includes drop-in replacements for these JAX APIs (e.g. "
        "`hk.vmap`) that carefully turn your function into a pure function and "
        "then call the underlying JAX function.")
Exemple #9
0
 def test_attr_disable_auto_repr(self):
     self.assertTrue(config.get_config().module_auto_repr)
     self.assertRegex(str(NoAutoReprModule()),
                      "<.*.NoAutoReprModule object at .*>")
Exemple #10
0
 def test_assign(self):
     cfg = config.get_config()
     cfg.check_jax_usage = False
     with config.assign(check_jax_usage=True):
         self.assertTrue(cfg.check_jax_usage)
     self.assertFalse(cfg.check_jax_usage)
Exemple #11
0
 def f():
     ran_f[0] = True
     return config.get_config().check_jax_usage
Exemple #12
0
 def g():
     e1.wait()
     self.assertEqual(config.get_config().check_jax_usage, default)
     config.get_config().check_jax_usage = False
     e2.set()
     self.assertFalse(config.get_config().check_jax_usage)
Exemple #13
0
 def f():
     self.assertEqual(config.get_config().check_jax_usage, default)
     config.get_config().check_jax_usage = True
     e1.set()
     e2.wait()
     self.assertTrue(config.get_config().check_jax_usage)
Exemple #14
0
 def __new__(cls, data):
   if config.get_config().restore_flatmap:
     return to_immutable_dict(data)
   else:
     return to_haiku_dict(data)