コード例 #1
0
def new_custom_context(
    *,
    params: Optional[Params] = None,
    state: Optional[State] = None,
    constants: Optional[State] = None,
    rng: Optional[Union[PRNGKey, int]] = None,
) -> CustomHaikuContext:

    if params is None:
        params = collections.defaultdict(dict)
    else:
        params = data_structures.to_immutable_dict(params)

    if state is None:
        state = collections.defaultdict(dict)
    else:
        state = {
            m: {k: StatePair(v, v)
                for k, v in p.items()}
            for m, p in state.items()
        }

    if constants is None:
        constants = collections.defaultdict(dict)
    else:
        constants = data_structures.to_immutable_dict(constants)

    if rng is not None and not isinstance(rng, PRNGSequence):
        rng = PRNGSequence(rng)

    return CustomHaikuContext(params, state, constants, rng)
コード例 #2
0
def decode_layer_weight(layers):
    data = {}
    for layer in layers:
        data[layer.name] = to_immutable_dict({
            "b": tensor_float_decoder(layer.b),
            "w": tensor_float_decoder(layer.w)
        })
    return to_immutable_dict(data)
コード例 #3
0
ファイル: lift.py プロジェクト: ibab/haiku
def unpack_from_dict(src, prefix):
    """Returns pairs from src where key begins with prefix, cutting off prefix."""
    result = dict()
    for key, value in src.items():
        if key.startswith(prefix):
            result[key[len(prefix):]] = value
    return data_structures.to_immutable_dict(result)
コード例 #4
0
ファイル: random.py プロジェクト: stjordanis/dm-haiku
    def wrapper(*args, **kwargs):
        base.assert_context("optimize_rng_use")

        # Extract all current state.
        frame = base.current_frame()
        params = frame.params or None
        if params is not None:
            params = data_structures.to_immutable_dict(params)
        state = frame.state or None
        if state is not None:
            state = base.extract_state(state, initial=True)
        rng = frame.rng_stack.peek()
        if rng is not None:
            rng = rng.internal_state

        def pure_fun(params, state, rng, *args, **kwargs):
            with base.new_context(params=params, state=state, rng=rng):
                return fun(*args, **kwargs)

        with count_hk_rngs_requested() as rng_count_f:
            jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
        rng_count = rng_count_f()

        if rng_count:
            base.current_frame().rng_stack.peek().reserve(rng_count)
        return fun(*args, **kwargs)
コード例 #5
0
ファイル: filtering.py プロジェクト: vballoli/dm-haiku
def filter(predicate: Predicate, structure: T) -> T:  # pylint: disable=redefined-builtin
    """Filters a input structure according to a user specified predicate.

  >>> params = {'linear': {'w': None, 'b': None}}
  >>> predicate = lambda module_name, name, value: name == 'w'
  >>> hk.data_structures.filter(predicate, params)
  frozendict({'linear': frozendict({'w': None})})

  Note: returns a new structure not a view.

  Args:
    predicate: criterion to be used to partition the input data.
      The `predicate` argument is expected to be a boolean function taking as
      inputs the name of the module, the name of a given entry in the module
      data bundle (e.g. parameter name) and the corresponding data.
    structure: Haiku params or state data structure to be filtered.

  Returns:
    All the input parameters or state as selected by the input predicate.
  """
    out = collections.defaultdict(dict)

    for module_name, bundle in structure.items():
        for name, value in bundle.items():
            if predicate(module_name, name, value):
                out[module_name][name] = value

    return data_structures.to_immutable_dict(out)
コード例 #6
0
ファイル: base.py プロジェクト: pythseq/dm-haiku
    def apply_fn(
        params: Params,
        state: State,
        rng: Optional[Union[PRNGKey, PRNGSeed]],
        *args,
        **kwargs,
    ):
        """Applies your function injecting parameters and state."""
        # TODO(tomhennigan) Remove support for `None` params (used in tests).
        params = data_structures.to_immutable_dict(params)
        state = {
            m: {k: StatePair(v, v)
                for k, v in p.items()}
            for m, p in state.items()
        }
        if rng is not None:
            try:
                rng = PRNGSequence(rng)
            except Exception as e:
                if state:
                    position, signature = "third", "apply(params, state, rng, *a, **k)"
                else:
                    position, signature = "second", "apply(params, rng, *a, **k)"
                raise ValueError(
                    f"Apply must be called with an RNG as the {position} argument, "
                    f"the required signature is: `{signature}`") from e

        with frame_stack(Frame.create(params=params, state=state, rng=rng)):
            out = f(*args, **kwargs)

        state = _extract_state(state, initial=False)
        return out, state
コード例 #7
0
ファイル: filtering.py プロジェクト: yuripulier/dm-haiku
def map(  # pylint: disable=redefined-builtin
    fn: Callable[[str, str, InT], OutT],
    structure: Mapping[str, Mapping[str, InT]],
) -> Mapping[str, Mapping[str, OutT]]:
  """Maps a function to a input structure according.

  >>> params = {'linear': {'w': 1.0, 'b': 2.0}}
  >>> fn = lambda module_name, name, value: 2 * value if name == 'w' else value
  >>> hk.data_structures.map(fn, params)
  FlatMapping({'linear': FlatMapping({'w': 2.0, 'b': 2.0})})

  Note: returns a new structure not a view.

  Args:
    fn: criterion to be used to map the input data.
      The ``fn`` argument is expected to be a boolean function taking as
      inputs the name of the module, the name of a given entry in the module
      data bundle (e.g. parameter name) and the corresponding data.
    structure: Haiku params or state data structure to be mapped.

  Returns:
    All the input parameters or state as mapped by the input fn.
  """
  out = collections.defaultdict(dict)

  for module_name, bundle in structure.items():
    for name, value in bundle.items():
      out[module_name][name] = fn(module_name, name, value)

  return data_structures.to_immutable_dict(out)
コード例 #8
0
ファイル: base.py プロジェクト: shyamalschandra/haiku
    def apply_fn(
        params: Params,
        state: State,
        rng: Optional[Union[PRNGKey, PRNGSeed]],
        *args,
        **kwargs,
    ):
        """Applies your function injecting parameters and state."""
        # TODO(tomhennigan) Remove support for `None` params (used in tests).
        params = data_structures.to_immutable_dict(params)
        state = {
            m: {k: StatePair(v, v)
                for k, v in p.items()}
            for m, p in state.items()
        }

        frame = Frame.create(
            params=params,
            state=state,
            rng=(PRNGSequence(rng) if rng is not None else None))

        with frame_stack(frame):
            out = f(*args, **kwargs)

        state = _extract_state(state, initial=False)
        return out, state
コード例 #9
0
 def test_copy(self, clone):
     before = data_structures.to_immutable_dict(dict(a=dict(b=1, c=2)))
     after = clone(before)
     self.assertIsNot(before, after)
     self.assertEqual(before, after)
     self.assertEqual(after, {"a": {"b": 1, "c": 2}})
     jax.tree_multimap(self.assertEqual, before, after)
コード例 #10
0
ファイル: filtering.py プロジェクト: vballoli/dm-haiku
def merge(*structures: T) -> T:
    """Merges multiple input structures.

  >>> weights = {'linear': {'w': None}}
  >>> biases = {'linear': {'b': None}}
  >>> hk.data_structures.merge(weights, biases)
  frozendict({'linear': frozendict({'b': None, 'w': None})})

  When structures are not disjoint the output will contain the value from the
  last structure for each path:

  >>> weights1 = {'linear': {'w': 1}}
  >>> weights2 = {'linear': {'w': 2}}
  >>> hk.data_structures.merge(weights1, weights2)
  frozendict({'linear': frozendict({'w': 2})})

  Note: returns a new structure not a view.

  Args:
    *structures: One or more structures to merge.

  Returns:
    A single structure with an entry for each path in the input structures.
  """
    out = collections.defaultdict(dict)
    for structure in structures:
        for module_name, bundle in structure.items():
            for name, value in bundle.items():
                out[module_name][name] = value
    return data_structures.to_immutable_dict(out)
コード例 #11
0
ファイル: base.py プロジェクト: pythseq/dm-haiku
def _extract_state(state: MutableState, *, initial) -> State:
    state = {
        m: {k: (v.initial if initial else v.current)
            for k, v in p.items()}
        for m, p in state.items()
    }
    state = data_structures.to_immutable_dict(state)
    return state
コード例 #12
0
def partition(
    predicate: Callable[[str, str, jnp.ndarray], bool],
    structure: T,
) -> Tuple[T, T]:
    """Partitions the input structure in two according to a given predicate.

  For a given set of parameters, you can use :func:`partition` to split them:

  >>> params = {'linear': {'w': None, 'b': None}}
  >>> predicate = lambda module_name, name, value: name == 'w'
  >>> weights, biases = hk.data_structures.partition(predicate, params)
  >>> weights
  frozendict({'linear': frozendict({'w': None})})
  >>> biases
  frozendict({'linear': frozendict({'b': None})})

  Note: returns new structures not a view.

  Args:
    predicate: criterion to be used to partition the input data.
      The ``predicate`` argument is expected to be a boolean function taking as
      inputs the name of the module, the name of a given entry in the module
      data bundle (e.g. parameter name) and the corresponding data.
    structure: Haiku params or state data structure to be partitioned.

  Returns:
    A tuple containing all the params or state as partitioned by the input
      predicate. Entries matching the predicate will be in the first structure,
      and the rest will be in the second.
  """
    true = collections.defaultdict(dict)
    false = collections.defaultdict(dict)

    for module_name, bundle in structure.items():
        for name, value in bundle.items():
            out = true if predicate(module_name, name, value) else false
            out[module_name][name] = value

    true = data_structures.to_immutable_dict(true)
    false = data_structures.to_immutable_dict(false)

    return true, false
コード例 #13
0
def new_context(
    *,
    params: Optional[Params] = None,
    state: Optional[State] = None,
    rng: Optional[Union[PRNGKey, int]] = None,
) -> HaikuContext:
    """Collects the results of hk.{get,set}_{parameter,state} calls.

  >>> with new_context(rng=jax.random.PRNGKey(42)) as ctx:
  ...   mod = hk.nets.MLP([300, 100, 10])
  ...   y1 = mod(jnp.ones([1, 1]))

  >>> assert len(jax.tree_leaves(ctx.collect_params())) == 6

  >>> with ctx:
  ...   y2 = mod(jnp.ones([1, 1]))

  The same module instance in the same context will produce the same value:

  >>> assert (y1 == y2).all()

  Args:
    params: Optional parameter values to inject.
    state: Optional state values to inject.
    rng: Optional rng to inject.

  Returns:
    Context manager which closes over mutable Haiku internal state.
  """
    if params is None:
        params = collections.defaultdict(dict)
    else:
        params = data_structures.to_immutable_dict(params)

    if state is None:
        state = collections.defaultdict(dict)
    else:
        state = {
            m: {k: StatePair(v, v)
                for k, v in p.items()}
            for m, p in state.items()
        }
        state = collections.defaultdict(dict, state)

    if rng is not None and not isinstance(rng, PRNGSequence):
        rng = PRNGSequence(rng)

    return HaikuContext(params, state, rng)
コード例 #14
0
  def __call__(self, tree, update_stats=True):
    def maybe_sn(k, v):
      if self._ignore_regex and re.match(self._ignore_regex, k):
        return v
      else:
        sn_name = k.replace("/", "__").replace("~", "_tilde")
        return SpectralNorm(self._eps, self._n_steps, name=sn_name)(
            v, update_stats=update_stats)

    # We want to potentially replace params with Spectral Normalized versions.
    new_values = {}
    for module_name, param_dict in tree.items():
      new_values[module_name] = {
          k: maybe_sn("/".join([module_name, k]), v)
          for k, v in param_dict.items()
      }
    return data_structures.to_immutable_dict(new_values)
コード例 #15
0
ファイル: moving_averages.py プロジェクト: vballoli/dm-haiku
  def __call__(self, tree, update_stats=True):
    def maybe_ema(k, v):
      if self._ignore_regex and re.match(self._ignore_regex, k):
        return v
      else:
        ema_name = k.replace("/", "__").replace("~", "_tilde_")
        return ExponentialMovingAverage(
            self._decay, self._zero_debias, self._warmup_length, name=ema_name)(
                v, update_stats=update_stats)

    # We want to potentially replace params with EMA'd versions.
    new_values = {}
    for module_name, param_dict in tree.items():
      new_values[module_name] = {
          k: maybe_ema("/".join([module_name, k]), v)
          for k, v in param_dict.items()
      }
    return data_structures.to_immutable_dict(new_values)
コード例 #16
0
ファイル: base.py プロジェクト: shyamalschandra/haiku
    def init_fn(
        rng: Optional[Union[PRNGKey, PRNGSeed]],
        *args,
        **kwargs,
    ):
        """Initializes your function collecting parameters and state."""
        params = collections.defaultdict(dict)
        state = collections.defaultdict(dict)

        frame = Frame.create(
            params=params,
            state=state,
            rng=(PRNGSequence(rng) if rng is not None else None))

        with frame_stack(frame):
            f(*args, **kwargs)

        params = data_structures.to_immutable_dict(params)
        state = _extract_state(state, initial=True)
        return params, state
コード例 #17
0
ファイル: filtering.py プロジェクト: facebbook/dm-haiku
def partition_n(
    fn: Callable[[str, str, jnp.ndarray], int],
    structure: T,
    n: int,
) -> Tuple[T, ...]:
  """Partitions a structure into `n` structures.

  For a given set of parameters, you can use :func:`partition_n` to split them
  into ``n`` groups. For example, to split your parameters/gradients by module
  name:

  >>> def partition_by_module(structure):
  ...   cnt = itertools.count()
  ...   d = collections.defaultdict(lambda: next(cnt))
  ...   fn = lambda m, n, v: d[m]
  ...   return hk.data_structures.partition_n(fn, structure, len(structure))

  >>> structure = {f'layer_{i}': {'w': None, 'b': None} for i in range(3)}
  >>> for substructure in partition_by_module(structure):
  ...   print(substructure)
  FlatMapping({'layer_0': FlatMapping({'b': None, 'w': None})})
  FlatMapping({'layer_1': FlatMapping({'b': None, 'w': None})})
  FlatMapping({'layer_2': FlatMapping({'b': None, 'w': None})})

  Args:
    fn: Callable returning which bucket in ``[0, n)`` the given element should
      be output.
    structure: Haiku params or state data structure to be partitioned.
    n: The total number of buckets.

  Returns:
    A tuple of size ``n``, where each element will contain the values for which
    the function returned the current index.
  """
  out = [collections.defaultdict(dict) for _ in range(n)]
  for module_name, name, value in traverse(structure):
    i = fn(module_name, name, value)
    assert 0 <= i < n, f"{i} must be in range [0, {n})"
    out[i][module_name][name] = value
  return tuple(data_structures.to_immutable_dict(o) for o in  out)
コード例 #18
0
ファイル: base.py プロジェクト: pythseq/dm-haiku
    def init_fn(
        rng: Optional[Union[PRNGKey, PRNGSeed]],
        *args,
        **kwargs,
    ):
        """Initializes your function collecting parameters and state."""
        params = collections.defaultdict(dict)
        state = collections.defaultdict(dict)
        if rng is not None:
            try:
                rng = PRNGSequence(rng)
            except Exception as e:
                raise ValueError(
                    "Init must be called with an RNG as the first argument, the "
                    "required signature is: `init(rng, *a, **k)`") from e

        with frame_stack(Frame.create(params=params, state=state, rng=rng)):
            f(*args, **kwargs)

        params = data_structures.to_immutable_dict(params)
        state = _extract_state(state, initial=True)
        return params, state
コード例 #19
0
ファイル: base.py プロジェクト: facebbook/dm-haiku
 def collect_params(self) -> Params:
   return data_structures.to_immutable_dict(self.__params)
コード例 #20
0
 def collect_constants(self) -> State:
     return data_structures.to_immutable_dict(self.__constants)
コード例 #21
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             no_scan: bool = False,
             accumulate: Iterable[str] = ["log_det", "aux_loss"],
             **kwargs) -> Mapping[str, jnp.ndarray]:
        if Layer._is_initializing:
            return self.call_no_scan(inputs, rng, sample=sample, **kwargs)

        # Want to make sure that we're passing all inputs/outputs to the next layer
        final_outputs = inputs.copy()

        # Need to get the funcitonal apply fun
        with make_functional_modules([self.layer_create_fun()]) as ([apply_fun], \
                                                                    params, \
                                                                    (state, constants, rng_seq), \
                                                                    finalize):
            # Retrieve the hashes of the names of the parameters and states for the layer call
            param_hashes, state_hashes = get_constant(
                "param_state_name_hashes", None)

            # Batch together the parameters and state across the repeated layers
            scan_params = _batch_repeated_layers(params, param_hashes)
            scan_params = data_structures.to_immutable_dict(scan_params)
            scan_state = _batch_repeated_layers(state, state_hashes)

            # Reverse the order if we are sampling
            if sample == True:
                scan_params = jax.tree_map(lambda x: x[::-1], scan_params)
                scan_state = jax.tree_map(lambda x: x[::-1], scan_state)

            # Pass other inputs we might have through the network
            shared_inputs = inputs.copy()
            del shared_inputs["x"]

            # Use a scan loop so that we only need to compile layer once!
            def scan_body(carry, scan_inputs):
                x = carry
                params, state, rng = scan_inputs

                # Bundle the non-parameter state together
                bundled_state = (state, constants, rng_seq)

                # Make sure that we're passing all of the inputs (such as labels) to the layer
                inputs = shared_inputs.copy()
                inputs["x"] = x

                # Run the function
                outputs, bundled_state = apply_fun(params,
                                                   bundled_state,
                                                   inputs,
                                                   rng,
                                                   sample=sample,
                                                   **kwargs)

                # Retrieve the state because it might have changed
                state, _, _ = bundled_state

                # Return the stuff we need
                x = outputs["x"]
                del outputs["x"]
                return x, (outputs, state)

            # Run the scan function
            rngs = random.split(
                rng,
                self.n_repeats) if rng is not None else [None] * self.n_repeats
            x, (batched_outputs, batched_updated_state) = jax.lax.scan(
                scan_body, inputs["x"], (scan_params, scan_state, rngs))

            # Reverse the updated state if we are sampling
            if sample == True:
                batched_updated_state = jax.tree_map(lambda x: x[::-1],
                                                     batched_updated_state)

            # Search through the outputs to find things we want to accumulate
            accumulated_outputs = {}
            for name in accumulate:
                if name in batched_outputs:
                    accumulated_outputs[name] = batched_outputs[name].sum(
                        axis=0)
                    del batched_outputs[name]

            # Convert the output of the scan into the same state data structure that was passed in.
            hash_map = {hash(k): k for k in state.keys()}
            rev_hash_map = {k: hash(k) for k in state.keys()}
            updated_state = state.copy()
            for base_layer_name, pytree in batched_updated_state.items():

                # Retrieve the names of each repeated layer
                layer_names = [
                    hash_map[k]
                    for k in state_hashes[rev_hash_map[base_layer_name]]
                ]

                # Split the batched parameters
                leaves, treedef = jax.tree_flatten(
                    batched_updated_state[base_layer_name])
                split_states = [
                    jax.tree_unflatten(treedef, [l[i] for l in leaves])
                    for i in range(self.n_repeats)
                ]

                # Update the state dictionary
                updated_state.update(dict(zip(layer_names, split_states)))

            # Just in case
            updated_state = jax.lax.stop_gradient(updated_state)

            # Only state might be different
            bundled_state = (updated_state, constants, rng_seq)
            finalize(params, bundled_state)

        outputs = {"x": x}
        outputs.update(accumulated_outputs)

        return outputs
コード例 #22
0
 def test_getitem_nested_immutable(self):
     f = data_structures.to_immutable_dict({"a": {"b": "c"}})
     with self.assertRaisesRegex(TypeError,
                                 "does not support item assignment"):
         f["a"]["b"] = "d"
コード例 #23
0
 def test_to_immutable_dict(self):
     before = {"a": {"b": 1, "c": 2}}
     after = data_structures.to_immutable_dict(before)
     self.assertEqual(before, after)
     self.assertEqual(type(after), FlatMapping)
     self.assertEqual(type(after["a"]), FlatMapping)