Пример #1
0
def unflatten_weights_and_state(
    flat_weights, flat_state, weights_and_state_signature, weights_only=False):
  """Un-flatten weights and state given their signatures."""
  weights_tree, state_tree = weights_and_state_signature
  weights_to_copy = [EMPTY_WEIGHTS, GET_WEIGHTS_FROM_CACHE]
  weights, _ = fastmath.tree_unflatten(flat_weights, weights_tree,
                                       copy_from_tree=weights_to_copy)
  state = None
  if not weights_only:
    states_to_copy = [EMPTY_STATE, GET_STATE_FROM_CACHE]
    state, _ = fastmath.tree_unflatten(flat_state, state_tree,
                                       copy_from_tree=states_to_copy)
  return weights, state
Пример #2
0
  def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
    """Assembles node-local weight and slot updates for the full layer tree.

    Args:
      step: Current step number in the training process.
      grad_tree: Gradients for the entire model, in a tree that matches the
          model's layer structure.
      weight_tree: Current weights for the entire model, in a tree that matches
          the model's layer structure.
      slots: Optimizer slots.
      opt_params: Optimizer hyperparameters (e.g. learning rate, momentum).

    Returns:
      Tuple `(weights, slots)`, where `weights` are the optimizer-updated
      weights for the whole model (in a tree matching the model's layer
      structure) and `slots` are the updated optimizer slot values.
    """
    grads_flat = fastmath.tree_flatten(grad_tree)
    grads_norm = self._l2_norm(grads_flat)
    if self._clip_grad_norm is not None:
      max_norm = self._clip_grad_norm
      grads_flat = [jnp.where(grads_norm < max_norm,  # pylint: disable=g-complex-comprehension
                              g,
                              g * (max_norm / grads_norm))
                    for g in grads_flat]
    weights_flat = fastmath.tree_flatten(weight_tree)
    weights_norm = self._l2_norm(weights_flat)
    updated_pairs = [
        self._update_and_check(step, grad, weight, slot, opt_params)
        for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
    ]
    new_weights_flat, self.slots = map(list, zip(*updated_pairs))
    new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree)
    metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm}
    return new_weights, self.slots, metrics