Пример #1
0
def _accelerate(f, n_devices):
  """Returns an accelerated version of ``f`` running on ``n_devices``."""
  if n_devices == 0:  # no accelerators - run on CPU
    return fastmath.jit(f, device=jax.devices('cpu')[0])

  if n_devices == 1:
    return fastmath.jit(f)

  return fastmath.pmap(f, axis_name='batch')
Пример #2
0
  def __init__(self, model, task, eval_model=None, eval_task=None,
               output_dir=None, checkpoint_at=None, eval_at=None):
    """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
    """
    self._task = task
    self._model = model
    self._model_in_training = tl.Serial(model, task.loss_layer)
    self._eval_model = model if eval_model is None else eval_model
    self._eval_task = eval_task
    self._output_dir = os.path.expanduser(output_dir) if output_dir else None
    default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint)
    self._checkpoint_at = checkpoint_at or default_fn
    self._eval_at = eval_at or default_fn
    if eval_task is None:
      self._eval_at = _never
    self._step = 0

    batch_signature = shapes.signature(task.sample_batch)
    self._batch_signature = batch_signature
    # Initialize the model and the optimizer; discard the return values
    # (model weights/state, optimizer slots/params), since they're available
    # from the model and optimizer objects.
    _, _ = self._model_in_training.init(batch_signature)
    _, _ = task.optimizer.tree_init(self._model_in_training.weights)

    self._gradients_and_state_fn = (
        fastmath.jit(fastmath.grad(self._model_in_training.pure_fn,
                                   argnums=1,  # arg1 of pure_fn: weights
                                   has_aux=True)))  # return (gradients, state)

    if eval_task is not None:
      model_with_metrics = _model_with_metrics(self._eval_model, eval_task)
      self._eval_weights = model_with_metrics.weights[1]  # just the eval part
      self._eval_state = model_with_metrics.state[1]  # just the eval part
      self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
Пример #3
0
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True):
    """Returns a (JIT-compiled) function that computes the loss for one step."""
    if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.

        def single_compute_loss(opt_state, batch, state, rng):
            rng, subrng = jax_random.split(rng[0])
            loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state,
                                      rng)
            return loss_val, state, [subrng]

        return fastmath.jit(
            single_compute_loss) if jit else single_compute_loss

    # Else, for n_devices > 1:
    @functools.partial(fastmath.pmap, axis_name='batch')
    def mapped_compute_loss(opt_state, batch, state, rng):
        """This is a multi-device version of the update function above."""
        # We assume all tensors have the first dimension = n_devices.
        rng, subrng = jax_random.split(rng)
        loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
        return loss_val, state, subrng

    def compute_loss(opt_state, batch, state, rng):
        return mapped_compute_loss(opt_state,
                                   _reshape_by_device(batch, n_devices), state,
                                   rng)

    return compute_loss
Пример #4
0
 def _pjit(self, f, donate_argnums=()):
     """JIT f if 1 device is available and pmap if more are available."""
     if self._n_devices == 1:
         return fastmath.jit(f, donate_argnums=donate_argnums)
     else:
         return fastmath.pmap(f,
                              axis_name='batch',
                              donate_argnums=donate_argnums)
Пример #5
0
def _accelerate_update_fn(forward_and_backward_fn,
                          optimizer,
                          n_devices,
                          accelerate=True,
                          adasum=False):
    """Accelerates the given forward_and_backward_fn function."""
    if n_devices == 1:

        def single_device_update_fn(weights_and_slots, step, opt_params, batch,
                                    state, rng):
            step = jnp.array(step, dtype=jnp.int32)  # Needed in TFNP backend.
            weights, slots = weights_and_slots
            (loss, state), gradients = forward_and_backward_fn(
                batch, weights, state, rng)
            weights, slots, stats = optimizer.tree_update(step,
                                                          gradients,
                                                          weights,
                                                          slots,
                                                          opt_params,
                                                          store_slots=False)
            stats['loss'] = loss
            return (weights, slots), state, stats

        if accelerate:
            # TODO(afrozm): Find out the status of buffer donation on GPUs, then do
            #  donate_argnums=(0,).
            single_device_update_fn = fastmath.jit(single_device_update_fn)
        return single_device_update_fn

    # More than one device (core), i.e. all of TPU configurations etc.
    assert n_devices > 1, f'{n_devices} should be greater than 1.'

    @functools.partial(fastmath.pmap, axis_name='batch', donate_argnums=(0, ))
    def _multi_device_update_fn(weights_and_slots, step, opt_params, batch,
                                state, rng):
        # All tensors should have the first dimension = n_devices.
        weights, slots = weights_and_slots
        (loss,
         state), gradients = (forward_and_backward_fn(batch, weights, state,
                                                      rng))
        gradients = _average_multidevice_gradients(gradients, adasum=adasum)
        weights, slots, stats = optimizer.tree_update(step,
                                                      gradients,
                                                      weights,
                                                      slots,
                                                      opt_params,
                                                      store_slots=False)
        stats['loss'] = loss
        return (weights, slots), state, stats

    def multi_device_update_fn(weights_and_slots, step, opt_params, batch,
                               state, rng):
        # Need to replicate step to n_devices leading dimension.
        return _multi_device_update_fn(weights_and_slots,
                                       jnp.repeat(step, n_devices), opt_params,
                                       batch, state, rng)

    return multi_device_update_fn
Пример #6
0
def _accelerate_update_fn(forward_and_backward_fn,
                          optimizer,
                          n_devices,
                          accelerate=True):
    """Accelerate the given forward_and_backward_fn function."""
    if n_devices == 1:

        def single_device_update_fn(weights_and_slots, step, opt_params, batch,
                                    state, rng):
            step = jnp.array(step, dtype=jnp.int32)  # Needed in TFNP backend.
            weights, slots = weights_and_slots
            (loss, state), gradients = forward_and_backward_fn(
                batch, weights, state, rng)
            weights, slots, stats = optimizer.tree_update(
                step, gradients, weights, slots, opt_params)
            stats['loss'] = loss
            return (weights, slots), state, stats

        if accelerate:
            # TODO(afrozm): Find out the status of buffer donation on GPUs, then do
            #  donate_argnums=(0,).
            single_device_update_fn = fastmath.jit(single_device_update_fn)
        return single_device_update_fn

    # More than one device (core), i.e. all of TPU configurations etc.
    assert n_devices > 1, f'{n_devices} should be greater than 1.'

    @functools.partial(fastmath.pmap, axis_name='batch', donate_argnums=(0, ))
    def _multi_device_update_fn(weights_and_slots, step, opt_params, batch,
                                state, rng):
        # We assume all tensors have the first dimension = n_devices.
        weights, slots = weights_and_slots
        (loss,
         state), gradients = forward_and_backward_fn(batch, weights, state,
                                                     rng)

        # gradients now need to be summed over all the devices across different host
        # machines, n_devices is only the number of devices on *this* host machine.
        gradients = fastmath.psum(gradients, 'batch')
        n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
        # Average across hosts.
        gradients = jax.tree_util.tree_map(lambda g: g / n_devices_total,
                                           gradients)

        weights, slots, stats = optimizer.tree_update(step, gradients, weights,
                                                      slots, opt_params)
        stats['loss'] = loss
        return (weights, slots), state, stats

    def multi_device_update_fn(weights_and_slots, step, opt_params, batch,
                               state, rng):
        # Need to replicate step to n_devices leading dimension.
        return _multi_device_update_fn(weights_and_slots,
                                       jnp.repeat(step, n_devices), opt_params,
                                       batch, state, rng)

    return multi_device_update_fn
Пример #7
0
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
    """Returns a (JIT-compiled) function that computes updates for one step."""
    model_and_loss = tl.Serial(predict_fn, loss_fn)

    # Gradients are always wrt. the first argument, so putting weights first.
    def model_and_loss_call(weights, batch, state, rng):
        res = model_and_loss(batch, weights=weights, state=state, rng=rng)
        return res, model_and_loss.state

    if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.

        def single_update(weights_and_slots, i, opt_params, batch, state, rng):
            weights, slots = weights_and_slots
            rng, subrng = jax_random.split(rng[0])
            grad_fn = fastmath.grad(model_and_loss_call, has_aux=True)
            grads, state = grad_fn(weights, batch, state, rng)
            new_weights, new_slots, stats = optimizer.tree_update(
                i, grads, weights, slots, opt_params)
            return (new_weights, new_slots), stats, state, [subrng]

        if jit:
            # TODO(lukaszkaiser): donate_argnums=(0,) when XLA supports it on GPU
            return fastmath.jit(single_update)
        else:
            return single_update

    # Else, for n_devices > 1:
    @functools.partial(fastmath.pmap,
                       axis_name='batch')  # donate_argnums=(0,))
    def mapped_update(weights_and_slots, i, opt_params, batch, state, rng):
        """This is a multi-device version of the update function above."""
        # We assume all tensors have the first dimension = n_devices.
        weights, slots = weights_and_slots
        rng, subrng = jax_random.split(rng)
        grad_fn = fastmath.grad(model_and_loss_call, has_aux=True)
        grads, state = grad_fn(weights, batch, state, rng)
        # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
        # the number of devices on this host machine, however psum goes over all
        # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
        # of them.
        #
        # Collect all gradients.
        grads = fastmath.psum(grads, 'batch')
        n_devices_total = fastmath.psum(np.array(1.0), 'batch')
        # Average across hosts.
        grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads)

        new_weights, new_slots, stats = optimizer.tree_update(
            i, grads, weights, slots, opt_params)
        return (new_weights, new_slots), stats, state, subrng

    def update(weights_and_slots, i, opt_params, batch, state, rng):
        return mapped_update(weights_and_slots, np.repeat(i, n_devices),
                             opt_params, batch, state, rng)

    return update
Пример #8
0
 def _pjit(self, f, memory_key=None, donate_argnums=()):
   """JIT f if 1 device is available and pmap if more are available."""
   should_memoize = self._jit_memory is not None and memory_key is not None
   if (should_memoize and memory_key in self._jit_memory):
     logging.info('Found JITed function in memory for: %s', memory_key)
     return self._jit_memory[memory_key]
   if self._n_devices == 1:
     res = fastmath.jit(f, donate_argnums=donate_argnums)
   else:
     res = fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums)
   if should_memoize:
     self._jit_memory[memory_key] = res
   return res
Пример #9
0
 def _pjit(self, f):
   """JIT f if 1 device is available and pmap if more are available."""
   if self._n_devices == 1:
     return fastmath.jit(f)
   else:
     return fastmath.pmap(f, axis_name='batch')
Пример #10
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at if checkpoints_at is not None else []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Пример #11
0
    def __init__(self,
                 blocks,
                 loss_layer,
                 optimizer_fn,
                 n_devices=None,
                 memoize_jit=True,
                 free_accelerators_on_step=False,
                 adasum=False):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
      memoize_jit: Whether to memoize JITed functions; this significantly speeds
        up XLA compilation of larger models, but it uses `repr(layer)` as keys
        to memoize so it could fail if two layers with different functionality
        had the same string representaion. We have not encountered such case
        yet so this is turned on by default, but consider turning it off or
        reviewing your model if you use custom layers and encounter a problem.
      free_accelerators_on_step: If true, frees memory on accelerators when
        starting a step. All layers and arguments must be on host for that,
        otherwise it can lead to failures. Can prevent memory fragmentation.
      adasum: if True, use adaptive summation to gather multi-device gradients.
    """
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
        self._n_steps_per_log = 100  # Log layers and stats every 100 steps.
        self._n_async_layers = 1  # How many layers to run asynchronously.
        self._jit_memory = {} if memoize_jit else None
        self._do_free = free_accelerators_on_step
        self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs,
                                                 backend='cpu')

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'),
            self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            opt.slots = tl.on_cpu(opt.slots)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate_cpu(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer,
                                              std_opt,
                                              self._n_devices,
                                              adasum=self._adasum)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
                    layer, opt, self._n_devices, self._adasum)
                # The donated args are (outputs, weights, grads) and we can donate
                # them because weights and grads are immediately replaced and in
                # case of reversible layers, the outputs are never used again.
                rev_and_fbos.append(
                    self._pjit(rev_and_fbo,
                               f'rev+bwd {repr(layer)}',
                               donate_argnums=(0, 1, 2)))
            # In standard layers, the inputs cannot be donated as they may be used
            # as outputs for the reversible block below, but weights and grads can.
            jit_std_fbo = self._pjit(std_fbo,
                                     f'bwd {repr(std_layer)}',
                                     donate_argnums=(1, 2))
            self._fbos.append((jit_std_fbo, rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss',
                                           self._adasum)
        self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
Пример #12
0
    def __init__(self,
                 model,
                 task,
                 eval_model=None,
                 eval_task=None,
                 output_dir=None,
                 checkpoint_at=None,
                 eval_at=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
    """
        self._task = task
        self._model = model
        self._eval_model = eval_model or model
        default_at = (_at_step_1_and_every_nth_step(
            self._task.n_steps_per_checkpoint))
        if output_dir is not None:
            self._output_dir = os.path.expanduser(output_dir)
            tf.io.gfile.makedirs(self._output_dir)
        else:
            self._output_dir = None

        # Prepare training components.
        self._step = 0
        self._checkpoint_at = checkpoint_at or default_at
        self._model_in_training = tl.Serial(self._model, self._task.loss_layer)
        self._batch_signature = shapes.signature(self._task.sample_batch)
        self._eval_model.init(self._batch_signature)
        self._model_in_training.init(self._batch_signature)
        self._task.optimizer.tree_init(self._model_in_training.weights)
        self._forward_and_backward_fn = (
            fastmath.jit(
                fastmath.value_and_grad(
                    self._model_in_training.pure_fn,
                    argnums=1,  # arg1 of pure_fn: weights
                    has_aux=True)))  # return (loss, state), gradients

        # Prepare eval components.
        if eval_task is None:
            self._eval_at = _never
        else:
            self._eval_task = eval_task
            self._eval_at = eval_at or default_at
            metric_name_lengths = [
                len(name) for name in self._eval_task.metric_names
            ]
            self._rjust_len = max([len(self._task.loss_layer.name)] +
                                  metric_name_lengths)
            model_with_metrics = (_model_with_metrics(self._eval_model,
                                                      self._eval_task))
            self._eval_weights = model_with_metrics.weights[
                1]  # just the eval part
            self._eval_state = model_with_metrics.state[
                1]  # just the eval part
            self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
            if self._output_dir is None:
                _log(
                    'Will not write evaluation metrics, because output_dir is None.'
                )
Пример #13
0
  def one_step(self, batch, rng, step=0, learning_rate=None):
    """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
    # Update the learning rate if needed.
    if learning_rate is not None:
      self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices(
          learning_rate, self._n_devices)
      for (std_op, rev_ops) in self._replicated_opt_params:
        std_op['learning_rate'] = tl.for_n_devices(
            learning_rate, self._n_devices)
        for op in rev_ops:
          op['learning_rate'] = tl.for_n_devices(
              learning_rate, self._n_devices)

    # Batch needs to be split across the local devices -- the difference
    # between _for_n_devices and _reshape_by_device is that the latter splits
    # the batch dim to batch // n_devices, vs _for_n_devices
    # broadcasts/replicates to n_devices dimension.
    step_int = step
    if self._n_devices > 1:
      batch = tl.reshape_by_device(batch, self._n_devices)
      step = np.repeat(step, self._n_devices)

    # Create separate rng for each device and layer.
    if self._n_devices == 1:
      rngs = fastmath.random.split(rng, self._n_layers)
    else:
      # Splitting by device first to be identical with default trainer.
      def per_device_rngs(rng):  # A function to JIT to not fragment memory.
        per_device_rng = fastmath.random.split(rng, self._n_devices)
        per_device_rngs = [
            fastmath.random.split(r, self._n_layers) for r in per_device_rng]
        rngs = [jnp.stack([r[i] for r in per_device_rngs])
                for i in range(self._n_layers)]
        return rngs
      # JIT the function and run it on CPU to avoid memory fragmentation.
      rngs = fastmath.jit(per_device_rngs, backend='cpu')(tl.on_cpu(rng))
    # Group rngs by layer blocks.
    rng_blocks, rng_i = [], 0
    for _, rev_layers in self._blocks:
      l = len(rev_layers)
      rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1]))
      rng_i += l + 1

    # Run the layers forward upto the loss layer.
    process = psutil.Process(os.getpid())
    logging.info('running step %d', step_int)
    if step_int % self._n_steps_per_log == 1:
      logging.info('run fwd: cpu memory use (MB): %.2f',
                   process.memory_info().rss / float(1024 * 1024))
    stack = batch
    block_inputs_states = []
    for i, (std_layer, rev_layers) in enumerate(self._blocks):
      acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i]
      std_rng, rev_rngs = rng_blocks[i]
      # Run the standard layer.
      stack, std_inputs, std_state = self._run_forward_standard(
          stack, std_layer, acc_std_layer_fn, std_rng, step_int)

      # Run the reversible layers and collect old and new states.
      stack, rev_old_states, rev_new_states = self._run_forward_reversible(
          stack, rev_layers, acc_rev_layer_fns, rev_rngs, step_int)
      block_inputs_states.append(
          ((std_inputs, std_state), (rev_old_states, rev_new_states)))

    # Run the loss layer forward and backward with optimizer update.
    if step_int % self._n_steps_per_log == 1:
      logging.info('run loss: cpu memory use (MB): %.2f',
                   process.memory_info().rss / float(1024 * 1024))
    loss_state = self._replicate(self._loss_layer.state)
    loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in)
    loss_stats, grad_stack = self._run_backward_standard(
        None, step, self._loss_layer, loss_inputs,
        loss_state, self._loss_fbo, rngs[-1], self._loss_opt,
        self._replicated_loss_opt_params)
    stats = [loss_stats]

    # Run the layers backward and run optimizer updates.
    if step_int % self._n_steps_per_log == 1:
      logging.info('run bwd: cpu memory use (MB): %.2f',
                   process.memory_info().rss / float(1024 * 1024))
    for i in range(len(self._blocks) - 1, -1, -1):
      std_layer, rev_layers = self._blocks[i]
      (std_inputs, std_state), (rev_old_states,
                                rev_new_states) = block_inputs_states[i]
      std_fbo, rev_fbos = self._fbos[i]
      std_opt, rev_opts = self._optimizers[i]
      std_rng, rev_rngs = rng_blocks[i]
      repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i]

      # Run reversible layers backward with optimizer update.
      stack, grad_stack, new_stats = self._run_backward_reversible(
          stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states,
          rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params)
      stats.extend(new_stats)

      # Run the standard layer forward-and-backward pass and optimizer update.
      std_layer_stats, grad_stack = self._run_backward_standard(
          grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng,
          std_opt, repl_std_opt_params)
      stack = cb.outputs_onto_stack(  # Put layer inputs on the stack.
          std_inputs, stack, std_layer.n_out)
      stats.append(std_layer_stats)

    # Join stats from different optimizers into one.
    joint_stats = {}
    for i, stat in enumerate(reversed(stats)):
      for k, v in stat.items():
        joint_stats[f'layer{i}/' + k] = v
    return stats[0]['loss'], joint_stats
Пример #14
0
def _accelerate(f, n_devices):
    """JIT-compiled version of `f` running on `n_devices`."""
    if n_devices == 1:
        return fastmath.jit(f)

    return fastmath.pmap(f, axis_name='batch')