コード例 #1
0
  def wrapper(*args, **kwargs):
    base.assert_context("optimize_rng_use")

    # Extract all current state.
    frame = base.current_frame()
    params = frame.params or None
    if params is not None:
      params = data_structures.to_haiku_dict(params)
    state = frame.state or None
    if state is not None:
      state = base.extract_state(state, initial=True)
    rng = frame.rng_stack.peek()
    if rng is not None:
      rng = rng.internal_state

    def pure_fun(params, state, rng, *args, **kwargs):
      with base.new_context(params=params, state=state, rng=rng):
        return fun(*args, **kwargs)

    with count_hk_rngs_requested() as rng_count_f:
      jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
    rng_count = rng_count_f()

    if rng_count:
      base.current_frame().rng_stack.peek().reserve(rng_count)
    return fun(*args, **kwargs)
コード例 #2
0
ファイル: filtering.py プロジェクト: YAMWD/dm-haiku
def merge(*structures: T) -> T:
  """Merges multiple input structures.

  >>> weights = {'linear': {'w': None}}
  >>> biases = {'linear': {'b': None}}
  >>> hk.data_structures.merge(weights, biases)
  {'linear': {'w': None, 'b': None}}

  When structures are not disjoint the output will contain the value from the
  last structure for each path:

  >>> weights1 = {'linear': {'w': 1}}
  >>> weights2 = {'linear': {'w': 2}}
  >>> hk.data_structures.merge(weights1, weights2)
  {'linear': {'w': 2}}

  Note: returns a new structure not a view.

  Args:
    *structures: One or more structures to merge.

  Returns:
    A single structure with an entry for each path in the input structures.
  """
  out = collections.defaultdict(dict)
  for structure in structures:
    for module_name, name, value in traverse(structure):
      out[module_name][name] = value
  return data_structures.to_haiku_dict(out)
コード例 #3
0
ファイル: filtering.py プロジェクト: YAMWD/dm-haiku
def filter(  # pylint: disable=redefined-builtin
    predicate: Callable[[str, str, jnp.ndarray], bool],
    structure: T,
) -> T:
  """Filters an input structure according to a user specified predicate.

  >>> params = {'linear': {'w': None, 'b': None}}
  >>> predicate = lambda module_name, name, value: name == 'w'
  >>> hk.data_structures.filter(predicate, params)
  {'linear': {'w': None}}

  Note: returns a new structure not a view.

  Args:
    predicate: criterion to be used to partition the input data.
      The ``predicate`` argument is expected to be a boolean function taking as
      inputs the name of the module, the name of a given entry in the module
      data bundle (e.g. parameter name) and the corresponding data.
    structure: Haiku params or state data structure to be filtered.

  Returns:
    All the input parameters or state as selected by the input predicate.
  """
  out = collections.defaultdict(dict)

  for module_name, name, value in traverse(structure):
    if predicate(module_name, name, value):
      out[module_name][name] = value

  return data_structures.to_haiku_dict(out)
コード例 #4
0
ファイル: filtering.py プロジェクト: YAMWD/dm-haiku
def map(  # pylint: disable=redefined-builtin
    fn: Callable[[str, str, InT], OutT],
    structure: Mapping[str, Mapping[str, InT]],
) -> Mapping[str, Mapping[str, OutT]]:
  """Maps a function to an input structure accordingly.

  >>> params = {'linear': {'w': 1.0, 'b': 2.0}}
  >>> fn = lambda module_name, name, value: 2 * value if name == 'w' else value
  >>> hk.data_structures.map(fn, params)
  {'linear': {'b': 2.0, 'w': 2.0}}

  Note: returns a new structure not a view.

  Args:
    fn: criterion to be used to map the input data.
      The ``fn`` argument is expected to be a boolean function taking as
      inputs the name of the module, the name of a given entry in the module
      data bundle (e.g. parameter name) and the corresponding data.
    structure: Haiku params or state data structure to be mapped.

  Returns:
    All the input parameters or state as mapped by the input fn.
  """
  out = collections.defaultdict(dict)
  for module_name, name, value in traverse(structure):
    out[module_name][name] = fn(module_name, name, value)
  return data_structures.to_haiku_dict(out)
コード例 #5
0
ファイル: base.py プロジェクト: JacobARose/dm-haiku
def extract_state(state: MutableState, *, initial) -> State:
    state = {
        m: {k: (v.initial if initial else v.current)
            for k, v in p.items()}
        for m, p in state.items()
    }
    state = data_structures.to_haiku_dict(state)
    return state
コード例 #6
0
 def test_copy(self, clone):
   before = data_structures.to_immutable_dict(dict(a=dict(b=1, c=2)))
   after = clone(before)
   self.assertIsNot(before, after)
   self.assertEqual(before, after)
   self.assertEqual(after, {"a": {"b": 1, "c": 2}})
   before_dict = data_structures.to_haiku_dict(before)
   jax.tree_map(self.assertEqual, before_dict, after)
コード例 #7
0
ファイル: base.py プロジェクト: JacobARose/dm-haiku
def new_context(
    *,
    params: Optional[Params] = None,
    state: Optional[State] = None,
    rng: Optional[Union[PRNGKey, int]] = None,
) -> HaikuContext:
    """Collects the results of hk.{get,set}_{parameter,state} calls.

  >>> with new_context(rng=jax.random.PRNGKey(42)) as ctx:
  ...   mod = hk.nets.MLP([300, 100, 10])
  ...   y1 = mod(jnp.ones([1, 1]))

  >>> assert len(jax.tree_leaves(ctx.collect_params())) == 6

  >>> with ctx:
  ...   y2 = mod(jnp.ones([1, 1]))

  The same module instance in the same context will produce the same value:

  >>> assert (y1 == y2).all()

  Args:
    params: Optional parameter values to inject.
    state: Optional state values to inject.
    rng: Optional rng to inject.

  Returns:
    Context manager which closes over mutable Haiku internal state.
  """
    if params is None:
        params = collections.defaultdict(dict)
        freeze_params = False
    else:
        params = data_structures.to_haiku_dict(params)
        freeze_params = True

    if state is None:
        state = collections.defaultdict(dict)
    else:
        state = {
            m: {k: StatePair(v, v)
                for k, v in p.items()}
            for m, p in state.items()
        }
        state = collections.defaultdict(dict, state)

    if rng is not None and not isinstance(rng, PRNGSequence):
        rng = PRNGSequence(rng)

    return HaikuContext(params, state, rng, freeze_params)
コード例 #8
0
def merge(
    *structures: Mapping[str, Mapping[str, Any]],
    check_duplicates: bool = False,
) -> Mapping[str, Mapping[str, Any]]:
  """Merges multiple input structures.

  >>> weights = {'linear': {'w': None}}
  >>> biases = {'linear': {'b': None}}
  >>> hk.data_structures.merge(weights, biases)
  {'linear': {'w': None, 'b': None}}

  When structures are not disjoint the output will contain the value from the
  last structure for each path:

  >>> weights1 = {'linear': {'w': 1}}
  >>> weights2 = {'linear': {'w': 2}}
  >>> hk.data_structures.merge(weights1, weights2)
  {'linear': {'w': 2}}

  Note: returns a new structure not a view.

  Args:
    *structures: One or more structures to merge.
    check_duplicates: If True, a ValueError will be thrown if an array is
      found in multiple structures but with a different shape and dtype.

  Returns:
    A single structure with an entry for each path in the input structures.
  """
  array_like = lambda o: hasattr(o, "shape") and hasattr(o, "dtype")
  shaped = lambda a: (a.shape, a.dtype) if array_like(a) else None
  fmt = lambda a: utils.format_array(a) if array_like(a) else repr(a)

  out = collections.defaultdict(dict)
  for structure in structures:
    for module_name, name, value in traverse(structure):
      if check_duplicates and (name in out[module_name]):
        previous = out[module_name][name]
        if shaped(previous) != shaped(value):
          raise ValueError(
              "Duplicate array found with different shape/dtype for "
              f"{module_name}.{name}: {fmt(previous)} vs {fmt(value)}.")
      out[module_name][name] = value
  return data_structures.to_haiku_dict(out)
コード例 #9
0
ファイル: filtering.py プロジェクト: YAMWD/dm-haiku
def partition_n(
    fn: Callable[[str, str, jnp.ndarray], int],
    structure: T,
    n: int,
) -> Tuple[T, ...]:
  """Partitions a structure into `n` structures.

  For a given set of parameters, you can use :func:`partition_n` to split them
  into ``n`` groups. For example, to split your parameters/gradients by module
  name:

  >>> def partition_by_module(structure):
  ...   cnt = itertools.count()
  ...   d = collections.defaultdict(lambda: next(cnt))
  ...   fn = lambda m, n, v: d[m]
  ...   return hk.data_structures.partition_n(fn, structure, len(structure))

  >>> structure = {f'layer_{i}': {'w': None, 'b': None} for i in range(3)}
  >>> for substructure in partition_by_module(structure):
  ...   print(substructure)
  {'layer_0': {'b': None, 'w': None}}
  {'layer_1': {'b': None, 'w': None}}
  {'layer_2': {'b': None, 'w': None}}

  Args:
    fn: Callable returning which bucket in ``[0, n)`` the given element should
      be output.
    structure: Haiku params or state data structure to be partitioned.
    n: The total number of buckets.

  Returns:
    A tuple of size ``n``, where each element will contain the values for which
    the function returned the current index.
  """
  out = [collections.defaultdict(dict) for _ in range(n)]
  for module_name, name, value in traverse(structure):
    i = fn(module_name, name, value)
    assert 0 <= i < n, f"{i} must be in range [0, {n})"
    out[i][module_name][name] = value
  return tuple(data_structures.to_haiku_dict(o) for o in  out)
コード例 #10
0
ファイル: transform.py プロジェクト: stjordanis/dm-haiku
 def apply_fn(params, state, *args, **kwargs):
     del state
     out = f.apply(params, *args, **kwargs)
     state = data_structures.to_haiku_dict({})
     return out, state
コード例 #11
0
ファイル: transform.py プロジェクト: stjordanis/dm-haiku
 def init_fn(*args, **kwargs):
     params = f.init(*args, **kwargs)
     state = data_structures.to_haiku_dict({})
     return params, state
コード例 #12
0
ファイル: base.py プロジェクト: JacobARose/dm-haiku
 def collect_params(self) -> Params:
     return data_structures.to_haiku_dict(self.__params)