예제 #1
0
def log_gaussian_pdf(x, mu, sigma):  # pylint: disable=invalid-name
    """Compute log N(x | mu, sigma)."""
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    _, b = jnp.linalg.slogdet(sigma)
    y = jnp.linalg.solve(sigma, x - mu)
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
예제 #2
0
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
    """Compute log N(x | mu, eye(diag_sigma))."""
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    b = jnp.sum(jnp.log(diag_sigma), axis=-1)
    y = x - mu / diag_sigma
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
예제 #3
0
def actor_loss(actions,
               advantage_weights,
               log_probab_actions_new,
               state=None):
  """Actor loss."""
  lp = np.squeeze(log_probab_actions_new)
  b = len(lp)
  log_probs = np.squeeze(lp[np.arange(b)[np.newaxis, :], actions])

  return -1.0 * np.mean(log_probs * advantage_weights), state
예제 #4
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
예제 #5
0
 def f(preds, values, returns, actions, mask):
   advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
예제 #6
0
 def AWRJointLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
   preds, values, returns, actions, mask = x
   advantages = jnp.squeeze(returns - values, axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
예제 #7
0
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None):
    """Actor loss."""

    # log_probab_actions_new's shape is (AB, 1, #C, #A), AB is actor batch.
    lp = jnp.squeeze(log_probab_actions_new, axis=1)
    AB, NC = actions.shape  # pylint: disable=invalid-name
    log_probs = lp[jnp.arange(AB)[:, None], jnp.arange(NC)[None, :], actions]

    # TODO(afrozm): Clarify this.
    #   log_probs are shaped (AB, #C), however advantage_weights are (AB,)
    return -1.0 * jnp.mean(log_probs * advantage_weights[:, None]), state
예제 #8
0
def dataset_to_stream(dataset, input_name):
    """Takes a tf.Dataset and creates a numpy stream of ready batches."""
    for example in math.dataset_as_numpy(dataset):
        features = example[0]
        inp, out = features[input_name], example[1]
        mask = features['mask'] if 'mask' in features else None
        # All input-pipeline processing should be on CPU.
        with tf.device('cpu:0'):
            # Some accelerators don't handle uint8 well, cast to int.
            if isinstance(inp, np.uint8):
                inp = inp.astype(np.int32)
            if isinstance(out, np.uint8):
                out = out.astype(np.int32)
            if len(out.shape) > 1 and out.shape[-1] == 1:
                out = np.squeeze(out, axis=-1)
        yield (inp, out) if mask is None else (inp, out, mask)
예제 #9
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
예제 #10
0
def BERT(d_model=768,
         vocab_size=30522,
         max_len=512,
         type_vocab_size=2,
         n_heads=12,
         d_ff=3072,
         n_layers=12,
         head=None,
         init_checkpoint=None,
         mode='eval',
        ):
  """BERT (default hparams are for bert-base-uncased)."""
  layer_norm_eps = 1e-12
  d_head = d_model // n_heads

  word_embeddings = tl.Embedding(d_model, vocab_size)
  type_embeddings = tl.Embedding(d_model, type_vocab_size)
  position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
  embeddings = [
      tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
      tl.Parallel(
          word_embeddings,
          type_embeddings,
          [tl.PaddingMask(),
           tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]
      ),
      tl.Add(),
      position_embeddings,
      tl.LayerNorm(epsilon=layer_norm_eps),
  ]

  encoder = []
  for _ in range(n_layers):
    attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head,
                            bias=True, masked=True, mode=mode)
    feed_forward = [
        tl.Dense(d_ff),
        tl.Gelu(),
        tl.Dense(d_model)
    ]
    encoder += [
        tl.Select([0, 1, 1]),  # Save a copy of the mask
        tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(epsilon=layer_norm_eps),
        tl.Residual(*feed_forward),
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

  encoder += [tl.Select([0], n_in=2)]  # Drop the mask

  pooler = [
      tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
      tl.Dense(d_model),
      tl.Tanh(),
  ]

  init_checkpoint = init_checkpoint if mode == 'train' else None
  bert = PretrainedBERT(
      embeddings + encoder + pooler, init_checkpoint=init_checkpoint)

  if head is not None:
    bert = tl.Serial(bert, head())

  return bert
예제 #11
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
    return [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]

  encoder = tl.Serial([
      in_encoder,
      tl.Dup(),
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),                 # tok_e tok_d tok_d
      tl.Branch([], [                       # tok_e mask  tok_d .....
          tl.PaddingMask(),
          tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]),

      # Encode.
      encoder,                              # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(mode=mode),             # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
예제 #12
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax
    # Transformer, but may not be the right option in general.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=0.0, mode=mode)
    return [
        tl.Embedding(d_model, vocab_size),
        # TODO(kitaev): BroadcastedDropout?
        tl.Dropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, mode)
      for _ in range(n_encoder_layers)]

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch(
          in_encoder, [tl.PaddingMask(),
                       tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]
          ),                                # vec_e  mask  tok_d .....
      tl.Dup(),                             # vec_e1 vec_e2 mask tok_d .....
      tl.ReversibleSerial(encoder_blocks),  # vec_e1 vec_e2 mask tok_d .....
      # The two sets of activations need to be reduced to one, in this case by
      # averaging them. Note that ReformerLM concatenates instead. Various
      # options (concat, average, add, keep only one, etc.) seem to perform
      # similarly. We don't concatenate here because we want exact parameter
      # parity with the standard Transformer.
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_e  mask tok_d .....
      tl.LayerNorm(),                       # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(),                      # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
예제 #13
0
파일: reformer.py 프로젝트: yangcaot/trax
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    if not axial_pos_shape:
      positional_encoding = tl.PositionalEncoding(
          max_len=max_len, dropout=dropout, mode=mode)
    else:
      assert d_axial_pos_embs is not None
      positional_encoding = tl.AxialPositionalEncoding(
          shape=axial_pos_shape, d_embs=d_axial_pos_embs,
          dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
          dropout=dropout, mode=mode)

    return [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type, dropout,
          ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([                # tok_e mask_e tok_e tok_d tok_d
      in_encoder,                      # vec_e mask_e tok_e tok_d tok_d
      tl.Dup(),                        # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),                  # tok_e tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: np.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # tok_e mask_e tok_e tok_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 0, 1, 2]),                 #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),                # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          _MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder, given their masks.
      tl.Select([2, 0, 3, 1]),          # svec_d mask_d vec_e mask_e tok_e tok_d
      _ConcatWithPadding(),                        # vec_ed tok_e tok_d

      # Run (encoder and) decoder blocks.
      tl.Dup(),                                    # vec_ed1 vec_ed2 tok_e tok_d
      tl.ReversibleSerial(decoder_blocks),         # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                     # vec_ed tok_e tok_d tok_d
      _StripFromConcatenateWithPadding(),          # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),                 # vec_d tok_d
      tl.LogSoftmax(),                             # vec_d tok_d
  )