Example #1
0
    def _init_host_and_devices(self, n_devices=None, random_seed=None):
        """Initializes host and device attributes for this trainer.

    Args:
      n_devices: Number of devices this trainer will use. If `None`, get the
          number from the backend.
      random_seed: Random seed as the starting point for all random numbers used
          by the trainer. If `None`, calculate one from system time and host id.

    Returns:
      is_chief: True if this trainer has special chief responsibilities.
      n_devices: The passed in value of n_devices or a computed default.
      random_seed: The passed in value of random_seed or a computed default.
    """
        if math.backend_name() == 'jax':
            host_id = jax.host_id()
            host_count = jax.host_count()
        else:
            host_id = 0
            host_count = 1
        is_chief = (host_id == 0)

        device_count = math.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and math.backend_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))

        if random_seed is None and host_count > 1:
            random_seed = int(1e6 * (host_id + time.time())) % 2**32
        return is_chief, n_devices, init_random_number_generators(random_seed)
Example #2
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of activations.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        if self._mode != 'train':
            return x
        state, rng = self.state, self.rng
        rate = self._initial_rate
        if isinstance(state, dict) and self._name in state:
            rate = state[self._name]
        mask_shape = list(x.shape)
        for axis in self._shared_axes:
            mask_shape[axis] = 1
        if math.backend_name() == 'jax':
            keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate)
        else:
            keep_prob = 1.0 - rate
        keep = math.random.bernoulli(rng, keep_prob, tuple(mask_shape))
        if math.backend_name() == 'jax':
            keep_prob = jax.lax.tie_in(keep, keep_prob)
        mask = keep.astype(x.dtype) / keep_prob
        return x * mask
Example #3
0
  def forward_with_state(self, inputs, weights, state, rng):
    if self._mode != 'predict':
      x = inputs
      symbol_size = jnp.shape(x)[1]
      px = weights[:, :symbol_size, :]
      if self._dropout == 0:
        return (x + px, state)
      else:
        noise_shape = list(px.shape)
        for dim in self._dropout_broadcast_dims:
          noise_shape[dim] = 1
        keep_prob = 1.0 - self._dropout
        if math.backend_name() == 'jax':
          keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype))
        keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
        multiplier = keep.astype(x.dtype) / keep_prob
        return (x + px * multiplier, state)
    else:
      if self._dropout != 0:
        raise ValueError(f'In predict mode, but dropout rate '
                         f'({self._dropout}) is not zero.')

      # State in this class is only used for fast inference. In that case,
      # the model is called with consecutive elements position-by-position.
      # This positional encoding layer needs to store the index of the current
      # position then and increment it on each call -- that's how state is used
      # and updated below.
      if inputs.shape[1] == 1:
        return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1)
      else:
        emb = []
        for i in range(inputs.shape[0]):
          emb.append(jax.lax.dynamic_slice_in_dim(
              weights[0], state[i], inputs.shape[1], axis=0))
        return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
Example #4
0
    def save_state(self, keep, prefix='model'):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self.n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(*math.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if math.backend_name() == 'jax':
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        weights_file = os.path.join(output_dir, prefix + '.pkl.gz')

        # This dict will be stored as the model.
        trainer_state_dict = make_trainer_state_dict(step, opt_state, history,
                                                     model_state,
                                                     self._input_signature)
        self._save_state_dict(trainer_state_dict, weights_file)

        if keep:
            weights_file = os.path.join(output_dir,
                                        '{}_{}.pkl.gz'.format(prefix, step))
            self._save_state_dict(trainer_state_dict, weights_file)
Example #5
0
    def forward_and_backward(self,
                             inputs,
                             ct,
                             state,
                             new_state,
                             rng=None,
                             **kwargs):
        assert math.backend_name() == 'jax', (
            'JAX backend is required to use forward_and_backward.')

        if ct is not None and new_state is not tl.EMPTY_STATE:
            recovered_rng = new_state
            is_same = (rng[0] == recovered_rng[0]) & (rng[1]
                                                      == recovered_rng[1])
            is_same = is_same.astype(np.float32)
            # Divides by zero if rngs are not the same, which results in NaNs.
            inputs = (inputs[0] / is_same, inputs[1] / is_same,
                      inputs[2] / is_same)

        def _do_forward(x):  # pylint: disable=invalid-name
            res, _ = self.forward_with_state(x, state=state, rng=rng, **kwargs)
            return res

        output, vjpfun = jax.vjp(_do_forward, inputs)
        return output, vjpfun(ct)[0]
Example #6
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Example #7
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        del weights
        q, k, v = inputs
        if self._mode in ('train', 'eval'):
            mask_size = q.shape[-2]
            # Not all backends define np.tril. However, using onp.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if math.backend_name() == 'jax':
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=onp.bool_),
                               k=0)
            else:
                mask = onp.tril(onp.ones((1, mask_size, mask_size),
                                         dtype=onp.bool_),
                                k=0)
        else:
            assert self._mode == 'predict'
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state
Example #8
0
    def forward_with_state(self, inputs, weights, state, rng):
        del weights
        q, k, v = inputs

        if self._mode == 'predict':
            state = _fast_inference_update_state(inputs, state)
            (k, v, mask, _) = state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if math.backend_name() == 'jax':
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res = DotProductAttention(q,
                                  k,
                                  v,
                                  mask,
                                  dropout=self._dropout,
                                  mode=self._mode,
                                  rng=rng)
        return res, state
Example #9
0
 def forward_with_state(self,
                        inputs,
                        weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE,
                        rng=None,
                        **kwargs):
     if self._mode in ('train', 'eval'):
         x = inputs
         symbol_size = np.shape(x)[1]
         px = weights[:, :symbol_size, :]
         if self._dropout == 0:
             return (x + px, state)
         else:
             noise_shape = list(px.shape)
             for dim in self._dropout_broadcast_dims:
                 noise_shape[dim] = 1
             keep_prob = 1.0 - self._dropout
             if math.backend_name() == 'jax':
                 keep_prob = jax.lax.tie_in(
                     x, np.full((), keep_prob, dtype=x.dtype))
             keep = math.random.bernoulli(rng, keep_prob,
                                          tuple(noise_shape))
             multiplier = keep.astype(x.dtype) / keep_prob
             return (x + px * multiplier, state)
     else:
         assert self._mode == 'predict'
         assert self._dropout == 0
         # State in this class is only used for fast inference. In that case,
         # the model is called with consecutive elements position-by-position.
         # This positional encoding layer needs to store the index of the current
         # position then and increment it on each call -- that's how state is used
         # and updated below.
         return (inputs + np.expand_dims(weights[0, state, :], 1),
                 state + 1)
Example #10
0
 def _l2_norm(self, flat_list):
     """Returns the aggregate L2 norm of a list of tensors."""
     if math.backend_name() == 'jax':
         norm = np.sqrt(sum(np.vdot(x, x) for x in flat_list))
     else:  # TODO(lukaszkaiser): add vdot to TF-numpy
         norm = np.sqrt(sum(np.sum(x * x) for x in flat_list))
     return norm
Example #11
0
  def _do_custom_gradients(self, x, weights, state, rng):
    """Calls this layer for a forward pass, but with custom gradients."""
    assert math.backend_name() == 'jax', (
        'Custom gradients are only supported in JAX for now.')

    # See this link for how custom transformations are defined in JAX:
    # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
    @jax.custom_transforms
    def _do_forward(y, weights):
      old_weights, old_state, old_rng = self._weights, self._state, self._rng
      res = self.forward(y, weights)
      s = self._state
      self._weights, self._state, self._rng = old_weights, old_state, old_rng
      return res, s

    # This is the custom gradient (vector-jacobian product in JAX) function.
    # For the exact specification of this custom transformation see this link:
    # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
    def do_forward_vjp(y, weights):
      """Custom gradient (vjp) function."""
      old_weights, old_state, old_rng = self._weights, self._state, self._rng
      output = self.forward(y, weights)
      new_state = self._state
      self._weights, self._state, self._rng = old_weights, old_state, old_rng
      def vjpfun(grad):
        grad = grad[0]  # Ignore dummy gradient wrt state.
        res = self.backward(y, output, grad, weights, state, new_state, rng)
        return res
      return (output, new_state), vjpfun

    jax.defvjp_all(_do_forward, do_forward_vjp)
    output, state = _do_forward(x, weights)
    state = jax.lax.stop_gradient(state)
    return output, state
Example #12
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)
        emb = np.concatenate(embs, -1)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            return inputs + emb[:, state, :][:, None, :], state + 1
        elif self._dropout == 0:
            return inputs + np.reshape(emb, inputs.shape), state
        else:
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if math.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Example #13
0
def one_hot(x, n_categories, dtype=np.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = np.arange(n_categories)
    if math.backend_name() == 'jax':
        # Work around a jax broadcasting issue.
        indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
    return np.array(x[..., np.newaxis] == indices_less_than_n, dtype)
Example #14
0
 def f(x):
     if n_devices > 1 and math.backend_name() == 'jax':
         return _multi_device_put(x)
     elif n_devices > 1:
         return jnp.broadcast_to(x, (n_devices, ) + x.shape)
     else:
         return x
Example #15
0
 def forward(self, x):
     """Dropout, with broadcasting to save memory."""
     if self._mode == 'train' and self._rate > 0.0:
         noise_shape = list(x.shape)
         for dim in self._broadcast_dims:
             noise_shape[dim] = 1
         if math.backend_name() == 'jax':
             keep_prob = jax.lax.tie_in(self.rng, 1.0 - self._rate)
         else:
             keep_prob = 1.0 - self._rate
         keep = random.bernoulli(self.rng, keep_prob, tuple(noise_shape))
         if math.backend_name() == 'jax':
             keep_prob = jax.lax.tie_in(keep, keep_prob)
         multiplier = keep.astype(x.dtype) / keep_prob
         return x * multiplier
     else:
         return x
Example #16
0
 def forward_and_backward(self, inputs, grad, state=base.EMPTY_STATE,
                          new_state=base.EMPTY_STATE, rng=None):
   del new_state
   assert math.backend_name() == 'jax', (
       'JAX backend is required to use forward_and_backward.')
   # Simultaneous forward pass and backprop through the attention mechanism.
   def _do_forward(x):  # pylint: disable=invalid-name
     res, _ = self.forward_with_state(x, state=state, rng=rng)
     return res
   output, vjpfun = jax.vjp(_do_forward, inputs)
   return output, vjpfun(grad)[0]
Example #17
0
def _jax_and_tf_configure_for_devices():
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')
        jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
        jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
    if FLAGS.enable_eager_execution and math.backend_name() in ('numpy',
                                                                'jax'):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')
Example #18
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and math.backend_name() == 'tf':
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Example #19
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference."""
    assert math.backend_name() == 'jax', (
        'JAX backend is required to use the predict mode.')
    for x in inputs:
        assert x.shape[1] == 1, (
            'In predict mode the input sequence must be of length 1.')
    # Fast inference: run with only 1 query in each step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    (ks, vs, mask, index) = state
    ks = jax.ops.index_update(ks, jax.ops.index[:, index, :], new_k[:, 0, :])
    vs = jax.ops.index_update(vs, jax.ops.index[:, index, :], new_v[:, 0, :])
    mask = jax.ops.index_update(mask, jax.ops.index[:, :, index], 1)
    return (ks, vs, mask, index + 1)
Example #20
0
 def policy(self, trajectory):
     """Chooses an action to play after a trajectory."""
     model = self._policy_collect_model
     model.weights = self._policy_trainer.model_weights
     tr_slice = trajectory[-self._max_slice_length:]
     trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np)
     # Add batch dimension to trajectory_np and run the model.
     pred = model(trajectory_np.observations[None, ...], n_accelerators=1)
     # Pick element 0 from the batch (the only one), last (current) timestep.
     pred = pred[0, -1, :]
     sample = self._policy_dist.sample(pred)
     result = (sample, pred)
     if math.backend_name() == 'jax':
         result = math.nested_map(lambda x: x.copy(), result)
     return result
Example #21
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out
Example #22
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.concatenate(embs, -1)
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            emb = jax.lax.dynamic_slice_in_dim(emb,
                                               state,
                                               inputs.shape[1],
                                               axis=1)
            return inputs + emb, state + inputs.shape[1]
        elif self._dropout == 0:
            # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
            # leads to memory blow-up on TPU.
            # emb = np.concatenate(embs, -1)
            # return inputs + np.reshape(emb, inputs.shape), state
            return inputs + np.concatenate([
                np.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], ))
                for emb in embs
            ], -1), state
        else:
            emb = np.concatenate(embs, -1)
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if math.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Example #23
0
def _fast_inference_update_state(inputs, state):
  """Updates state of a causal attention layer for fast inference."""
  if math.backend_name() != 'jax':
    raise ValueError(f'JAX backend is required in predict mode, but found '
                     f'backend ({math.backend_nameO()}).')
  for x in inputs:
    if x.shape[1] != 1:
      raise ValueError(f'In predict mode, input sequence must have length 1, '
                       f'instead has length {x.shape[1]}.')
  # Fast inference: run with only 1 query in each step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  (ks, vs, mask, seq_indices) = state
  batch_indices = jnp.arange(ks.shape[0])
  ks = jax.ops.index_update(
      ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :])
  vs = jax.ops.index_update(
      vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :])
  mask = jax.ops.index_update(
      mask, jax.ops.index[batch_indices, :, seq_indices], 1)
  return (ks, vs, mask, seq_indices + 1)
Example #24
0
    def save_state(self, keep):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self.n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(*math.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if math.backend_name() == 'jax':
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        pkl_module = utils.get_pickle_module()
        weights_file = os.path.join(output_dir, 'model.pkl')
        with tf.io.gfile.GFile(weights_file, 'wb') as f:
            pkl_module.dump((tuple(opt_state), step, history, model_state), f)
        if keep:
            weights_file = os.path.join(output_dir,
                                        'model_{}.pkl'.format(step))
            with tf.io.gfile.GFile(weights_file, 'wb') as f:
                pkl_module.dump((tuple(opt_state), step, history, model_state),
                                f)
        log('Model saved to %s' % weights_file, stdout=False)
Example #25
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.CrossEntropyLoss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.Adafactor,
          lr_schedule=lr.MultifactorSchedule,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          save_backward_graph=False,
          has_weights=False,
          nontrainable_param_map=None,
          id_to_mask=None,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          custom_train_fn=None):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    save_backward_graph: bool, if True, save backward graph to file too.
    has_weights: bool, whether weights are included in the inputs.
    nontrainable_param_map: dict, mapping from model nontrainable parameter
      names to control names in PolicySchedule.
    id_to_mask: id to mask out (None by default).
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    custom_train_fn: custom train function to call, entirely bypassing this one

  Returns:
    trax.TrainerState
  """
    if custom_train_fn is not None:
        return custom_train_fn(output_dir, model=model)

    n_devices = num_devices()
    # TODO(lukaszkaiser): remove has_weights and id_to_mask (configure loss).
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule,
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            checkpoints_at=checkpoints_at,
                            has_weights=has_weights,
                            nontrainable_param_map=nontrainable_param_map,
                            metrics=metrics,
                            id_to_mask=id_to_mask,
                            checkpoint_lowest=checkpoint_lowest,
                            checkpoint_highest=checkpoint_highest)

    epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    try:
        for epoch_steps in epochs(steps, trainer.step, epoch_steps):
            trainer.train_epoch(epoch_steps, eval_steps)

            # Update nontrainable parameters with new history
            trainer.update_nontrainable_params()

            # Bookkeeping we do at the first step
            if trainer.step == 1:
                # Save computation graph (single-device only for now)
                if (save_graphs and math.backend_name() == 'jax'):
                    trainer.save_computation_graphs(save_backward_graph)

                # Save Gin config
                trainer.save_gin()

        trainer.log_step('Training done')
    except Exception as e:
        raise e
    finally:
        trainer.close()
    return trainer.state
Example #26
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        if metrics is not None:
            self._metrics_dict = metrics
        else:
            self._metrics_dict = _DEFAULT_METRICS
            self._metrics_dict['loss'] = loss_fn
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if math.backend_name() == 'jax':
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state)
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Example #27
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if math.backend_name() == 'jax':
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        # TODO(kitaev): axial positional encoding is better for very long sequences.
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode)
        return [
            tl.Embedding(d_model, vocab_size),
            BroadcastedDropout(rate=dropout, mode=mode),
            positional_encoding,
        ]

    # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
    # position embeddings between the encoder and decoder if output_vocab_size is
    # None. This isn't supported here because (a) Trax shares weights by sharing
    # layer instances, but we need two separate instances to have mode == 'eval'
    # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
    # not work if its sublayers participate in any weight sharing.

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout,
                     ff_activation, ff_dropout, mode)
        for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                            ff_dropout, mode) for _ in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                     # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Example #28
0
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if math.backend_name() == 'jax':
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        if not axial_pos_shape:
            positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                        dropout=dropout,
                                                        mode=mode)
        else:
            assert d_axial_pos_embs is not None
            positional_encoding = tl.AxialPositionalEncoding(
                shape=axial_pos_shape,
                d_embs=d_axial_pos_embs,
                dropout_broadcast_dims=tuple(range(1,
                                                   len(axial_pos_shape) + 1)),
                dropout=dropout,
                mode=mode)

        return [
            tl.Embedding(d_model, vocab_size),
            BroadcastedDropout(rate=dropout, mode=mode),
            positional_encoding,
        ]

    # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
    # position embeddings between the encoder and decoder if output_vocab_size is
    # None. This isn't supported here because (a) Trax shares weights by sharing
    # layer instances, but we need two separate instances to have mode == 'eval'
    # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
    # not work if its sublayers participate in any weight sharing.

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout,
                     ff_activation, ff_dropout, mode)
        for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([  # tok_e mask_e tok_e tok_d tok_d
        in_encoder,  # vec_e mask_e tok_e tok_d tok_d
        tl.Dup(),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # tok_e mask_e tok_e tok_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Decode.
        tl.Select([3, 0, 1, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        tl.Branch([], _MaskOfRightShiftedArray()
                  ),  # stok_d mask_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d mask_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder, given their masks.
        tl.Select([2, 0, 3, 1]),  # svec_d mask_d vec_e mask_e tok_e tok_d
        _ConcatWithPadding(),  # vec_ed tok_e tok_d

        # Run (encoder and) decoder blocks.
        tl.Dup(),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.ReversibleSerial(decoder_blocks),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        _StripFromConcatenateWithPadding(),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
Example #29
0
def _is_jit_init(value=None):
    if value is None:
        value = math.backend_name() == 'jax'
    return value