def flatten_weights_and_state(weights, state): """Flatten weights and state into lists, excluding empty and cached ones.""" flat_weights = [ w for w in math.tree_flatten(weights) if not (w is EMPTY_WEIGHTS or w is GET_WEIGHTS_FROM_CACHE) ] flat_state = [ s for s in math.tree_flatten(state) if not (s is EMPTY_STATE or s is GET_STATE_FROM_CACHE) ] return flat_weights, flat_state
def tree_init(self, weight_tree): """Assembles node-local initializations into full-tree initialization.""" self._slots = [ self.init(weight) for weight in math.tree_flatten(weight_tree) ] return ( self._slots, self._init_opt_params, )
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree.""" grads_flat = math.tree_flatten(grad_tree) grads_norm = self._l2_norm(grads_flat) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm grads_flat = [ np.where( grads_norm < max_norm, # pylint: disable=g-complex-comprehension g, g * (max_norm / grads_norm)) for g in grads_flat ] weights_flat = math.tree_flatten(weight_tree) weights_norm = self._l2_norm(weights_flat) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = zip(*updated_pairs) new_weights, _ = math.tree_unflatten(new_weights_flat, weight_tree) metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} return new_weights, self.slots, metrics
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree. Args: step: Current step number in the training process. grad_tree: Gradients for the entire model, in a tree that matches the model's layer structure. weight_tree: Current weights for the entire model, in a tree that matches the model's layer structure. slots: Optimizer slots. opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). Returns: Tuple `(weights, slots)`, where `weights` are the optimizer-updated weights for the whole model (in a tree matching the model's layer structure) and `slots` are the updated optimizer slot values. """ grads_flat = math.tree_flatten(grad_tree) grads_norm = self._l2_norm(grads_flat) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm grads_flat = [ np.where( grads_norm < max_norm, # pylint: disable=g-complex-comprehension g, g * (max_norm / grads_norm)) for g in grads_flat ] weights_flat = math.tree_flatten(weight_tree) weights_norm = self._l2_norm(weights_flat) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = zip(*updated_pairs) new_weights, _ = math.tree_unflatten(new_weights_flat, weight_tree) metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} return new_weights, self.slots, metrics
def tree_init(self, weight_tree): """Assembles node-local initializations into full-tree initialization. Args: weight_tree: Weights for an entire model, in a tree that matches the model's layer structure. Returns: Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer slot values and `opt_params` are optimizer hyperparameters (e.g., learning rate, momentum). """ self._slots = [ self.init(weight) for weight in math.tree_flatten(weight_tree) ] return ( self._slots, self._init_opt_params, )
def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" leaves = math.tree_flatten(tree) return np.sqrt(sum(np.vdot(x, x) for x in leaves))