Ejemplo n.º 1
0
 def _l2_norm(self, flat_list):
     """Returns the aggregate L2 norm of a list of tensors."""
     if math.backend_name() == 'jax':
         norm = np.sqrt(sum(np.vdot(x, x) for x in flat_list))
     else:  # TODO(lukaszkaiser): add vdot to TF-numpy
         norm = np.sqrt(sum(np.sum(x * x) for x in flat_list))
     return norm
Ejemplo n.º 2
0
 def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
     """Assembles node-local weight and slot updates for the full layer tree."""
     grads_flat = _tree_flatten(grad_tree)
     if self._clip_grad_norm is not None:
         max_norm = self._clip_grad_norm
         norm = np.sqrt(sum(np.vdot(x, x) for x in grads_flat))
         grads_flat = [
             np.where(norm < max_norm, g, g * (max_norm / norm))
             for g in grads_flat
         ]
     weights_flat = _tree_flatten(weight_tree)
     updated_pairs = [
         self._update_and_check(step, grad, weight, slot, opt_params)
         for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
     ]
     new_weights_flat, self.slots = zip(*updated_pairs)
     new_weights, _ = _tree_unflatten(new_weights_flat, weight_tree)
     return new_weights, self.slots
Ejemplo n.º 3
0
def l2_norm(tree):
    """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
    leaves = math.tree_flatten(tree)
    return np.sqrt(sum(np.vdot(x, x) for x in leaves))