Ejemplo n.º 1
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of activations.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        if self._mode != 'train':
            return x
        state, rng = self.state, self.rng
        rate = self._initial_rate
        if isinstance(state, dict) and self._name in state:
            rate = state[self._name]
        mask_shape = list(x.shape)
        for axis in self._shared_axes:
            mask_shape[axis] = 1
        if fastmath.backend_name() == 'jax':
            keep_prob = jax.lax.tie_in(self.rng, 1.0 - rate)
        else:
            keep_prob = 1.0 - rate
        keep = fastmath.random.bernoulli(rng, keep_prob, tuple(mask_shape))
        if fastmath.backend_name() == 'jax':
            keep_prob = jax.lax.tie_in(keep, keep_prob)
        mask = keep.astype(x.dtype) / keep_prob
        return x * mask
Ejemplo n.º 2
0
  def _init_host_and_devices(self, n_devices=None, random_seed=None):
    """Initializes host and device attributes for this trainer.

    Args:
      n_devices: Number of devices this trainer will use. If `None`, get the
          number from the backend.
      random_seed: Random seed as the starting point for all random numbers used
          by the trainer. If `None`, calculate one from system time and host id.

    Returns:
      is_chief: True if this trainer has special chief responsibilities.
      n_devices: The passed in value of n_devices or a computed default.
      random_seed: The passed in value of random_seed or a computed default.
    """
    if fastmath.backend_name() == 'jax':
      host_id = jax.host_id()
      host_count = jax.host_count()
    else:
      host_id = 0
      host_count = 1
    is_chief = (host_id == 0)

    device_count = fastmath.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count and fastmath.backend_name() == 'jax':
      raise ValueError('JAX cannot work yet with n_devices != all devices: '
                       '%d != %d' % (n_devices, device_count))

    if random_seed is None and host_count > 1:
      random_seed = int(1e6 * (host_id + time.time())) % 2**32
    return is_chief, n_devices, init_random_number_generators(random_seed)
Ejemplo n.º 3
0
 def policy(self, trajectory, temperature=1):
   """Chooses an action to play after a trajectory."""
   tr_slice = trajectory[-self._max_slice_length:]
   trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np)
   # Add batch dimension to trajectory_np and run the model.
   obs = trajectory_np.observations[None, ...]
   values = self._run_value_model(obs, use_eval_model=False)
   # We insisit that values and observations have the shape
   # (batch, length, ...), where the length is the number of subsequent
   # observations on a given trajectory
   assert values.shape[:1] == obs.shape[:1]
   # We select the last element in the batch and the value
   # related to the last (current) observation
   values = values[0, -1, :]
   # temperature == 0 is used in another place in order to trigger eval
   if np.random.random_sample() < self._exploration_rate(self._epoch) and \
       temperature == 1:
     sample = np.array(self.task.action_space.sample())
   else:
     # this is our way of doing the argmax
     sample = jnp.argmax(values)
   result = (sample, values)
   if fastmath.backend_name() == 'jax':
     result = fastmath.nested_map(lambda x: x.copy(), result)
   return result
Ejemplo n.º 4
0
def mean_or_pmean(n_devices, x, axis=None):
  """jnp.mean or pmean.

  `x` is a distributed value. Directly calling jnp.mean on `x` means stacking
  x's components together to form a large array and then doing jnp.mean on
  it. In TF, stacking `x` will introduce D2H copy, so we use a collective
  (pmean) here instead of directly calling jnp.mean for TF.

  Args:
    n_devices: number of devices.
    x: a distributed array.
    axis: the axis to reduce. Can only be 0 or None.

  Returns:
    A local array.
  """
  if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1:
    if axis not in (None, 0):
      raise ValueError('axis can only be None or 0')
    x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices
    if axis is None:
      x = jnp.mean(x)
    return x
  else:
    return jnp.mean(x, axis=axis)
Ejemplo n.º 5
0
def one_hot(x, n_categories, dtype=jnp.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = jnp.arange(n_categories)
    if fastmath.backend_name() == 'jax':
        # Work around a jax broadcasting issue.
        indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
    return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
Ejemplo n.º 6
0
 def _l2_norm(self, flat_list):
     """Returns the aggregate L2 norm of a list of tensors."""
     if fastmath.backend_name() == 'jax':
         norm = jnp.sqrt(sum(jnp.vdot(x, x) for x in flat_list))
     else:  # TODO(lukaszkaiser): add vdot to TF-numpy
         norm = jnp.sqrt(sum(jnp.sum(x * x) for x in flat_list))
     return norm
Ejemplo n.º 7
0
 def f(x):
     if n_devices > 1 and fastmath.backend_name() == 'jax':
         return _multi_device_put(x)
     elif n_devices > 1:
         return jnp.broadcast_to(x, (n_devices, ) + x.shape)
     else:
         return x
Ejemplo n.º 8
0
    def save_state(self, keep, prefix='model'):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self.n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(
                *fastmath.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if fastmath.backend_name() == 'jax':
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        weights_file = os.path.join(output_dir, prefix + '.pkl.gz')

        # This dict will be stored as the model.
        trainer_state_dict = make_trainer_state_dict(step, opt_state, history,
                                                     model_state,
                                                     self._input_signature)
        self._save_state_dict(trainer_state_dict, weights_file)

        if keep:
            weights_file = os.path.join(output_dir,
                                        '{}_{}.pkl.gz'.format(prefix, step))
            self._save_state_dict(trainer_state_dict, weights_file)
Ejemplo n.º 9
0
    def forward(self, inputs):
        q, k, v = inputs

        if self._mode == 'predict':
            self.state = _fast_inference_update_state(inputs, self.state)
            (k, v, mask, _) = self.state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if fastmath.backend_name() == 'jax':
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res, dots = DotProductAttention(q,
                                        k,
                                        v,
                                        mask,
                                        dropout=self._dropout,
                                        mode=self._mode,
                                        rng=self.rng)
        if self._mode == 'viz':
            self.state = dots
        return res
Ejemplo n.º 10
0
def _jax_and_tf_configure_for_devices():  # pylint: disable=missing-function-docstring
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')
        jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
        jax.config.update('jax_backend_target', FLAGS.jax_backend_target)
    if (FLAGS.enable_eager_execution
            and fastmath.backend_name() in ('numpy', 'jax')):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')
Ejemplo n.º 11
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and fastmath.backend_name() == 'tf':
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Ejemplo n.º 12
0
 def policy(self, trajectory, temperature=1.0):
   """Chooses an action to play after a trajectory."""
   model = self._policy_collect_model
   if temperature != 1.0:  # When evaluating (t != 1.0), don't collect stats
     model = self._policy_eval_model
     model.state = self._policy_collect_model.state
   model.replicate_weights(self._policy_trainer.model_weights)
   tr_slice = trajectory[-self._max_slice_length:]
   trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np)
   # Add batch dimension to trajectory_np and run the model.
   pred = model(trajectory_np.observations[None, ...])
   # Pick element 0 from the batch (the only one), last (current) timestep.
   pred = pred[0, -1, :]
   sample = self._policy_dist.sample(pred, temperature=temperature)
   result = (sample, pred)
   if fastmath.backend_name() == 'jax':
     result = fastmath.nested_map(lambda x: x.copy(), result)
   return result
Ejemplo n.º 13
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if fastmath.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out, dots
Ejemplo n.º 14
0
    def forward(self, inputs):
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if fastmath.backend_name() == 'jax':
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
Ejemplo n.º 15
0
def mean_or_pmean(n_devices, x, axis=None):
  """Computes the mean of a distributed value ``x``.

  Args:
    n_devices: Number of devices.
    x: Distributed array.
    axis: Axis along which to compute means; can only be ``0`` or ``None``.

  Returns:
    A local array.
  """
  if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1:
    if axis not in (None, 0):
      raise ValueError('axis can only be None or 0')
    x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices
    if axis is None:
      x = jnp.mean(x)
    return x
  else:
    return jnp.mean(x, axis=axis)
Ejemplo n.º 16
0
def _fast_inference_update_state(inputs, state):
  """Updates state of a causal attention layer for fast inference."""
  if fastmath.backend_name() != 'jax':
    raise ValueError(f'JAX backend is required in predict mode, but found '
                     f'backend ({fastmath.backend_nameO()}).')
  for x in inputs:
    if x.shape[1] != 1:
      raise ValueError(f'In predict mode, input sequence must have length 1, '
                       f'instead has length {x.shape[1]}.')
  # Fast inference: run with only 1 query in each step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  (ks, vs, mask, seq_indices) = state
  batch_indices = jnp.arange(ks.shape[0])
  ks = jax.ops.index_update(
      ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :])
  vs = jax.ops.index_update(
      vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :])
  mask = jax.ops.index_update(
      mask, jax.ops.index[batch_indices, :, seq_indices], 1)
  return (ks, vs, mask, seq_indices + 1)
Ejemplo n.º 17
0
    def forward(self, inputs):
        rng, state = self.rng, self.state
        embs = []
        for ax_emb in self.weights:
            ax_emb = jnp.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                      self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = jnp.concatenate(embs, -1)
            emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            emb = jax.lax.dynamic_slice_in_dim(emb,
                                               state,
                                               inputs.shape[1],
                                               axis=1)
            self.state = state + inputs.shape[1]
            return inputs + emb
        elif self._dropout == 0:
            # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
            # leads to memory blow-up on TPU.
            # emb = jnp.concatenate(embs, -1)
            # return inputs + jnp.reshape(emb, inputs.shape), state
            return inputs + jnp.concatenate([
                jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], ))
                for emb in embs
            ], -1)
        else:
            emb = jnp.concatenate(embs, -1)
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if fastmath.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, jnp.full((), keep_prob, dtype=inputs.dtype))
            keep = fastmath.random.bernoulli(rng, keep_prob,
                                             tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob
            return inputs + jnp.reshape(emb * multiplier, inputs.shape)
Ejemplo n.º 18
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference.

  The layer state stores tensors with cached values of keys and values,
  as well as the mask and an index. To make shapes static, keys and values
  in the state are long, and the index indicates where the new keys and values
  from inputs need to be appended. Mask ensures that attention will only look
  at keys upto index.

  During update, we append new_keys and new_values to keys and values at
  position given by index. We also update mask (which starts as all-0s) to
  be 1 at the new keys positions. And we increment index by length of new keys.

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, mask, index)

  Returns:
    Updated state.
  """
    if fastmath.backend_name() != 'jax':
        raise ValueError(f'JAX backend is required in predict mode, but found '
                         f'backend ({fastmath.backend_nameO()}).')

    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    length = new_k.shape[1]
    (ks, vs, mask, idx) = state
    # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
    # with index_update when length == 1 is worth it.
    # Keys and values are of shape [batch_size, length, d_kv].
    ks = jax.lax.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
    vs = jax.lax.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
    # Mask is of shape [batch_size, 1 (for heads), length].
    new_mask = jnp.ones((mask.shape[0], mask.shape[1], length))
    mask = jax.lax.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2)
    return (ks, vs, mask, idx + length)
Ejemplo n.º 19
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.CrossEntropyLoss(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule=lr.multifactor(),
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          custom_train_fn=None):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    custom_train_fn: custom train function to call, entirely bypassing this one

  Returns:
    trax.TrainerState
  """
  if custom_train_fn is not None:
    return custom_train_fn(output_dir, model=model)

  n_devices = num_devices()
  trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs,
                          output_dir,
                          random_seed=random_seed,
                          n_devices=n_devices,
                          checkpoints_at=checkpoints_at,
                          metrics=metrics,
                          checkpoint_lowest=checkpoint_lowest,
                          checkpoint_highest=checkpoint_highest)

  epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  trainer.log_step('Starting training using %d devices' % trainer.n_devices)
  trainer.print_n_weights()

  try:
    for epoch_steps in epochs(steps, trainer.step, epoch_steps):
      trainer.train_epoch(epoch_steps, eval_steps)

      # Bookkeeping we do at the first step
      if trainer.step == 1:
        # Save computation graph (single-device only for now)
        if (save_graphs and fastmath.backend_name() == 'jax'):
          trainer.save_computation_graphs()

        # Save Gin config
        trainer.save_gin()

    trainer.log_step('Training done')
  except Exception as e:
    raise e
  finally:
    trainer.close()
  return trainer.state
Ejemplo n.º 20
0
 def test_use_backend_str(self):
     with fastmath.use_backend('tensorflow-numpy'):
         self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy')
Ejemplo n.º 21
0
 def test_use_backend_enum(self):
     with fastmath.use_backend(fastmath.Backend.NUMPY):
         self.assertEqual(fastmath.backend_name(), 'numpy')
Ejemplo n.º 22
0
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  if fastmath.backend_name() == 'jax':
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    if not axial_pos_shape:
      positional_encoding = tl.PositionalEncoding(
          max_len=max_len, dropout=dropout, mode=mode)
    else:
      assert d_axial_pos_embs is not None
      positional_encoding = tl.AxialPositionalEncoding(
          shape=axial_pos_shape, d_embs=d_axial_pos_embs,
          dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
          dropout=dropout, mode=mode)

    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type, dropout,
          ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([                # tok_e mask_e tok_e tok_d tok_d
      in_encoder,                      # vec_e mask_e tok_e tok_d tok_d
      tl.Dup(),                        # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),                  # tok_e tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # tok_e mask_e tok_e tok_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 0, 1, 2]),                 #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),                # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          _MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder, given their masks.
      tl.Select([2, 0, 3, 1]),          # svec_d mask_d vec_e mask_e tok_e tok_d
      _ConcatWithPadding(),                        # vec_ed tok_e tok_d

      # Run (encoder and) decoder blocks.
      tl.Dup(),                                    # vec_ed1 vec_ed2 tok_e tok_d
      tl.ReversibleSerial(decoder_blocks),         # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                     # vec_ed tok_e tok_d tok_d
      _StripFromConcatenateWithPadding(),          # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),                 # vec_d tok_d
      tl.LogSoftmax(),                             # vec_d tok_d
  )
Ejemplo n.º 23
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  if fastmath.backend_name() == 'jax':
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation,
          ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([
      in_encoder,
      tl.Dup(),
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),                 # tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                     # tok_e mask  tok_d .....

      # Encode.
      encoder,                              # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(mode=mode),             # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),    # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
Ejemplo n.º 24
0
 def test_backend_can_be_set(self):
     self.assertEqual(fastmath.backend_name(), 'jax')
     fastmath.set_backend('tensorflow-numpy')
     self.assertEqual(fastmath.backend_name(), 'tensorflow-numpy')
     fastmath.set_backend(None)
     self.assertEqual(fastmath.backend_name(), 'jax')
Ejemplo n.º 25
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.backend_name() == 'jax':
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Ejemplo n.º 26
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.CrossEntropyLoss(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
    if use_loop:
        n_devices = num_devices() or fastmath.device_count()

        # Prepare the training task.
        # Inputs is either an Inputs instance or a function that returns it.
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            inputs = inputs()
        train_task = training.TrainTask(inputs.train_stream(n_devices),
                                        loss_layer=loss_fn,
                                        optimizer=optimizer(),
                                        lr_schedule=lr_schedule_fn(),
                                        n_steps_per_checkpoint=eval_frequency)

        # Prepare the evaluation.
        metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        names, metrics = zip(*metrics_dict.items())
        eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                      metrics,
                                      metric_names=names,
                                      n_eval_batches=eval_steps)

        # Prepare the training loop.
        checkpoint_at = None
        if checkpoints_at is not None:
            checkpoint_at = lambda step: step in checkpoints_at
        loop = training.Loop(model(mode='train'), [train_task],
                             eval_model=model(mode='eval'),
                             eval_tasks=[eval_task],
                             output_dir=output_dir,
                             checkpoint_at=checkpoint_at,
                             n_devices=n_devices,
                             random_seed=random_seed)

        # Train and return the loop.
        loop.run(steps)
        return loop

    n_devices = num_devices()
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule_fn(),
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            checkpoints_at=checkpoints_at,
                            metrics=metrics,
                            checkpoint_lowest=checkpoint_lowest,
                            checkpoint_highest=checkpoint_highest)

    epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    try:
        for epoch_steps in epochs(steps, trainer.step, epoch_steps):
            trainer.train_epoch(epoch_steps, eval_steps)

            # Bookkeeping we do at the first step
            if trainer.step == 1:
                # Save computation graph (single-device only for now)
                if (save_graphs and fastmath.backend_name() == 'jax'):
                    trainer.save_computation_graphs()

                # Save Gin config
                trainer.save_gin()

        trainer.log_step('Training done')
    except Exception as e:
        raise e
    finally:
        trainer.close()
    return trainer.state