def unflatten_weights_and_state( flat_weights, flat_state, weights_and_state_signature, weights_only=False): """Un-flatten weights and state given their signatures.""" weights_tree, state_tree = weights_and_state_signature weights_to_copy = [EMPTY_WEIGHTS, GET_WEIGHTS_FROM_CACHE] weights, _ = fastmath.tree_unflatten(flat_weights, weights_tree, copy_from_tree=weights_to_copy) state = None if not weights_only: states_to_copy = [EMPTY_STATE, GET_STATE_FROM_CACHE] state, _ = fastmath.tree_unflatten(flat_state, state_tree, copy_from_tree=states_to_copy) return weights, state
def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree. Args: step: Current step number in the training process. grad_tree: Gradients for the entire model, in a tree that matches the model's layer structure. weight_tree: Current weights for the entire model, in a tree that matches the model's layer structure. slots: Optimizer slots. opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). Returns: Tuple `(weights, slots)`, where `weights` are the optimizer-updated weights for the whole model (in a tree matching the model's layer structure) and `slots` are the updated optimizer slot values. """ grads_flat = fastmath.tree_flatten(grad_tree) grads_norm = self._l2_norm(grads_flat) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm grads_flat = [jnp.where(grads_norm < max_norm, # pylint: disable=g-complex-comprehension g, g * (max_norm / grads_norm)) for g in grads_flat] weights_flat = fastmath.tree_flatten(weight_tree) weights_norm = self._l2_norm(weights_flat) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = map(list, zip(*updated_pairs)) new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree) metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} return new_weights, self.slots, metrics