Exemplo n.º 1
0
  def wrapper(*args, **kwargs):
    base.assert_context("optimize_rng_use")

    # Extract all current state.
    frame = base.current_frame()
    params = frame.params or None
    if params is not None:
      params = data_structures.to_haiku_dict(params)
    state = frame.state or None
    if state is not None:
      state = base.extract_state(state, initial=True)
    rng = frame.rng_stack.peek()
    if rng is not None:
      rng = rng.internal_state

    def pure_fun(params, state, rng, *args, **kwargs):
      with base.new_context(params=params, state=state, rng=rng):
        return fun(*args, **kwargs)

    with count_hk_rngs_requested() as rng_count_f:
      jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
    rng_count = rng_count_f()

    if rng_count:
      base.current_frame().rng_stack.peek().reserve(rng_count)
    return fun(*args, **kwargs)
Exemplo n.º 2
0
    def wrapped(self, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        frame = base.current_frame()
        state = base.ModuleState(module=self, method_name=method_name)
        with frame.module(state), _module_method_call(self, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(self, "module_name", None)
            f = functools.partial(unbound_method, self)
            f = functools.partial(run_interceptors, f, method_name, self)
            # TODO(tomhennigan): With omnistaging primitives (like named call) will
            # stage out return values eagerly. For functions that produce non-Array
            # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be
            # returned that might result in a concretization error. For now we only
            # enable named call on __call__ (covering 99% of the interesting usages)
            # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term
            # we may want to split static and dynamic results in named call to support
            # other methods.
            if modules_with_named_call and module_name and method_name == "__call__":
                local_name = module_name.split("/")[-1]
                f = named_call.stateful_named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Exemplo n.º 3
0
 def __call__(self, *args, **kwargs):
     frame = base.current_frame()
     outer_params = frame.params
     outer_state = frame.state
     if hk.running_init():
         inner_params, inner_state = self._init_fn(*args, **kwargs)
         # Lift parameters into this transform's params_dict.
         check_param_reuse = not self._allow_reuse
         pack_into_dict(inner_params,
                        outer_params,
                        self._prefix_name,
                        check_param_reuse=check_param_reuse)
         pack_into_dict(inner_state,
                        outer_state,
                        self._prefix_name,
                        state=True,
                        check_param_reuse=check_param_reuse)
         return inner_params, inner_state
     else:
         if self._prefix_name:
             prefix = f"{self._prefix_name}/"
         else:
             prefix = ""
         inner_params = unpack_from_dict(outer_params, prefix)
         inner_state = unpack_from_dict(outer_state, prefix)
         inner_state = base.extract_state(inner_state, initial=False)
         inner_params = hk.data_structures.to_haiku_dict(inner_params)
         inner_state = hk.data_structures.to_haiku_dict(inner_state)
         return inner_params, inner_state
Exemplo n.º 4
0
def temporary_internal_state(state: InternalState):
    rng = state.rng
    if rng is not None:
        rng = base.PRNGSequence(rng)
    frame = base.current_frame()
    frame = frame.evolve(params=state.params, state=state.state, rng=rng)
    return base.frame_stack(frame)
Exemplo n.º 5
0
    def wrapped(module, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        frame = base.current_frame()
        state = base.ModuleState(module=module, method_name=method_name)
        with frame.module(state), _module_method_call(module, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(module, "module_name", None)
            f = functools.partial(unbound_method, module)
            f = functools.partial(run_interceptors, f, method_name, module)
            if modules_with_named_call and module_name:
                local_name = module_name.split("/")[-1]
                f = named_call.stateful_named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Exemplo n.º 6
0
  def __init__(self, name: Optional[str] = None):
    """Initializes the current module with the given name.

    Subclasses should call this constructor before creating other modules or
    variables such that those modules are named correctly.

    Args:
      name: An optional string name for the class. Must be a valid Python
        identifier. If ``name`` is not provided then the class name for the
        current instance is converted to ``lower_snake_case`` and used instead.
    """
    if name is None:
      if hasattr(self, "name") and self.name is not None:
        # Attribute assigned by @dataclass constructor.
        name = self.name
      else:
        name = utils.camel_to_snake(type(self).__name__)
    if not valid_identifier(name):
      raise ValueError(
          "'{}' is not a valid module name (must be a valid Python identifier)"
          .format(name))
    self._submodules: Set[str] = set()
    self.module_name = unique_and_canonical_name(name)
    self.name = self.module_name.split("/")[-1]
    self._creation_frame_id = base.current_frame().frame_id
Exemplo n.º 7
0
def internal_state() -> InternalState:
    frame = base.current_frame()
    rng = frame.rng_stack.peek()
    if rng is not None:
        rng = rng.peek()
    return InternalState(params=copy_structure(frame.params),
                         state=copy_structure(frame.state),
                         rng=rng)
Exemplo n.º 8
0
def update_internal_state(state: InternalState):
    frame = base.current_frame()
    if not frame.params_frozen and state.params is not None:
        update_recursive(frame.params, state.params)
    update_recursive(frame.state, state.state)
    rng = state.rng
    if rng is not None:
        frame.rng_stack.peek().replace_internal_state(rng)
Exemplo n.º 9
0
def reserve_up_to_full_rng_block():
    """If RNG is active in the current frame, reserve up to the default block."""
    # TODO(lenamartens): Fix needing full block reservation in stateful
    # control-flow by keeping track of current key with index and keeping a full
    # block in PRNGSequence at all time.
    rng_seq = base.current_frame().rng_stack.peek()
    if rng_seq:
        rng_seq.reserve_up_to_full()
Exemplo n.º 10
0
def get_frame_data() -> FrameData:
  frame = current_frame()
  rng = frame.rng_stack.peek()
  if rng is not None:
    rng = rng.internal_state
  return FrameData(params=copy_structure(frame.params),
                   state=copy_structure(frame.state),
                   constants=copy_structure(frame.constants),
                   rng=copy_structure(rng))
Exemplo n.º 11
0
def internal_state(*, params=True) -> InternalState:
    frame = base.current_frame()
    rng = frame.rng_stack.peek()
    if rng is not None:
        rng = rng.internal_state
    return InternalState(
        params=(copy_structure(frame.params) if params else None),
        state=copy_structure(frame.state),
        rng=copy_structure(rng))
Exemplo n.º 12
0
 def __call__(self, *args, **kwargs):
     outer_params = base.current_frame().params
     if hk.running_init():
         inner_params = self._init_fn(*args, **kwargs)
         # Lift parameters into this transform's params_dict.
         pack_into_dict(inner_params, outer_params, self.module_name)
         return inner_params
     else:
         return unpack_from_dict(outer_params, f"{self.module_name}/")
Exemplo n.º 13
0
 def update(self, state: hk.State):
     """Updates Haiku's internal state to the given state."""
     frame = base.current_frame()
     for mod_name, bundle in state.items():
         if self._name is not None:
             mod_name = f"{self._name}/{mod_name}"
         for name, value in bundle.items():
             initial_pair = base.StatePair(value, value)
             initial = frame.state[mod_name].get(name, initial_pair).initial
             frame.state[mod_name][name] = base.StatePair(initial, value)
Exemplo n.º 14
0
def update_modified_frame_data(frame_data: FrameData):
  frame = current_frame()
  if not params_frozen():
    update_recursive_skip_none(frame.params, frame_data.params)
  update_recursive_skip_none(frame.state, frame_data.state)
  if not params_frozen():
    update_recursive_skip_none(frame.constants, frame_data.constants)
  rng = frame_data.rng
  if rng is not None:
    frame.rng_stack.peek().replace_internal_state(rng)
Exemplo n.º 15
0
def update_modified_frame_data_from_args(params, bundled_state):
  (state, constants, rng) = bundled_state
  frame = current_frame()
  if not params_frozen():
    update_recursive_skip_none(frame.params, params)
  update_recursive_skip_none(frame.state, state)
  if not params_frozen():
    update_recursive_skip_none(frame.constants, constants)
  rng = rng
  if rng is not None:
    frame.rng_stack.peek().replace_internal_state(rng)
Exemplo n.º 16
0
Arquivo: lift.py Projeto: ibab/haiku
 def __call__(self, *args, **kwargs):
     frame = base.current_frame()
     bundle_name = self.module_name
     if base.in_apply():
         prefix = bundle_name + "/"
         lifted_params = unpack_from_dict(frame.params, prefix)
         return lifted_params
     else:  # Inside init.
         # Lift parameters into this transform's params_dict.
         params = self._init_fn(*args, **kwargs)
         pack_into_dict(params, frame.params, bundle_name)
         return params
Exemplo n.º 17
0
    def wrapped(self, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        # Submodules are associated with this method. We allow users to associate
        # submodules with a different method than the one being called via
        # `@name_like("other_method")`. Interceptors and custom getters are still
        # provided the actual method name (e.g. "submodule_method_name" is only used
        # for naming submodules).
        submodule_method_name = getattr(unbound_method, _CUSTOM_NAME,
                                        method_name)

        frame = base.current_frame()
        state = base.ModuleState(module=self,
                                 method_name=submodule_method_name)
        with frame.module(state), _module_method_call(self, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(self, "module_name", None)
            f = functools.partial(unbound_method, self)
            f = functools.partial(run_interceptors, f, method_name, self)
            # TODO(tomhennigan): With omnistaging primitives (like named call) will
            # stage out return values eagerly. For functions that produce non-Array
            # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be
            # returned that might result in a concretization error. For now we only
            # enable named call on __call__ (covering 99% of the interesting usages)
            # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term
            # we may want to split static and dynamic results in named call to support
            # other methods.
            if modules_with_named_call and module_name and method_name == "__call__":
                local_name = module_name.split("/")[-1]
                f = stateful.named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Module names are set in the constructor. If `f` is the constructor then
            # its name will only be set **after** `f` has run. For methods other
            # than `__init__` we need the name before running in order to wrap their
            # execution with `named_call`.
            if module_name is None:
                module_name = getattr(self, "module_name", None)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    if module_state.module is not self:
                        module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Exemplo n.º 18
0
def temporary_frame_data(frame_data: FrameData):
  """Pushes a temporary copy of the frame_data."""
  frame_data = copy_structure(frame_data)
  rng = frame_data.rng if frame_data.rng is None else PRNGSequence(frame_data.rng)
  params = frame_data.params
  state = frame_data.state
  constants = frame_data.constants
  assert params is not None, "Must initialize module before this call"
  assert state is not None, "Must initialize module before this call"
  assert constants is not None, "Must initialize module before this call"

  frame = current_frame()
  frame = frame.evolve(params=params, state=state, constants=constants, rng=rng)
  return frame_stack(frame)
Exemplo n.º 19
0
  def wrapped(self, *args, **kwargs):
    """Calls the original method with a group name set before and after."""
    if not base.frame_stack:
      raise ValueError(
          "All `hk.Module`s must be initialized inside an `hk.transform`.")

    # Submodules are associated with this method. We allow users to associate
    # submodules with a different method than the one being called via
    # `@name_like("other_method")`. Interceptors and custom getters are still
    # provided the actual method name (e.g. "submodule_method_name" is only used
    # for naming submodules).
    submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name)

    frame = base.current_frame()
    state = base.ModuleState(module=self, method_name=submodule_method_name)
    with frame.module(state), _module_method_call(self, method_name):
      # hk.Module enters the module name scope for all methods.
      module_name = getattr(self, "module_name", None)
      f = functools.partial(unbound_method, self)
      f = functools.partial(run_interceptors, f, method_name, self)
      if jax.config.jax_experimental_name_stack and module_name:
        local_module_name = module_name.split("/")[-1]
        f = jax.named_call(f, name=local_module_name)
        if method_name != "__call__":
          f = jax.named_call(f, name=method_name)
      elif module_name:
        # TODO(lenamartens): remove this branch once jax_experimental_name_stack
        # flag is removed.
        cfg = config.get_config()
        if cfg.profiler_name_scopes and method_name == "__call__":
          local_module_name = module_name.split("/")[-1]
          f = stateful.named_call(f, name=local_module_name)

      out = f(*args, **kwargs)

      # Module names are set in the constructor. If `f` is the constructor then
      # its name will only be set **after** `f` has run. For methods other
      # than `__init__` we need the name before running in order to wrap their
      # execution with `named_call`.
      if module_name is None:
        module_name = getattr(self, "module_name", None)

      # Notify parent modules about our existence.
      if module_name is not None:
        for module_state in frame.module_stack:
          if module_state.module is not self:
            module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
    return out
Exemplo n.º 20
0
def temporary_internal_state(state: InternalState):
    """Pushes a temporary copy of the internal state."""
    state = copy_structure(state)
    rng = state.rng
    if rng is not None:
        rng = base.PRNGSequence(rng)
    current_state = internal_state()
    params = state.params
    if params is None:
        params = current_state.params
    state = state.state
    if state is None:
        state = current_state.state
    frame = base.current_frame()
    frame = frame.evolve(params=params, state=state, rng=rng)
    return base.frame_stack(frame)
Exemplo n.º 21
0
 def __call__(self, *args, **kwargs):
   frame = base.current_frame()
   bundle_name = self.module_name
   if _SENTINEL_NAME in frame.params[bundle_name]:
     prefix = bundle_name + "/"
     lifted_params = unpack_from_dict(frame.params, prefix)
     lifted_state = unpack_from_dict(frame.state, prefix)
     return lifted_params, lifted_state
   else:
     # Ensure sentinel is set for apply.
     base.get_parameter(_SENTINEL_NAME, (), init=jnp.zeros)
     # Lift parameters into this transform's params_dict.
     params, state = self._init_fn(*args, **kwargs)
     pack_into_dict(params, frame.params, bundle_name)
     pack_into_dict(state, frame.state, bundle_name)
     return params, state
Exemplo n.º 22
0
def get_constant(name: str, value: Any, init=None, do_not_set=False):
    constants = current_frame().constants[current_bundle_name()]
    saved_value = constants.get(name, None)
    if saved_value is None:
        if do_not_set:
            return None

        if init is not None:
            value = init(value)
            constants[name] = value
        else:
            constants[name] = value
    else:
        assert name in constants, f"Missing {name} in constants"
        value = saved_value

    return value
Exemplo n.º 23
0
def unique_and_canonical_name(name: str) -> str:
  """Returns a canonical name for the given name."""
  frame = base.current_frame()

  # If we are outside init/call then prefix the name with the method name.
  if len(frame.module_stack) > 1:
    # -2 since we are inside the ctor and want to look at the caller state.
    module_state = frame.module_stack.peek(-2)

    # Make sure to include the method name if appropriate.
    method_name = module_state.method_name
    if method_name == "__init__":
      name = "~/" + name
    elif method_name != "__call__":
      name = "~" + method_name + "/" + name

    # Include the parent name.
    parent_module = module_state.module
    parent_name = base.safe_get_module_name(parent_module)
    name = parent_name + "/" + name

  # Test if the user has explicitly numbered this module.
  splits = re.split(r"_(\d+)$", name, 3)
  if len(splits) > 1:
    name, n = splits[0], int(splits[1])
    explicit_n = True
  else:
    n = None
    explicit_n = False

  # Determine a unique name for this module within the current context.
  counters = frame.counter_stack.peek(-2)
  if n is not None:
    counters[name] = max(counters[name], n + 1)
  else:
    n = counters[name]
    counters[name] += 1
  qualified_name = f"{name}_{n}" if explicit_n or n else name

  # Final sanity check that this name has not been used before.
  used_names = frame.used_names_stack.peek(-2)
  if qualified_name in used_names:
    raise ValueError(f"Module name '{qualified_name}' is not unique.")
  used_names.add(qualified_name)

  return qualified_name
Exemplo n.º 24
0
  def params_dict(self) -> Mapping[base.ParamName, jnp.array]:
    """Returns parameters keyed by name for this module and submodules."""
    if not base.frame_stack:
      raise ValueError(
          "`module.params_dict()` must be used as part of an `hk.transform`.")

    params = {}
    curr_name = self.module_name
    for mod_name, mod_params in base.current_frame().params.items():
      if (mod_name == curr_name
          or mod_name.startswith(curr_name + "/")
          or mod_name in self._submodules):
        for param_name, param in mod_params.items():
          fq_name = mod_name + "/" + param_name
          params[fq_name] = param

    return params
Exemplo n.º 25
0
def params_or_state_dict(
    module_name: str,
    submodules: Set[str],
    which: str,
) -> Mapping[str, jnp.array]:
    """Returns module parameters or state for the given module or submodules."""
    assert which in ("params", "state")
    out = {}
    frame = base.current_frame()
    for their_module_name, bundle in getattr(frame, which).items():
        if (their_module_name == module_name
                or their_module_name.startswith(module_name + "/")
                or their_module_name in submodules):
            for name, value in bundle.items():
                fq_name = their_module_name + "/" + name
                out[fq_name] = value.current if which == "state" else value
    return out
Exemplo n.º 26
0
Arquivo: module.py Projeto: ibab/haiku
    def wrapped(module, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        frame = base.current_frame()
        state = base.ModuleState(module=module, method_name=method_name)
        with frame.module(state), _module_method_call(module, method_name):
            # hk.Module enters the module name scope for all methods.
            out = unbound_method(module, *args, **kwargs)

            # Notify parent modules about our existence.
            module_name = getattr(module, "module_name", None)
            if module_name is not None:
                for module_state in frame.module_stack:
                    module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Exemplo n.º 27
0
def get_call_stack() -> Sequence[ModuleDetails]:
    frame = base.current_frame()
    return tuple(
        map(lambda s: ModuleDetails.of(s.module, s.method_name),
            list(frame.module_stack)))
Exemplo n.º 28
0
 def wrapped(params, state, *args, **kwargs):
   update_frame_data(params, state)
   out = module(*args, **kwargs)
   state = copy_structure(current_frame().state)
   return out, state
Exemplo n.º 29
0
def get_bundled_state():
  frame = current_frame()
  rng = frame.rng_stack.peek()
  if rng is not None:
    rng = rng.internal_state
  return (frame.state, frame.constants, rng)
Exemplo n.º 30
0
def update_frame_data(params, state):
  frame = current_frame()
  if not params_frozen():
    update_recursive_skip_none(frame.params, params)
  update_recursive_skip_none(frame.state, state)