Example #1
0
  def tree_update(self, step, grad_tree, weight_tree, slots, opt_params):
    """Assembles node-local weight and slot updates for the full layer tree.

    Args:
      step: Current step number in the training process.
      grad_tree: Gradients for the entire model, in a tree that matches the
          model's layer structure.
      weight_tree: Current weights for the entire model, in a tree that matches
          the model's layer structure.
      slots: Optimizer slots.
      opt_params: Optimizer hyperparameters (e.g. learning rate, momentum).

    Returns:
      Tuple `(weights, slots)`, where `weights` are the optimizer-updated
      weights for the whole model (in a tree matching the model's layer
      structure) and `slots` are the updated optimizer slot values.
    """
    grads_flat = fastmath.tree_flatten(grad_tree)
    grads_norm = self._l2_norm(grads_flat)
    if self._clip_grad_norm is not None:
      max_norm = self._clip_grad_norm
      grads_flat = [jnp.where(grads_norm < max_norm,  # pylint: disable=g-complex-comprehension
                              g,
                              g * (max_norm / grads_norm))
                    for g in grads_flat]
    weights_flat = fastmath.tree_flatten(weight_tree)
    weights_norm = self._l2_norm(weights_flat)
    updated_pairs = [
        self._update_and_check(step, grad, weight, slot, opt_params)
        for (grad, weight, slot) in zip(grads_flat, weights_flat, slots)
    ]
    new_weights_flat, self.slots = map(list, zip(*updated_pairs))
    new_weights, _ = fastmath.tree_unflatten(new_weights_flat, weight_tree)
    metrics = {'gradients_l2': grads_norm, 'weights_l2': weights_norm}
    return new_weights, self.slots, metrics
Example #2
0
def flatten_weights_and_state(weights, state):
  """Flatten weights and state into lists, excluding empty and cached ones."""
  def _is_empty_weight(x):
    return (x is EMPTY_WEIGHTS or
            (isinstance(x, dict) and x == GET_WEIGHTS_FROM_CACHE))
  flat_weights = [w for w in fastmath.tree_flatten(weights)
                  if not _is_empty_weight(w)]
  def _is_empty_state(x):
    return (x is EMPTY_STATE or
            (isinstance(x, dict) and x == GET_STATE_FROM_CACHE))
  flat_state = [s for s in fastmath.tree_flatten(state)
                if not _is_empty_state(s)]
  return flat_weights, flat_state
Example #3
0
 def parallel_generator():
     generators = []
     for f in fastmath.tree_flatten(fns):
         generators.append(f())
     while True:
         for generator in generators:
             yield next(generator)
Example #4
0
def _size_of_model(model):
  def _size(x):
    try:
      return x.size
    except Exception:  # pylint: disable=broad-except
      return 0
  sizes = fastmath.nested_map(_size, model.weights)
  total_size = sum(fastmath.tree_flatten(sizes))
  return total_size
Example #5
0
 def _log_n_weights(self):
   """"Logs the number of weights in the training model."""
   def _size(x):
     try:
       return x.size
     except Exception:  # pylint: disable=broad-except
       return 0
   sizes = fastmath.nested_map(_size, self._model.weights)
   total_size = sum(fastmath.tree_flatten(sizes))
   self._log_step('Total number of trainable weights: %d' % total_size)
Example #6
0
def l2_norm(tree):
    """Returns an L2 norm computed over all elements of all tensors in `tree`.

  Args:
    tree: Tree-structured collection of tensors, e.g., model weights matching
        the model's layer structure.

  Returns:
    A scalar value computed as if all the tensors in `tree` were combined
    and flattened into a single vector, and then the L2 norm of that vector
    was calculated.
  """
    leaves = fastmath.tree_flatten(tree)
    return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
Example #7
0
    def tree_init(self, weight_tree):
        """Assembles node-local initializations into full-tree initialization.

    Args:
      weight_tree: Weights for an entire model, in a tree that matches the
          model's layer structure.

    Returns:
      Tuple `(slots, opt_params)`, where `slots` are the initialized optimizer
      slot values and `opt_params` are optimizer hyperparameters (e.g.,
      learning rate, momentum).
    """
        self._slots = tuple(
            self.init(weight) for weight in fastmath.tree_flatten(weight_tree))
        return (self._slots, self._init_opt_params)
Example #8
0
 def _free_accelerators(self, exceptions=(), keep_constants=True):
     """Deletes all live buffers from accelerator with no safety guarantees."""
     backend = jax.lib.xla_bridge.get_backend()
     live_buffers = backend.live_buffers()
     logging.info('Deleting %d live buffers.', len(live_buffers))
     exceptions_buffers = []
     for x in fastmath.tree_flatten(exceptions):
         if hasattr(x, 'device_buffer'):  # DeviceArray
             exceptions_buffers.append(x.device_buffer)
         if hasattr(x, 'device_buffers'):  # ShardedDeviceArray
             exceptions_buffers.extend(x.device_buffers)
     for b in live_buffers:
         should_delete = True
         for e in exceptions_buffers:
             if b is e:
                 should_delete = False
         if keep_constants and not b.shape:
             should_delete = False
         if should_delete:
             b.delete()
Example #9
0
 def composed_fns(generator=None):
     for f in fastmath.tree_flatten(fns):
         generator = f(generator)
     return generator
Example #10
0
def l2_norm(tree):
  """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
  leaves = fastmath.tree_flatten(tree)
  return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
Example #11
0
def Serial(*fns):  # pylint: disable=invalid-name
    """Creates an input pipeline by running all functions one after another."""
    generator = None
    for f in fastmath.tree_flatten(fns):
        generator = f(generator)
    return generator