Example #1
0
 def slots(self):
     """Returns the slots of all optimizers."""
     optimizers = list(self._optimizers) + [self._loss_opt]
     return fastmath.nested_map(lambda opt: opt.slots, optimizers)
Example #2
0
 def _collect_weights(self, layer):
     layer.weights = fastmath.nested_map(np.asarray, layer.weights)
Example #3
0
    def _unreplicate(self, x):
        if self.n_devices == 1:
            return x

        unreplicate_fn = lambda x: x[0]
        return fastmath.nested_map(unreplicate_fn, x)
Example #4
0
    def trajectory_stream(self,
                          epochs=None,
                          max_slice_length=None,
                          sample_trajectories_uniformly=False,
                          margin=0):
        """Return a stream of random trajectory slices from the specified epochs.

    Args:
      epochs: a list of epochs to use; we use all epochs if None
      max_slice_length: maximum length of the slices of trajectories to return
      sample_trajectories_uniformly: whether to sample trajectories uniformly,
        or proportionally to the number of slices in each trajectory (default)
      margin: number of extra steps after "done" that should be included in
        slices, so that networks see the terminal states in the training data

    Yields:
      random trajectory slices sampled uniformly from all slices of length
      up to max_slice_length in all specified epochs
    """

        # TODO(lukaszkaiser): add option to sample from n last trajectories.
        def n_slices(t):
            """How many slices of length upto max_slice_length in a trajectory."""
            if not max_slice_length:
                return 1
            # A trajectory [a, b, c, end_state] will have 2 slices of length 2:
            # the slice [a, b] and the one [b, c], with margin=0; 3 with margin=1.
            return max(1, len(t) + margin - max_slice_length)

        while True:
            all_epochs = list(self._trajectories.keys())
            max_epoch = max(all_epochs) + 1
            # Bind the epoch indices to a new name so they can be recalculated every
            # epoch.
            epoch_indices = epochs or all_epochs
            epoch_indices = [
                # So -1 means "last".
                ep % max_epoch for ep in epoch_indices
            ]
            # Remove duplicates and consider only epochs where some trajectories
            # were recorded.
            epoch_indices = [
                epoch_id for epoch_id in list(set(epoch_indices))
                if self._trajectories[epoch_id]
            ]

            # Sample an epoch proportionally to number of slices in each epoch.
            if len(epoch_indices
                   ) == 1:  # Skip this step if there's just 1 epoch.
                epoch_id = epoch_indices[0]
            else:
                # NOTE: Bottleneck. TODO(pkozakowski): Optimize.
                slices_per_epoch = [
                    sum([n_slices(t) for t in self._trajectories[ep]])
                    for ep in epoch_indices
                ]
                epoch_id = _sample_proportionally(epoch_indices,
                                                  slices_per_epoch)
            epoch = self._trajectories[epoch_id]

            # Sample a trajectory proportionally to number of slices in each one.
            if sample_trajectories_uniformly:
                slices_per_trajectory = [1] * len(epoch)
            else:
                # NOTE: Bottleneck. TODO(pkozakowski): Optimize.
                slices_per_trajectory = [n_slices(t) for t in epoch]
            trajectory = _sample_proportionally(epoch, slices_per_trajectory)

            # Sample a slice from the trajectory.
            slice_start = np.random.randint(n_slices(trajectory))

            # Convert the whole trajectory to Numpy while adding the margin. The
            # result is cached, so we don't have to repeat this for every sample.
            trajectory_np = trajectory.to_np(margin, self._timestep_to_np)

            # Slice and yield the result.
            slice_end = slice_start + (max_slice_length
                                       or trajectory_np.observations.shape[0])
            yield fastmath.nested_map(lambda x: x[slice_start:slice_end],
                                      trajectory_np)
Example #5
0
 def slots(self):
   """Returns the slots of all optimizers."""
   return fastmath.nested_map(lambda opt: opt.slots, self._optimizers)
Example #6
0
    def trajectory_batch_stream(self,
                                batch_size,
                                epochs=None,
                                max_slice_length=None,
                                min_slice_length=None,
                                margin=0,
                                include_final_state=False,
                                sample_trajectories_uniformly=False):
        """Return a stream of trajectory batches from the specified epochs.

    This function returns a stream of tuples of numpy arrays (tensors).
    If tensors have different lengths, they will be padded by 0.

    Args:
      batch_size: the size of the batches to return
      epochs: a list of epochs to use; we use all epochs if None
      max_slice_length: maximum length of the slices of trajectories to return
      min_slice_length: minimum length of the slices of trajectories to return
      margin: number of extra steps after "done" that should be included in
        slices, so that networks see the terminal states in the training data
      include_final_state: whether to include slices with the final state of
        the trajectory which may have no action and reward
      sample_trajectories_uniformly: whether to sample trajectories uniformly,
       or proportionally to the number of slices in each trajectory (default)

    Yields:
      batches of trajectory slices sampled uniformly from all slices of length
      at least min_slice_length and up to max_slice_length in all specified
      epochs
    """
        def pad(tensor_list):
            # Replace Nones with valid tensors.
            not_none_tensors = [t for t in tensor_list if t is not None]
            assert not_none_tensors, 'All tensors to pad are None.'
            prototype = np.zeros_like(not_none_tensors[0])
            tensor_list = [
                t if t is not None else prototype for t in tensor_list
            ]

            max_len = max([t.shape[0] for t in tensor_list])
            min_len = min([t.shape[0] for t in tensor_list])
            if max_len == min_len:  # No padding needed.
                return np.array(tensor_list)

            pad_len = 2**int(np.ceil(np.log2(max_len)))
            return np.array([
                _zero_pad(t, (0, pad_len - t.shape[0]), axis=0)
                for t in tensor_list
            ])

        cur_batch = []
        for t in self.trajectory_stream(epochs,
                                        max_slice_length,
                                        include_final_state,
                                        sample_trajectories_uniformly,
                                        margin=margin):
            # TODO(pkozakowski): Instead sample the trajectories out of those with
            # the minimum length.
            if min_slice_length is not None and len(t) < min_slice_length:
                continue

            cur_batch.append(t)
            if len(cur_batch) == batch_size:
                # TODO(pkozakowski): Unpack based on name instead of position in the
                # tuple (how?).
                obs, act, dinp, rew, ret, done, mask = zip(
                    *[t.to_np(self._timestep_to_np) for t in cur_batch])
                # Where act, rew and ret will usually have the following shape:
                # [batch_size, trajectory_length-1], which we call [B, L-1].
                # Observations are more complex and will usually be [B, L] + S where S
                # is the shape of the observation space (self.observation_space.shape).
                # We stop the recursion at level 1, so we pass lists of arrays into
                # pad().
                yield fastmath.nested_map(pad,
                                          TrajectoryNp(
                                              observations=obs,
                                              actions=act,
                                              dist_inputs=dinp,
                                              rewards=rew,
                                              dones=done,
                                              returns=ret,
                                              mask=mask,
                                          ),
                                          level=1)
                cur_batch = []
Example #7
0
def tensor_shapes_to_shape_dtypes(shapes, dtype):
    return math_lib.nested_map(
        lambda s: shapes_lib.ShapeDtype(s.as_list(), dtype), shapes)
Example #8
0
 def _unreplicate(self, x):
   if self._n_devices == 1:
     return x
   return fastmath.nested_map(lambda x: x[0], x)
Example #9
0
 def _unreplicate(self, x):
   """Return a single-device version of x using the first component only."""
   if self._n_devices < 2:
     return x
   return fastmath.nested_map(lambda y: y[0], x)
Example #10
0
def _average_multidevice_gradients(gradients):
    """Averages gradients over all the devices across different hosts."""
    gradients_psum = fastmath.psum(gradients, 'batch')  # sum over all devices
    n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
    return fastmath.nested_map(lambda g: g / n_devices_total, gradients_psum)
Example #11
0
 def test_nested_map(self):
     inp = {'a': ([0, 1], 2), 'b': _TestNamedtuple(3)}
     out = {'a': ([1, 2], 3), 'b': _TestNamedtuple(4)}
     self.assertEqual(fastmath.nested_map(lambda x: x + 1, inp), out)
Example #12
0
 def _unreplicate(self, x):
     """Returns a single-device version of ``x``."""
     if self._n_devices < 2:
         return x
     return fastmath.nested_map(lambda y: y[0], x)
Example #13
0
  def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None,
               memoize_jit=True):
    """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
      memoize_jit: Whether to memoize JITed functions; this significantly speeds
        up XLA compilation of larger models, but it uses `repr(layer)` as keys
        to memoize so it could fail if two layers with different functionality
        had the same string representaion. We have not encountered such case
        yet so this is turned on by default, but consider turning it off or
        reviewing your model if you use custom layers and encounter a problem.
    """
    self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
    self._loss_layer = loss_layer
    self._optimizer_fn = optimizer_fn
    self._n_devices = n_devices or fastmath.device_count()
    self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
    self._n_steps_per_log = 100  # Log layers and stats every 100 steps.
    self._jit_memory = {} if memoize_jit else None

    # Create accelerated versions of layers as pmaped/jited pure_fn.
    self._accelerated_layer_fns = fastmath.nested_map(
        lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'),
        self._blocks)

    # Create per-layer optimizers and replicate opt_params.
    def _make_optimizer(layer):
      opt = optimizer_fn()
      opt.tree_init(layer.weights)
      return opt

    self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
    self._replicated_opt_params = fastmath.nested_map(
        lambda opt: self._replicate(opt.opt_params), self._optimizers)

    self._loss_opt = _make_optimizer(loss_layer)
    self._replicated_loss_opt_params = self._replicate(
        self._loss_opt.opt_params)

    # Forward + backward + optimizer-update functions for all layers.
    # We call them in short FBO for "Forward + Backward + Optimizer update".
    # Reversible layers define a reverse_and_fbo function that also reverses.

    self._fbos = []
    for i, (std_layer, rev_layers) in enumerate(self._blocks):
      (std_opt, rev_opts) = self._optimizers[i]
      std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices)
      rev_and_fbos = []
      for layer, opt in zip(rev_layers, rev_opts):
        rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
            layer, opt, self._n_devices)
        rev_and_fbos.append(self._pjit(
            rev_and_fbo, f'rev+bwd {repr(layer)}', donate_argnums=(1, 2)))
      jit_std_fbo = self._pjit(
          std_fbo, f'bwd {repr(std_layer)}', donate_argnums=(1, 2))
      self._fbos.append((jit_std_fbo, rev_and_fbos))

    loss_fbo = _fbo_with_layer_and_opt(
        self._loss_layer, self._loss_opt, self._n_devices, 'loss')
    self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
Example #14
0
 def _default_timestep_to_np(self, ts):
     """Default way to convert timestep to numpy."""
     return fastmath.nested_map(np.array, ts)
Example #15
0
def read_values(variables):
    return math_lib.nested_map(lambda v: v.read_value(), variables)
Example #16
0
    def trajectory_batch_stream(self,
                                batch_size,
                                epochs=None,
                                max_slice_length=None,
                                min_slice_length=None,
                                margin=0,
                                sample_trajectories_uniformly=False):
        """Return a stream of trajectory batches from the specified epochs.

    This function returns a stream of tuples of numpy arrays (tensors).
    If tensors have different lengths, they will be padded by 0.

    Args:
      batch_size: the size of the batches to return
      epochs: a list of epochs to use; we use all epochs if None
      max_slice_length: maximum length of the slices of trajectories to return
      min_slice_length: minimum length of the slices of trajectories to return
      margin: number of extra steps after "done" that should be included in
        slices, so that networks see the terminal states in the training data
      sample_trajectories_uniformly: whether to sample trajectories uniformly,
       or proportionally to the number of slices in each trajectory (default)

    Yields:
      batches of trajectory slices sampled uniformly from all slices of length
      at least min_slice_length and up to max_slice_length in all specified
      epochs
    """
        def pad(tensor_list):
            # Replace Nones with valid tensors.
            not_none_tensors = [t for t in tensor_list if t is not None]
            assert not_none_tensors, 'All tensors to pad are None.'
            prototype = np.zeros_like(not_none_tensors[0])
            tensor_list = [
                t if t is not None else prototype for t in tensor_list
            ]

            max_len = max([t.shape[0] for t in tensor_list])
            if min_slice_length is not None:
                max_len = max(max_len, min_slice_length)
            min_len = min([t.shape[0] for t in tensor_list])
            if max_len == min_len:  # No padding needed.
                return np.array(tensor_list)

            pad_len = 2**int(np.ceil(np.log2(max_len)))
            return np.array([
                _zero_pad(t, (0, pad_len - t.shape[0]), axis=0)
                for t in tensor_list
            ])

        cur_batch = []
        for t in self.trajectory_stream(epochs,
                                        max_slice_length,
                                        sample_trajectories_uniformly,
                                        margin=margin):
            cur_batch.append(t)
            if len(cur_batch) == batch_size:
                # Make a nested TimeStepBatch of lists out of a list of TimeStepBatches.
                timestep_batch = fastmath.nested_zip(cur_batch)
                # Actions, rewards and returns in the trajectory slice have shape
                # [batch_size, trajectory_length], which we denote as [B, L].
                # Observations are more complex: [B, L] + S, where S is the shape of the
                # observation space (self.observation_space.shape).
                # We stop the recursion at level 1, so we pass lists of arrays into
                # pad().
                yield fastmath.nested_map(pad, timestep_batch, level=1)
                cur_batch = []
Example #17
0
def to_tensors(args):
    return math_lib.nested_map(tf.convert_to_tensor, args)
Example #18
0
  def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None):
    """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
    """
    self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
    self._loss_layer = loss_layer
    self._optimizer_fn = optimizer_fn
    self._n_devices = n_devices or fastmath.device_count()
    self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])

    # Create accelerated versions of layers as pmaped/jited pure_fn.
    self._accelerated_layer_fns = fastmath.nested_map(
        lambda layer: self._pjit(layer.pure_fn), self._blocks)

    # Create per-layer optimizers and replicate opt_params.
    def _make_optimizer(layer):
      opt = optimizer_fn()
      opt.tree_init(layer.weights)
      return opt

    self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
    self._replicated_opt_params = fastmath.nested_map(
        lambda opt: self._replicate(opt.opt_params), self._optimizers)

    self._loss_opt = _make_optimizer(loss_layer)
    self._replicated_loss_opt_params = self._replicate(
        self._loss_opt.opt_params)

    # Forward + backward + optimizer-update functions for all layers.
    # We call them in short FBO for "Forward + Backward + Optimizer update".
    # Reversible layers define a reverse_and_fbo function that also reverses.

    self._fbos = []
    for i, (std_layer, rev_layers) in enumerate(self._blocks):
      (std_opt, rev_opts) = self._optimizers[i]
      std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices)
      rev_and_fbos = []
      for layer, opt in zip(rev_layers, rev_opts):
        rev_and_fbos.append(self._pjit(_reverse_and_fbo_with_layer_and_opt(
            layer, opt, self._n_devices)))
      self._fbos.append((self._pjit(std_fbo), rev_and_fbos))

    loss_fbo = _fbo_with_layer_and_opt(
        self._loss_layer, self._loss_opt, self._n_devices, 'loss')
    self._loss_fbo = self._pjit(loss_fbo)
Example #19
0
def to_arrays(args):
    return math_lib.nested_map(jnp.asarray, args)
Example #20
0
 def _unreplicate(self, x):
   if self._n_devices == 1:
     return tl.on_cpu(x)
   return tl.on_cpu(fastmath.nested_map(lambda x: x[0], x))
Example #21
0
  def run(self, n_steps=1):
    """Runs this training loop for n steps.

    Optionally runs evals and saves checkpoints at specified points.

    Args:
      n_steps: Stop training after completing n steps.
    """
    # Extract key values (weights, state, slots) and update them in each loop.
    weights = self._model_in_training.weights
    state = self._model_in_training.state
    slots = self._task.optimizer.slots
    opt_params = self._task.optimizer.opt_params

    # weights, state, slots need to be replicated if needed.
    weights, state, slots, opt_params = self._for_n_devices(
        (weights, state, slots, opt_params))

    with self._open_summary_writers() as (train_summary_writer,
                                          eval_summary_writer):
      loss_acc, step_acc = 0.0, 0
      start_time = time.time()
      optimizer_metrics_acc = collections.defaultdict(float)
      for _ in range(n_steps):
        self._step += 1
        loss, weights, state, slots, optimizer_metrics = self._run_one_step(
            weights, state, slots, opt_params)

        # optimizer_metrics and loss are replicated on self.n_devices, a few
        # metrics are replicated (ex: gradients_l2, weights_l2) - i.e. they are
        # the same across devices, whereas some (ex: loss) aren't because they
        # are different on different devices (due to different data).
        # Taking the average does the correct thing in both the cases.
        #
        # NOTE: Only the weights and gradients are synced across the hosts. This
        # implies the loss here is averaged from this hosts' devices and not
        # across all hosts.
        optimizer_metrics, loss = fastmath.nested_map(
            jnp.mean, (optimizer_metrics, loss))

        loss_acc += loss
        step_acc += 1
        for metric_name, value in optimizer_metrics.items():
          optimizer_metrics_acc[metric_name] += value

        should_checkpoint = self._checkpoint_at(self.step)
        should_eval = self._eval_at(self.step)
        unr_weights, unr_state, unr_slots = None, None, None
        if should_checkpoint or should_eval:
          unr_weights, unr_state, unr_slots = self._unreplicate(
              (weights, state, slots))

        if should_checkpoint:
          self.save_checkpoint(unr_weights, unr_state, unr_slots)
        if should_eval:
          elapsed_time = time.time() - start_time
          self._model_in_training.weights = unr_weights
          self._model_in_training.state = unr_state
          self._eval_model.weights = self._model.weights
          self._log_training_progress(
              total_loss=loss_acc, n_steps=step_acc, elapsed_time=elapsed_time,
              optimizer_metrics=optimizer_metrics_acc,
              summary_writer=train_summary_writer)
          self.run_evals(weights, state, eval_summary_writer)
          loss_acc, step_acc = 0.0, 0
          start_time = time.time()
          optimizer_metrics_acc = collections.defaultdict(float)

    # Store the final values back into their respective objects, for testing
    # or other inspection/use.

    # We keep the standard model weights/state unreplicated and
    # `tl.Accelerate(model)` will carry the replicated weights/state.
    # TODO(afrozm): Try to use `tl.Accelerate(model)` everywhere in the Loop.
    self._model_in_training.weights = self._unreplicate(weights)
    self._model_in_training.state = self._unreplicate(state)
    self._task.optimizer.slots = self._unreplicate(slots)
    self._task.optimizer.opt_params = self._unreplicate(opt_params)
    self._eval_model.weights = self._model.weights
Example #22
0
  def testTrain(self, layer_id, rng_updater_id, batch_size, trax_has_weights,
                explicit_build, use_model):
    """Tests training (forward and backward pass) for AsKeras.

    Args:
      layer_id: an integer, the index into `_LAYERS`.
      rng_updater_id: an integer, the index into `_RNG_UPDATERS`.
      batch_size: an integer or `None`, the value for the `batch_size` argument
        in `AsKeras.__init__`.
      trax_has_weights: bool, whether to make the trax layer contain weights at
        the time when `AsKeras.build` is called.
      explicit_build: bool, whether to explicitly call `AsKeras.build`.
      use_model: bool, whether to build a `tf.keras.Model` out of the
        `AsKeras` layer and use the model to do the training instead of
        the bare layer. If `True`, we will also test checkpointing and restoring
        using the model.
    """
    with trax.fastmath.use_backend("tensorflow-numpy"):
      make_trax_layer, input_shapes_no_batch, dtype, allow_none_batch = (
          _LAYERS[layer_id])
      # We make a fresh trax layer for each test case, so that different test
      # cases won't interfere with each other.
      trax_layer = make_trax_layer()
      if not allow_none_batch and batch_size is None:
        self.skipTest("This Trax layer can't handle None batch size.")
      rng_updater = _RNG_UPDATERS[rng_updater_id]
      input_shapes = math_lib.nested_map(
          lambda s: [batch_size] + s, input_shapes_no_batch)
      input_sig = trax2keras.tensor_shapes_to_shape_dtypes(input_shapes, dtype)
      initializer_rng = math_lib.random.get_prng(765)
      weights, state = trax_layer.init(input_sig, rng=initializer_rng)
      generator = tf.random.Generator.from_seed(567)
      def get_inputs():
        return dummy_inputs(generator, input_sig)
      if trax_has_weights:
        trax_layer(to_arrays(get_inputs()), weights=weights, state=state)
      rng = math_lib.random.get_prng(1234)
      keras_layer = trax2keras.AsKeras(
          trax_layer, batch_size=batch_size, initializer_rng=initializer_rng,
          rng=rng, rng_updater=rng_updater)
      if explicit_build:
        keras_layer.build(input_shapes)
      if use_model:
        x = tf.keras.Input(shape=input_shapes_no_batch, dtype=dtype)
        y = keras_layer(x)
        keras_model = tf.keras.Model(inputs=x, outputs=y)
      lr = 0.1  # learning rate
      for _ in range(3):
        inputs = get_inputs()
        with tf.GradientTape() as trax_tape:
          trax_tape.watch([x.data for x in tf.nest.flatten(weights)])
          trax_outputs, state = trax_layer.pure_fn(
              to_arrays(inputs), weights=weights, state=state, rng=rng)
        trax_grads = trax_tape.gradient(*to_tensors([trax_outputs, weights]))
        # `g` may be `tf.IndexedSlices`, so we need to `convert_to_tensor`
        # before multiplication.
        weights = tf.nest.map_structure(
            lambda w, g: w + jnp.asarray(lr * tf.convert_to_tensor(g), w.dtype),
            weights, trax_grads)
        rng = rng_updater(rng)
        with tf.GradientTape() as keras_tape:
          if use_model:
            keras_outputs = keras_model(inputs)
          else:
            keras_outputs = keras_layer(inputs)
        if isinstance(keras_outputs, tuple) and len(keras_outputs) == 1:
          keras_outputs = keras_outputs[0]
        self.assertAllClose(to_tensors(trax_outputs), keras_outputs, atol=1e-5)
        keras_grads = keras_tape.gradient(keras_outputs,
                                          keras_layer.trainable_variables)
        tf.nest.map_structure(
            lambda v, g: v.assign_add(  # pylint: disable=g-long-lambda
                tf.cast(lr * tf.convert_to_tensor(g), v.dtype)),
            keras_layer.trainable_variables, keras_grads)
        self.assertAllClose(
            to_tensors(weights), read_values(keras_layer._weights),
            rtol=2e-6, atol=3.5e-4 if has_gpu() else 1e-6)
        self.assertAllClose(to_tensors(state), read_values(keras_layer._state))
        self.assertAllClose(to_tensors(rng), read_values(keras_layer._rng))
      if use_model:
        fname = os.path.join(self.get_temp_dir(), "checkpoint")
        keras_model.save(fname)
        loaded_model = tf.keras.models.load_model(fname)
        for _ in range(2):
          inputs = get_inputs()
          self.assertAllClose(keras_model(inputs), loaded_model(inputs))
Example #23
0
    def __init__(self,
                 blocks,
                 loss_layer,
                 optimizer_fn,
                 n_devices=None,
                 memoize_jit=True,
                 free_accelerators_on_step=False,
                 adasum=False):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
      memoize_jit: Whether to memoize JITed functions; this significantly speeds
        up XLA compilation of larger models, but it uses `repr(layer)` as keys
        to memoize so it could fail if two layers with different functionality
        had the same string representaion. We have not encountered such case
        yet so this is turned on by default, but consider turning it off or
        reviewing your model if you use custom layers and encounter a problem.
      free_accelerators_on_step: If true, frees memory on accelerators when
        starting a step. All layers and arguments must be on host for that,
        otherwise it can lead to failures. Can prevent memory fragmentation.
      adasum: if True, use adaptive summation to gather multi-device gradients.
    """
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
        self._n_steps_per_log = 100  # Log layers and stats every 100 steps.
        self._n_async_layers = 1  # How many layers to run asynchronously.
        self._jit_memory = {} if memoize_jit else None
        self._do_free = free_accelerators_on_step
        self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs,
                                                 backend='cpu')

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'),
            self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            opt.slots = tl.on_cpu(opt.slots)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate_cpu(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer,
                                              std_opt,
                                              self._n_devices,
                                              adasum=self._adasum)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
                    layer, opt, self._n_devices, self._adasum)
                # The donated args are (outputs, weights, grads) and we can donate
                # them because weights and grads are immediately replaced and in
                # case of reversible layers, the outputs are never used again.
                rev_and_fbos.append(
                    self._pjit(rev_and_fbo,
                               f'rev+bwd {repr(layer)}',
                               donate_argnums=(0, 1, 2)))
            # In standard layers, the inputs cannot be donated as they may be used
            # as outputs for the reversible block below, but weights and grads can.
            jit_std_fbo = self._pjit(std_fbo,
                                     f'bwd {repr(std_layer)}',
                                     donate_argnums=(1, 2))
            self._fbos.append((jit_std_fbo, rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss',
                                           self._adasum)
        self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
Example #24
0
  def run(self, n_steps=1):
    """Runs this training loop for n steps.

    Optionally runs evals and saves checkpoints at specified points.

    Args:
      n_steps: Stop training after completing n steps.
    """
    with self._open_summary_writers() as (
        train_summary_writers, eval_summary_writers):
      process = psutil.Process(os.getpid())
      loss_acc, step_acc = 0.0, 0
      start_time = time.time()
      optimizer_metrics_acc = collections.defaultdict(float)
      for i in range(n_steps):
        prev_task_index = self._which_task(self._step)
        self._step += 1
        task_index = self._which_task(self._step)
        task_changed = task_index != prev_task_index
        loss, optimizer_metrics = self._run_one_step(task_index, task_changed)

        # optimizer_metrics and loss are replicated on self.n_devices, a few
        # metrics are replicated (ex: gradients_l2, weights_l2) - i.e. they are
        # the same across devices, whereas some (ex: loss) aren't because they
        # are different on different devices (due to different data).
        # Taking the average does the correct thing in both the cases.
        #
        # NOTE: Only the weights and gradients are synced across the hosts. This
        # implies the loss here is averaged from this hosts' devices and not
        # across all hosts.
        optimizer_metrics, loss = fastmath.nested_map(
            functools.partial(tl.mean_or_pmean, self._n_devices),
            (optimizer_metrics, loss))

        loss_acc += loss
        step_acc += 1
        for metric_name, value in optimizer_metrics.items():
          optimizer_metrics_acc[metric_name] += value

        if self._checkpoint_at(self.step):
          self.save_checkpoint()
        if self._eval_at(self.step):
          logging.info('cpu memory use (MB): %.2f',
                       process.memory_info().rss / float(1024*1024))
          elapsed_time = time.time() - start_time
          self._eval_model.weights = self._model.weights
          self._log_training_progress(
              task=self._tasks[task_index],
              total_loss=loss_acc,
              n_steps=step_acc,
              elapsed_time=elapsed_time,
              optimizer_metrics=optimizer_metrics_acc,
              summary_writer=train_summary_writers[task_index],
          )
          self.run_evals(eval_summary_writers)
          loss_acc, step_acc = 0.0, 0
          start_time = time.time()
          optimizer_metrics_acc = collections.defaultdict(float)

    # Store the final values back into their respective objects, for testing
    # or other inspection/use.

    # We keep the standard model weights/state unreplicated and
    # `tl.Accelerate(model)` will carry the replicated weights/state.
    # TODO(afrozm): Try to use `tl.Accelerate(model)` everywhere in the Loop.
    self._eval_model.weights = self._model.weights
Example #25
0
def clip_grads(grad_tree, max_norm):
    """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
    norm = l2_norm(grad_tree)
    normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm))
    return fastmath.nested_map(grad_tree, normalize)