def _l2_norm(self, flat_list): """Returns the aggregate L2 norm of a list of tensors.""" if math.backend_name() == 'jax': norm = np.sqrt(sum(np.vdot(x, x) for x in flat_list)) else: # TODO(lukaszkaiser): add vdot to TF-numpy norm = np.sqrt(sum(np.sum(x * x) for x in flat_list)) return norm
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree.""" grads_flat = _tree_flatten(grad_tree) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm norm = np.sqrt(sum(np.vdot(x, x) for x in grads_flat)) grads_flat = [ np.where(norm < max_norm, g, g * (max_norm / norm)) for g in grads_flat ] weights_flat = _tree_flatten(weight_tree) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = zip(*updated_pairs) new_weights, _ = _tree_unflatten(new_weights_flat, weight_tree) return new_weights, self.slots
def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" leaves = math.tree_flatten(tree) return np.sqrt(sum(np.vdot(x, x) for x in leaves))