Exemple #1
0
def flatten_weights_and_state(weights, state):
    """Flatten weights and state into lists, excluding empty and cached ones."""
    flat_weights = [
        w for w in math.tree_flatten(weights)
        if not (w is EMPTY_WEIGHTS or w is GET_WEIGHTS_FROM_CACHE)
    ]
    flat_state = [
        s for s in math.tree_flatten(state)
        if not (s is EMPTY_STATE or s is GET_STATE_FROM_CACHE)
    ]
    return flat_weights, flat_state
Exemple #2
0
 def tree_init(self, weight_tree):
     """Assembles node-local initializations into full-tree initialization."""
     self._slots = [
         self.init(weight) for weight in math.tree_flatten(weight_tree)
     ]
     return (
         self._slots,
         self._init_opt_params,
     )
Exemple #3
0
 def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
     """Assembles node-local weight and slot updates for the full layer tree."""
     grads_flat = math.tree_flatten(grad_tree)
     grads_norm = self._l2_norm(grads_flat)
     if self._clip_grad_norm is not None:
         max_norm = self._clip_grad_norm
         grads_flat = [
             np.where(
                 grads_norm < max_norm,  # pylint: disable=g-complex-comprehension
                 g,
                 g * (max_norm / grads_norm)) for g in grads_flat
         ]
     weights_flat = math.tree_flatten(weight_tree)
     weights_norm = self._l2_norm(weights_flat)
     updated_pairs = [
         self._update_and_check(step, grad, weight, slot, opt_params)
         for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
     ]
     new_weights_flat, self.slots = zip(*updated_pairs)
     new_weights, _ = math.tree_unflatten(new_weights_flat, weight_tree)
     metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm}
     return new_weights, self.slots, metrics
Exemple #4
0
    def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
        """Assembles node-local weight and slot updates for the full layer tree.

    Args:
      step: Current step number in the training process.
      grad_tree: Gradients for the entire model, in a tree that matches the
          model's layer structure.
      weight_tree: Current weights for the entire model, in a tree that matches
          the model's layer structure.
      slots: Optimizer slots.
      opt_params: Optimizer hyperparameters (e.g. learning rate, momentum).

    Returns:
      Tuple `(weights, slots)`, where `weights` are the optimizer-updated
      weights for the whole model (in a tree matching the model's layer
      structure) and `slots` are the updated optimizer slot values.
    """
        grads_flat = math.tree_flatten(grad_tree)
        grads_norm = self._l2_norm(grads_flat)
        if self._clip_grad_norm is not None:
            max_norm = self._clip_grad_norm
            grads_flat = [
                np.where(
                    grads_norm < max_norm,  # pylint: disable=g-complex-comprehension
                    g,
                    g * (max_norm / grads_norm)) for g in grads_flat
            ]
        weights_flat = math.tree_flatten(weight_tree)
        weights_norm = self._l2_norm(weights_flat)
        updated_pairs = [
            self._update_and_check(step, grad, weight, slot, opt_params)
            for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
        ]
        new_weights_flat, self.slots = zip(*updated_pairs)
        new_weights, _ = math.tree_unflatten(new_weights_flat, weight_tree)
        metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm}
        return new_weights, self.slots, metrics
Exemple #5
0
    def tree_init(self, weight_tree):
        """Assembles node-local initializations into full-tree initialization.

    Args:
      weight_tree: Weights for an entire model, in a tree that matches the
          model's layer structure.

    Returns:
      Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer
      slot values and `opt_params` are optimizer hyperparameters (e.g.,
      learning rate, momentum).
    """
        self._slots = [
            self.init(weight) for weight in math.tree_flatten(weight_tree)
        ]
        return (
            self._slots,
            self._init_opt_params,
        )
Exemple #6
0
def l2_norm(tree):
    """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
    leaves = math.tree_flatten(tree)
    return np.sqrt(sum(np.vdot(x, x) for x in leaves))