def tree_update(self, step, grad_tree, weight_tree, slots, opt_params): """Assembles node-local weight and slot updates for the full layer tree. Args: step: Current step number in the training process. grad_tree: Gradients for the entire model, in a tree that matches the model's layer structure. weight_tree: Current weights for the entire model, in a tree that matches the model's layer structure. slots: Optimizer slots. opt_params: Optimizer hyperparameters (e.g. learning rate, momentum). Returns: Tuple `(weights, slots)`, where `weights` are the optimizer-updated weights for the whole model (in a tree matching the model's layer structure) and `slots` are the updated optimizer slot values. """ grads_flat = fastmath.tree_flatten(grad_tree) grads_norm = self._l2_norm(grads_flat) if self._clip_grad_norm is not None: max_norm = self._clip_grad_norm grads_flat = [jnp.where(grads_norm < max_norm, # pylint: disable=g-complex-comprehension g, g * (max_norm / grads_norm)) for g in grads_flat] weights_flat = fastmath.tree_flatten(weight_tree) weights_norm = self._l2_norm(weights_flat) updated_pairs = [ self._update_and_check(step, grad, weight, slot, opt_params) for (grad, weight, slot) in zip(grads_flat, weights_flat, slots) ] new_weights_flat, self.slots = map(list, zip(*updated_pairs)) new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree) metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm} return new_weights, self.slots, metrics
def flatten_weights_and_state(weights, state): """Flatten weights and state into lists, excluding empty and cached ones.""" def _is_empty_weight(x): return (x is EMPTY_WEIGHTS or (isinstance(x, dict) and x == GET_WEIGHTS_FROM_CACHE)) flat_weights = [w for w in fastmath.tree_flatten(weights) if not _is_empty_weight(w)] def _is_empty_state(x): return (x is EMPTY_STATE or (isinstance(x, dict) and x == GET_STATE_FROM_CACHE)) flat_state = [s for s in fastmath.tree_flatten(state) if not _is_empty_state(s)] return flat_weights, flat_state
def parallel_generator(): generators = [] for f in fastmath.tree_flatten(fns): generators.append(f()) while True: for generator in generators: yield next(generator)
def _size_of_model(model): def _size(x): try: return x.size except Exception: # pylint: disable=broad-except return 0 sizes = fastmath.nested_map(_size, model.weights) total_size = sum(fastmath.tree_flatten(sizes)) return total_size
def _log_n_weights(self): """"Logs the number of weights in the training model.""" def _size(x): try: return x.size except Exception: # pylint: disable=broad-except return 0 sizes = fastmath.nested_map(_size, self._model.weights) total_size = sum(fastmath.tree_flatten(sizes)) self._log_step('Total number of trainable weights: %d' % total_size)
def l2_norm(tree): """Returns an L2 norm computed over all elements of all tensors in `tree`. Args: tree: Tree-structured collection of tensors, e.g., model weights matching the model's layer structure. Returns: A scalar value computed as if all the tensors in `tree` were combined and flattened into a single vector, and then the L2 norm of that vector was calculated. """ leaves = fastmath.tree_flatten(tree) return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def tree_init(self, weight_tree): """Assembles node-local initializations into full-tree initialization. Args: weight_tree: Weights for an entire model, in a tree that matches the model's layer structure. Returns: Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer slot values and `opt_params` are optimizer hyperparameters (e.g., learning rate, momentum). """ self._slots = tuple( self.init(weight) for weight in fastmath.tree_flatten(weight_tree)) return (self._slots, self._init_opt_params)
def _free_accelerators(self, exceptions=(), keep_constants=True): """Deletes all live buffers from accelerator with no safety guarantees.""" backend = jax.lib.xla_bridge.get_backend() live_buffers = backend.live_buffers() logging.info('Deleting %d live buffers.', len(live_buffers)) exceptions_buffers = [] for x in fastmath.tree_flatten(exceptions): if hasattr(x, 'device_buffer'): # DeviceArray exceptions_buffers.append(x.device_buffer) if hasattr(x, 'device_buffers'): # ShardedDeviceArray exceptions_buffers.extend(x.device_buffers) for b in live_buffers: should_delete = True for e in exceptions_buffers: if b is e: should_delete = False if keep_constants and not b.shape: should_delete = False if should_delete: b.delete()
def composed_fns(generator=None): for f in fastmath.tree_flatten(fns): generator = f(generator) return generator
def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" leaves = fastmath.tree_flatten(tree) return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
def Serial(*fns): # pylint: disable=invalid-name """Creates an input pipeline by running all functions one after another.""" generator = None for f in fastmath.tree_flatten(fns): generator = f(generator) return generator