Example #1
0
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of activations.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    if self._mode != 'train':
      return x
    state, rng = self.state, self.rng
    rate = self._initial_rate
    if isinstance(state, dict) and self._name in state:
      rate = state[self._name]
    mask_shape = list(x.shape)
    for axis in self._shared_axes:
      mask_shape[axis] = 1
    if fastmath.is_backend(fastmath.Backend.JAX):
      keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate)
    else:
      keep_prob = 1.0 - rate
    keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape))
    if fastmath.is_backend(fastmath.Backend.JAX):
      keep_prob = jax.lax.tie_in(keep, keep_prob)
    mask = keep.astype(x.dtype) / keep_prob
    return x * mask
Example #2
0
File: trainer.py Project: MLDL/trax
def _jax_and_tf_configure_for_devices():  # pylint: disable=missing-function-docstring
  jax.config.enable_omnistaging()
  if FLAGS.use_tpu:
    jax.config.update('jax_platform_name', 'tpu')
    jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
  if (FLAGS.enable_eager_execution and (fastmath.is_backend(Backend.NUMPY) or
                                        fastmath.is_backend(Backend.JAX))):
    # Numpy backend doesn't benefit from having the input pipeline run on GPU,
    # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend.
    tf.config.experimental.set_visible_devices([], 'GPU')
Example #3
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return _multi_device_put(x)
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Example #4
0
def one_hot(x, n_categories, dtype=jnp.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = jnp.arange(n_categories)
    if fastmath.is_backend(fastmath.Backend.JAX):
        # Work around a jax broadcasting issue.
        indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
    return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
Example #5
0
 def f(x):
   if n_devices > 1 and fastmath.is_backend(fastmath.Backend.JAX):
     return jax.device_put_replicated(x, jax.local_devices())
   elif n_devices > 1:
     return jnp.broadcast_to(x, (n_devices,) + jnp.asarray(x).shape)
   else:
     return x
Example #6
0
    def save_state(self, keep, prefix='model'):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self.n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(
                *fastmath.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if fastmath.is_backend(fastmath.Backend.JAX):
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        weights_file = os.path.join(output_dir, prefix + '.pkl.gz')

        # This dict will be stored as the model.
        trainer_state_dict = make_trainer_state_dict(step, opt_state, history,
                                                     model_state,
                                                     self._input_signature)
        self._save_state_dict(trainer_state_dict, weights_file)

        if keep:
            weights_file = os.path.join(output_dir,
                                        '{}_{}.pkl.gz'.format(prefix, step))
            self._save_state_dict(trainer_state_dict, weights_file)
Example #7
0
    def forward(self, inputs):
        """Returns attention-computed activations.

    Args:
      inputs: A (queries, keys, values) tuple.
    """
        q, k, v = inputs

        if self._mode == 'predict':
            self.state = _fast_inference_update_state(inputs, self.state)
            (k, v, mask, _) = self.state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if fastmath.is_backend(fastmath.Backend.JAX):
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res, dots = DotProductAttention(q,
                                        k,
                                        v,
                                        mask,
                                        dropout=self._dropout,
                                        mode=self._mode,
                                        rng=self.rng)
        if self._mode == 'viz':
            self.state = dots
        return res
Example #8
0
 def _l2_norm(self, flat_list):
   """Returns the aggregate L2 norm of a list of tensors."""
   if fastmath.is_backend(fastmath.Backend.JAX):
     norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
   else:  # TODO(lukaszkaiser): add vdot to TF-numpy
     norm = jnp.sqrt(sum(jnp.sum(x*x) for x in flat_list))
   return norm
Example #9
0
def _causal_mask(length):
  # Not all backends define jnp.tril. However, using np.tril is inefficient
  # in that it creates a large global constant. TODO(kitaev): try to find an
  # alternative that works across all backends.
  if fastmath.is_backend(fastmath.Backend.JAX):
    return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)
  else:
    return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
Example #10
0
def network_policy(
    collect_model,
    policy_distribution,
    loop,
    trajectory_np,
    head_index=0,
    temperature=1.0,
):
  """Policy function powered by a neural network.

  Used to implement Agent.policy() in policy-based agents.

  Args:
    collect_model: the model used for collecting trajectories
    policy_distribution: an instance of trax.rl.distributions.Distribution
    loop: trax.supervised.training.Loop used to train the policy network
    trajectory_np: an instance of trax.rl.task.TrajectoryNp
    head_index: index of the policy head a multihead model.
    temperature: temperature used to sample from the policy (default=1.0)

  Returns:
    a pair (action, dist_inputs) where action is the action taken and
    dist_inputs is the parameters of the policy distribution, that will later
    be used for training.
  """
  if temperature == 1.0:
    model = collect_model
  else:
    # When evaluating (t != 1.0), use the evaluation model instead of the
    # collection model - some models accumulate normalization statistics
    # during data collection, and we don't want to do it in eval to avoid data
    # leakage.
    model = loop.eval_model
    model.state = collect_model.state
  # Copying weights from loop.model should work, because the raw model's
  # weights should be updated automatically during training, but it doesn't.
  # TODO(pkozakowski): Debug.
  acc = loop._trainer_per_task[0].accelerated_model_with_loss  # pylint: disable=protected-access
  model.weights = acc._unreplicate(acc.weights[0])  # pylint: disable=protected-access
  # Add batch dimension to trajectory_np and run the model.
  pred = model(trajectory_np.observations[None, ...])
  if isinstance(pred, (tuple, list)):
    # For multihead models, extract the policy head output.
    pred = pred[head_index]
  assert pred.shape == (
      1, trajectory_np.observations.shape[0], policy_distribution.n_inputs
  )
  # Pick element 0 from the batch (the only one), last (current) timestep.
  pred = pred[0, -1, :]
  sample = policy_distribution.sample(pred, temperature=temperature)
  result = (sample, pred)
  if fastmath.is_backend(fastmath.Backend.JAX):
    # The result is composed of mutable numpy arrays. We copy them to avoid
    # accidental modification.
    result = fastmath.nested_map(lambda x: x.copy(), result)
  return result
Example #11
0
def init_host_and_devices(n_devices=None, random_seed=None):
    """Initializes host and device attributes for this trainer.

  Args:
    n_devices: Number of devices this trainer will use. If `None`, get the
        number from the backend.
    random_seed: Random seed as the starting point for all random numbers used
        by the trainer. If `None`, calculate one from system time and host id.

  Returns:
    is_chief: True if this trainer has special chief responsibilities.
    host_count: Number of hosts in this computation.
    n_devices: The passed in value of n_devices or a computed default (for this
      host).
    random_seed: The passed in value of random_seed or a computed default.
  """
    if fastmath.is_backend(fastmath.Backend.JAX):
        host_id = jax.host_id()
        host_count = jax.host_count()
    else:
        host_id = 0
        host_count = 1
    is_chief = (host_id == 0)

    logging.info(
        'Initializing hosts and devices: host_id %d, host_count %d, '
        'is_chief %d', host_id, host_count, is_chief)

    device_count = fastmath.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count and fastmath.is_backend(fastmath.Backend.JAX):
        raise ValueError('JAX cannot work yet with n_devices != all devices: '
                         '%d != %d' % (n_devices, device_count))

    if random_seed is None and host_count > 1:
        random_seed = int(1e6 * (host_id + time.time())) % 2**32
    return (is_chief, host_count, n_devices,
            _init_random_number_generators(random_seed))
Example #12
0
File: trainer.py Project: MLDL/trax
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  _tf_setup_from_flags()
  _gin_parse_configs()
  _jax_and_tf_configure_for_devices()

  output_dir = _output_dir_or_default()
  if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP):
    _train_using_tf(output_dir)
  else:
    trainer_lib.train(output_dir=output_dir)

  trainer_lib.log('Finished training.')
Example #13
0
 def _to_bits(self, weights):
   """Converts a list of weights to bit-cast weights and their types."""
   # This is currently needed to pickle bfloat16 arrays from JAX.
   # TODO(lukaszkaiser): remove once it is not needed (the following unit test
   #   checks it: training_test/test_restores_step_bfloat16).
   if not fastmath.is_backend(fastmath.Backend.JAX):
     return weights
   bits = []
   for w in weights:
     if w.dtype == jnp.bfloat16:
       bits.append((jax.lax.bitcast_convert_type(w, np.uint16), 'bfloat16'))
     else:  # for non-bfloat16 weights, be compatible with earlier checkpoints
       bits.append(w)
   return bits
Example #14
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    # Create a JAX GPU cluster if using JAX and given a chief IP.
    if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip:
        _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id,
                              FLAGS.gpu_cluster_chief_ip,
                              FLAGS.gpu_cluster_n_hosts,
                              FLAGS.gpu_cluster_port)

    if FLAGS.disable_jit:
        fastmath.disable_jit()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP):
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Example #15
0
 def _from_bits(self, bits_and_types):
   """Converts a list of bit-cast weights and their types back to weights."""
   # This is the reverse of _to_bits, see above for explanation.
   if not fastmath.is_backend(fastmath.Backend.JAX):
     return bits_and_types
   weights = []
   for bits_and_dtype in bits_and_types:
     if isinstance(bits_and_dtype, tuple):
       bits, dtype = bits_and_dtype
       assert dtype == 'bfloat16'
       w = jax.lax.bitcast_convert_type(bits, jnp.bfloat16)
       weights.append(w)
     else:
       weights.append(bits_and_dtype)
   return weights
Example #16
0
def _forward_and_or_backward(layer):
  """Create forward_and_or_backward for layers that don't define it."""
  # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX.
  if fastmath.is_backend(fastmath.Backend.JAX):
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def forward_and_or_backward(inputs, weights, state, rng, output_grad=None,
                              compute_output=True, update_state=True):
    """Performs batched forward and/or backward passes.

    Args:
      inputs: inputs to the attention layer
      weights: weights for the attention layer
      state: state of the attention layer
      rng: PRNG key for the layer (shared across all examples and heads)
      output_grad: gradient of the loss wrt the output of the layer, or None.
          This function performs the backward pass iff `output_grad` is not
          None.
      compute_output: bool: whether to return the output of the forward pass
          (for example, a pure backwards pass does not need to return the
          output).
      update_state: bool: whether to return an updated layer state.

    Returns:
      A tuple (output, new_state, inputs_grad, weights_grad).
      - output is not None iff compute_output is True
      - new_state is not None iff update_state is True
      - inputs_grad & weights_grad are not None iff output_grad is not None
    """
    # We need a layer pure_fn but only for inputs and weights.
    def pure_fn_without_state_and_rng(x, w):
      return layer.pure_fn(x, w, state, rng)

    # Calculate the vector-Jacobian product of the layer pure_fn.
    output, vjp_fn, new_state = fastmath.vjp(
        pure_fn_without_state_and_rng, inputs, weights, has_aux=True)
    output = output if compute_output else None
    new_state = new_state if update_state else None

    # The vjp function returns gradients with respect to inputs and weights.
    if output_grad is not None:
      grads_inputs, grads_weights = vjp_fn(output_grad)
    else:
      grads_inputs, grads_weights = None, None

    return (output, new_state, grads_inputs, grads_weights)
  return forward_and_or_backward
Example #17
0
    def _l2_norm(self, flat_list):
        """Returns an L2-like norm of all elements of all tensors in `flat_list`.

    Args:
      flat_list: Collection of tensors as a flat list (rather than, e.g., a
          tree).

    Returns:
      A scalar value computed as if all the tensors in `flat_list` were joined
      and flattened into a single vector, and then the L2 norm of that vector
      was calculated.
    """
        if fastmath.is_backend(fastmath.Backend.JAX):
            norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
        else:  # TODO(lukaszkaiser): add vdot to TF-numpy
            norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list))
        return norm
Example #18
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention strengths
        (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
  d_feature = queries.shape[-1]
  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if fastmath.is_backend(fastmath.Backend.JAX):
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
  # Softmax.
  dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
  out = jnp.matmul(dots, values)
  out = out.astype(jnp.float32)
  dots = dots.astype(jnp.float32)
  return out, dots
Example #19
0
 def policy(self, trajectory, temperature=1.0):
     """Chooses an action to play after a trajectory."""
     model = self._policy_collect_model
     if temperature != 1.0:  # When evaluating (t != 1.0), don't collect stats
         model = self._policy_eval_model
         model.state = self._policy_collect_model.state
     model.replicate_weights(self._policy_trainer.model_weights)
     tr_slice = trajectory[-self._max_slice_length:]
     trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np)
     # Add batch dimension to trajectory_np and run the model.
     pred = model(trajectory_np.observations[None, ...])
     # Pick element 0 from the batch (the only one), last (current) timestep.
     pred = pred[0, -1, :]
     sample = self._policy_dist.sample(pred, temperature=temperature)
     result = (sample, pred)
     if fastmath.is_backend(fastmath.Backend.JAX):
         result = fastmath.nested_map(lambda x: x.copy(), result)
     return result
Example #20
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if fastmath.is_backend(fastmath.Backend.JAX):
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
Example #21
0
    def forward(self, inputs):
        rng, state = self.rng, self.state
        embs = []
        for ax_emb in self.weights:
            ax_emb = jnp.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                      self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = jnp.concatenate(embs, -1)
            emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            emb = jax.lax.dynamic_slice_in_dim(emb,
                                               state,
                                               inputs.shape[1],
                                               axis=1)
            self.state = state + inputs.shape[1]
            return inputs + emb
        elif self._dropout == 0:
            # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
            # leads to memory blow-up on TPU.
            # emb = jnp.concatenate(embs, -1)
            # return inputs + jnp.reshape(emb, inputs.shape), state
            return inputs + jnp.concatenate([
                jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], ))
                for emb in embs
            ], -1)
        else:
            emb = jnp.concatenate(embs, -1)
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if fastmath.is_backend(fastmath.Backend.JAX):
                keep_prob = jax.lax.tie_in(
                    inputs, jnp.full((), keep_prob, dtype=inputs.dtype))
            keep = fastmath.random.bernoulli(rng, keep_prob,
                                             tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob
            return inputs + jnp.reshape(emb * multiplier, inputs.shape)
Example #22
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference.

  The layer state stores tensors with cached values of keys and values,
  as well as the mask and an index. To make shapes static, keys and values
  in the state are long, and the index indicates where the new keys and values
  from inputs need to be appended. Mask ensures that attention will only look
  at keys upto index.

  During update, we append new_keys and new_values to keys and values at
  position given by index. We also update mask (which starts as all-0s) to
  be 1 at the new keys positions. And we increment index by length of new keys.

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, mask, index)

  Returns:
    Updated state.
  """
    if not fastmath.is_backend(fastmath.Backend.JAX):
        raise ValueError(f'JAX backend is required in predict mode, but found '
                         f"backend ({fastmath.backend()['name']}).")

    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    length = new_k.shape[1]
    (ks, vs, mask, idx) = state
    # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
    # with index_update when length == 1 is worth it.
    # Keys and values are of shape [batch_size, length, d_kv].
    ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
    vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
    # Mask is of shape [batch_size, 1 (for heads), length].
    new_mask = jnp.ones((mask.shape[0], mask.shape[1], length))
    mask = fastmath.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2)
    return (ks, vs, mask, idx + length)
Example #23
0
    def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
    """
        # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX.
        if fastmath.is_backend(fastmath.Backend.JAX):
            jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.device_count()
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn), self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt,
                                              self._n_devices)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbos.append(
                    self._pjit(
                        _reverse_and_fbo_with_layer_and_opt(
                            layer, opt, self._n_devices)))
            self._fbos.append((self._pjit(std_fbo), rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss')
        self._loss_fbo = self._pjit(loss_fbo)
Example #24
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.WeightedCategoryCrossEntropy(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          permanent_checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          permanent_checkpoint_frequency=None,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False,
          adasum=False,
          init_checkpoint=None,
          callbacks=None,
          additional_train_tasks=None,
          additional_eval_tasks=None,
          additional_eval_streams=None):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    permanent_checkpoints_at: list of integers. Save a permanent checkpoint for
      each training step in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    permanent_checkpoint_frequency: int, how often to save permanent checkpoints
      (every permanent_checkpoint_frequency steps).
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer.
    adasum: if True, use adaptive summation for multi-device gradients.
    init_checkpoint: a checkpoint for fine tuning.
    callbacks: a list of callbacks to call during training.
    additional_train_tasks: additional tasks which should be performed during
      training.
    additional_eval_tasks: additional tasks which should be performed during
      evaluation.
    additional_eval_streams: List[NamedStream], additional data streams that
      should be used during evaluation. Can be provided independently of
      additional_eval_tasks.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
  if (permanent_checkpoint_frequency is not None
      and permanent_checkpoints_at is not None):
    raise ValueError('Only one of ["permanent_checkpoint_frequency", '
                     '"permanent_checkpoints_at"] should be set.')
  if use_loop:
    n_devices = num_devices() or fastmath.local_device_count()

    # Prepare the training task.
    # Inputs is either an Inputs instance or a function that returns it.
    if callable(inputs):  # If we pass a function, e.g., through gin, call it.
      inputs = inputs()
    opt = optimizer if use_memory_efficient_trainer else optimizer()
    train_task = training.TrainTask(
        inputs.train_stream(n_devices),
        loss_layer=loss_fn,
        optimizer=opt,
        lr_schedule=lr_schedule_fn(),
        n_steps_per_checkpoint=eval_frequency,
        n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency)

    if additional_train_tasks is None:
      additional_train_tasks = []

    # Prepare the evaluation.
    metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
    names, metrics = zip(*metrics_dict.items())
    eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                  metrics,
                                  metric_names=names,
                                  n_eval_batches=eval_steps)

    if additional_eval_tasks is None:
      additional_eval_tasks = []

    additional_eval_tasks_from_streams = []
    if additional_eval_streams is not None:
      for stream in additional_eval_streams:
        additional_eval_tasks_from_streams.append(
            training.EvalTask(stream.stream,
                              metrics,
                              metric_names=names,
                              n_eval_batches=eval_steps,
                              export_prefix=stream.name))

    # Prepare the training loop.
    checkpoint_at = None
    if checkpoints_at is not None:
      checkpoint_at = lambda step: step in checkpoints_at
    permanent_checkpoint_at = None
    if permanent_checkpoints_at is not None:
      permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')
    if init_checkpoint:
      model_train.init_from_file(init_checkpoint, weights_only=True)
      model_predict_eval.init_from_file(init_checkpoint, weights_only=True)
    loop = training.Loop(
        model_train, [train_task] + additional_train_tasks,
        eval_model=model_predict_eval,
        eval_tasks=[eval_task] +
        additional_eval_tasks + additional_eval_tasks_from_streams,
        output_dir=output_dir,
        checkpoint_at=checkpoint_at,
        permanent_checkpoint_at=permanent_checkpoint_at,
        n_devices=n_devices,
        loss_chunk_size=loss_chunk_size,
        use_memory_efficient_trainer=use_memory_efficient_trainer,
        adasum=adasum,
        random_seed=random_seed,
        callbacks=callbacks,
    )

    steps_to_go = steps - loop.step
    if steps_to_go <= 0:
      log('Stop training, already reached the total training steps %d' % steps)
      return loop

    # Train and return the loop.
    loop.run(steps_to_go)
    return loop

  n_devices = num_devices()
  trainer = trainer_class(model, loss_fn, optimizer, lr_schedule_fn(), inputs,
                          output_dir,
                          random_seed=random_seed,
                          n_devices=n_devices,
                          checkpoints_at=checkpoints_at,
                          metrics=metrics,
                          checkpoint_lowest=checkpoint_lowest,
                          checkpoint_highest=checkpoint_highest,
                          init_checkpoint=init_checkpoint)

  epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  trainer.log_step('Starting training using %d devices' % trainer.n_devices)
  trainer.print_n_weights()

  try:
    for epoch_steps in epochs(steps, trainer.step, epoch_steps):
      trainer.train_epoch(epoch_steps, eval_steps)

      # Bookkeeping we do at the first step
      if trainer.step == 1:
        # Save computation graph (single-device only for now)
        if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
          trainer.save_computation_graphs()

        # Save Gin config
        trainer.save_gin()

    trainer.log_step('Training done')
  except Exception as e:
    raise e
  finally:
    trainer.close()
  return trainer.state
Example #25
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if fastmath.is_backend(fastmath.Backend.JAX):
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        # TODO(kitaev): axial positional encoding is better for very long sequences.
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode)
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
            positional_encoding,
        ]

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout,
                     ff_activation, ff_dropout, mode)
        for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                            ff_dropout, mode) for _ in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                     # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Example #26
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=64,
              d_attention_value=64,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              axial_pos_shape='infinite',
              d_axial_pos_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """

    # assert d_model // n_heads == d_attention_key, \
    #     f'{d_model} // {n_heads} != {d_attention_key}'
    # assert d_model // n_heads == d_attention_value, \
    #     f'{d_model} // {n_heads} != {d_attention_value}'

    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if fastmath.is_backend(fastmath.Backend.JAX):
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        positional_encoding = PositionalEncoding(mode, dropout, max_len,
                                                 axial_pos_shape,
                                                 d_axial_pos_embs)

        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
            positional_encoding,
        ]

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout,
                     ff_activation, ff_dropout, mode)
        for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([  # tok_e mask_e tok_e tok_d tok_d
        in_encoder,  # vec_e mask_e tok_e tok_d tok_d
        tl.Dup(),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # tok_e mask_e tok_e tok_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Decode.
        tl.Select([3, 0, 1, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        tl.Branch([], _MaskOfRightShiftedArray()
                  ),  # stok_d mask_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d mask_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder, given their masks.
        tl.Select([2, 0, 3, 1]),  # svec_d mask_d vec_e mask_e tok_e tok_d
        _ConcatWithPadding(),  # vec_ed tok_e tok_d

        # Run (encoder and) decoder blocks.
        tl.Dup(),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.ReversibleSerial(decoder_blocks),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        _StripFromConcatenateWithPadding(),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
Example #27
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.Serial(tl.LogSoftmax(),
                            tl.CrossEntropyLoss(),
                            name='CrossEntropyLoss'),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
    if use_loop:
        n_devices = num_devices() or fastmath.device_count()

        # Prepare the training task.
        # Inputs is either an Inputs instance or a function that returns it.
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            inputs = inputs()
        opt = optimizer if use_memory_efficient_trainer else optimizer()
        train_task = training.TrainTask(inputs.train_stream(n_devices),
                                        loss_layer=loss_fn,
                                        optimizer=opt,
                                        lr_schedule=lr_schedule_fn(),
                                        n_steps_per_checkpoint=eval_frequency)

        # Prepare the evaluation.
        metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        names, metrics = zip(*metrics_dict.items())
        eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                      metrics,
                                      metric_names=names,
                                      n_eval_batches=eval_steps)

        # Prepare the training loop.
        checkpoint_at = None
        if checkpoints_at is not None:
            checkpoint_at = lambda step: step in checkpoints_at
        loop = training.Loop(
            model(mode='train'), [train_task],
            eval_model=model(mode='eval'),
            eval_tasks=[eval_task],
            output_dir=output_dir,
            checkpoint_at=checkpoint_at,
            n_devices=n_devices,
            loss_chunk_size=loss_chunk_size,
            use_memory_efficient_trainer=use_memory_efficient_trainer,
            random_seed=random_seed)

        steps_to_go = steps - loop.step
        if steps_to_go <= 0:
            log('Stop training, already reached the total training steps %d' %
                steps)
            return loop

        # Train and return the loop.
        loop.run(steps_to_go)
        return loop

    n_devices = num_devices()
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule_fn(),
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            checkpoints_at=checkpoints_at,
                            metrics=metrics,
                            checkpoint_lowest=checkpoint_lowest,
                            checkpoint_highest=checkpoint_highest)

    epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    try:
        for epoch_steps in epochs(steps, trainer.step, epoch_steps):
            trainer.train_epoch(epoch_steps, eval_steps)

            # Bookkeeping we do at the first step
            if trainer.step == 1:
                # Save computation graph (single-device only for now)
                if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
                    trainer.save_computation_graphs()

                # Save Gin config
                trainer.save_gin()

        trainer.log_step('Training done')
    except Exception as e:
        raise e
    finally:
        trainer.close()
    return trainer.state
Example #28
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at if checkpoints_at is not None else []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Example #29
0
    def __init__(self,
                 first_layer,
                 reversible_layers,
                 loss_layer,
                 optimizer_fn,
                 n_devices=None):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial([first_layer] + reversible_layer + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Note that the first layer and loss layer
    can be arbitrary layers, so they can be a `tl.Serial` combination of layers
    too. For now, we only support one block of reversible layers though.

    Args:
      first_layer: The first layer of the model, it can be arbitraty.
      reversible_layers: A list of reversible layers that are executed after
        the first layer. We do not keep their activations in memory and weights
        are moved to CPU RAM after each layer to free accelerator memory.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
    """
        # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX.
        if fastmath.is_backend(fastmath.Backend.JAX):
            jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access
        self._first_layer = first_layer
        self._reversible_layers = reversible_layers
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.device_count()

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_first_layer_fn = self._pjit(first_layer.pure_fn)

        self._accelerated_reversible_layers_fns = []
        for layer in reversible_layers:
            self._accelerated_reversible_layers_fns.append(
                self._pjit(layer.pure_fn))

        # Create per-layer optimizers and replicate opt_params.
        self._optimizers, self._replicated_opt_params = [], []
        for layer in [first_layer] + reversible_layers + [loss_layer]:
            optimizer = optimizer_fn()
            optimizer.tree_init(layer.weights)
            self._optimizers.append(optimizer)
            opt_params = self._replicate(optimizer.opt_params)
            self._replicated_opt_params.append(opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".

        def first_fbo(inputs, weights, state, slots, opt_params, rng, step,
                      grads):
            """FBO of the first layer."""

            # We need the first layer's pure_fn but only for inputs and weights.
            def first_layer_pure_fn_without_state_and_rng(x, w):
                return first_layer.pure_fn(x, w, state, rng)

            # Calculate vector-Jacobian product of the reduced first layer pure fn.
            activations_after_first_layer, vjp_fn, new_state = fastmath.vjp(
                first_layer_pure_fn_without_state_and_rng,
                inputs,
                weights,
                has_aux=True)
            del activations_after_first_layer  # unused

            # The vjp function returns gradients with respect to inputs and weights.
            _, grads_weights = vjp_fn(grads)

            # In multi-device setting, average gradients from multiple devices.
            if self._n_devices > 1:
                grads_weights = _average_multidevice_gradients(grads_weights)

            # Run the first layer optimizer, which is the first one.
            new_weights, new_slots, stats = self._optimizers[0].tree_update(
                step, grads_weights, weights, slots, opt_params)
            return new_weights, new_state, new_slots, stats

        # Accelerate the first layer FBO function and store it.
        self._first_fbo = self._pjit(first_fbo)

        # Loss layer FBO is like the first layer, but has no gradients argument
        # as it is the last layer and we always use 1.0 for that. On the other
        # hand, it adds the final activation (loss) into the returned stats.

        def loss_fbo(inputs, weights, state, slots, opt_params, rng, step):
            """FBO of the final loss layer."""

            # We need a loss layer pure_fn but only for inputs and weights.
            def loss_pure_fn_without_state_and_rng(x, w):
                return loss_layer.pure_fn(x, w, state, rng)

            # Calculate the vector-Jacobian product of the reduced loss pure fn.
            loss, vjp_fn, new_state = fastmath.vjp(
                loss_pure_fn_without_state_and_rng,
                inputs,
                weights,
                has_aux=True)

            # The vjp function returns gradients with respect to inputs and weights.
            # Since loss is scalar and there are no other layers, run it at 1.0.
            grads_inputs, grads_weights = vjp_fn(jnp.ones((),
                                                          dtype=loss.dtype))

            # In multi-device setting, average gradients from multiple devices.
            if self._n_devices > 1:
                grads_weights = _average_multidevice_gradients(grads_weights)

            # Run the loss optimizer, which is the last one since it's the last layer.
            new_weights, new_slots, stats = self._optimizers[-1].tree_update(
                step, grads_weights, weights, slots, opt_params)
            stats['loss'] = loss
            return new_weights, new_state, new_slots, grads_inputs, stats

        # Accelerate the loss layer FBO function and store it.
        self._loss_fbo = self._pjit(loss_fbo)

        # Reversible layers define a reverse_and_fbo function that both reverses
        # and runs the forward-backward pass and applied the optimizer.
        # This function uses the `reverse_and_grad` method of reversible layers.

        def reverse_and_fbo_with_layer_and_opt(layer, optimizer):
            """Create the reverse_and_fbo function for a given layer and optimizer."""
            def reverse_and_fbo(output, weights, state, new_state, slots,
                                opt_params, rng, step, grads):
                """Reverse and FBO of the layer."""
                # Call the reverse_and_grad method of the layer.
                inputs, (grads_inputs,
                         grads_weights) = layer.reverse_and_grad(output,
                                                                 grads,
                                                                 weights,
                                                                 state,
                                                                 new_state,
                                                                 rng=rng)

                # For non-trainable layers, return the calculated arguments.
                if not weights:
                    return weights, slots, inputs, grads_inputs, {}

                # In multi-device setting, average gradients from multiple devices.
                if self._n_devices > 1:
                    grads_weights = _average_multidevice_gradients(
                        grads_weights)

                # Run the optimizer.
                new_weights, new_slots, stats = optimizer.tree_update(
                    step, grads_weights, weights, slots, opt_params)

                return new_weights, new_slots, inputs, grads_inputs, stats

            return reverse_and_fbo

        # Accelerate the reverse_and_fbo functions and store them.
        self._reverse_and_fbos = []
        for layer, opt in zip(reversible_layers, self._optimizers[1:-1]):
            reverse_and_fbo = reverse_and_fbo_with_layer_and_opt(layer, opt)
            self._reverse_and_fbos.append(self._pjit(reverse_and_fbo))
Example #30
0
def on_cpu(x):
    """Puts ``x`` in CPU memory in JAX."""
    if fastmath.is_backend(fastmath.Backend.JAX):
        return jax.device_put(x, jax.devices('cpu')[0])
    else:
        return x