コード例 #1
0
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 attention_chunk_size=0,
                 center_layernorm=True,
                 use_bfloat16=False,
                 use_two_swaps_per_block=True,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder
      block, otherwise use only one swap.
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    attention = ct.ApplyAttentionLayer(
        attention_type=attention_type,
        d_model=d_model,
        n_heads=n_heads,
        d_qk=d_model // n_heads,
        d_v=d_model // n_heads,
        masked=True,
        causal=False,
        attention_dropout=dropout,
        output_dropout=dropout,
        attention_chunk_size=attention_chunk_size,
        mode=mode)
    # TODO(lukaszkaiser): refactor efficient attention layers to unify the API
    # If we're using standard attention, we need to pass reshaped mask and not
    # return the mask to be compatible with the EfficientAttention API.
    if attention.n_out == 2:

        def reshape_mask(mask):
            return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))

        attention = tl.Serial(
            tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2),
            attention, tl.Select([0], n_in=2))

    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(center=center_layernorm),
        attention_layer=attention,
        name='ReversibleHalfResidualEncoderAttn')

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, center_layernorm,
                                             mode, use_bfloat16)

    encoder_block = [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward,
                                  name='ReversibleHalfResidualEncoderFF'),
    ]
    if use_two_swaps_per_block:
        encoder_block.append(tl.ReversibleSwap())
    return encoder_block
コード例 #2
0
def EncoderDecoderBlock(d_model,
                        d_ff,
                        n_heads,
                        dropout,
                        ff_activation,
                        ff_dropout,
                        mode,
                        ff_use_sru=0,
                        ff_chunk_size=0,
                        ff_sparsity=0):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate for feed-forward layer
    mode: str: 'train' or 'eval'
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    the layer.
  """
    enc_dec_attention = tl.EncDecAttention(n_heads=n_heads,
                                           d_qk=d_model // n_heads,
                                           d_v=d_model // n_heads,
                                           attention_dropout=dropout,
                                           output_dropout=dropout,
                                           mode=mode)
    enc_dec_attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=enc_dec_attention,
    )

    causal_attention = tl.SelfAttention(n_heads=n_heads,
                                        d_qk=d_model // n_heads,
                                        d_v=d_model // n_heads,
                                        causal=True,
                                        attention_dropout=dropout,
                                        output_dropout=dropout,
                                        mode=mode)
    causal_attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=causal_attention,
    )

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, mode)

    return [  # vec_d1 vec_d2 vec_e masks
        causal_attention_half_residual,
        tl.ReversibleSwap(),
        enc_dec_attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
コード例 #3
0
def DecoderBlock(d_model,
                 d_ff,
                 d_attention_key,
                 d_attention_value,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru,
                 ff_chunk_size,
                 ff_sparsity,
                 attention_chunk_size,
                 n_attention_layers=1,
                 n_feedforward_layers=1,
                 center_layernorm=True,
                 use_bfloat16=False,
                 mode='train'):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_attention_layers: how many residual causal attention layers should we
      have before the feed-forward block (default: 1, the standard block)
    n_feedforward_layers: how many FFNN layers should we have (default 1).
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    mode: str: 'train' or 'eval'


  Returns:
    the layer.
  """
    # pylint: disable=g-complex-comprehension
    attention_half_residuals = [[
        tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),
                                  attention_layer=ct.ApplyAttentionLayer(
                                      attention_type, d_model, n_heads,
                                      d_attention_key, d_attention_value, True,
                                      False, dropout, dropout,
                                      attention_chunk_size, mode),
                                  name='ReversibleHalfResidualDecoderAttn'),
        tl.ReversibleSwap()
    ] for _ in range(n_attention_layers)]

    feed_forwards = [[
        tl.ReversibleHalfResidual(ct.FeedForwardWithOptions(
            d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
            ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode,
            use_bfloat16),
                                  name='ReversibleHalfResidualDecoderFF'),
        tl.ReversibleSwap()
    ] for _ in range(n_feedforward_layers)]
    # pylint: enable=g-complex-comprehension
    return attention_half_residuals + feed_forwards
コード例 #4
0
 def test_reversible_swap(self):
     layer = tl.ReversibleSwap()
     xs = [np.array([1, 2]), np.array([10, 20])]
     ys = layer(xs)
     self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])
コード例 #5
0
ファイル: trainer_test.py プロジェクト: yinxx/trax
    def test_run_reversible_same_as_default_extended(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        # We want to test rng propagation too, so adding some dropout layers.
        first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup())
        rev_layers1 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)),
            tl.ReversibleSwap(),
            tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)),
            tl.ReversibleSwap()
        ]
        mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup())
        rev_layers2 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)),
            tl.ReversibleSwap()
        ]
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3),
                               tl.LogSoftmax(), tl.CrossEntropyLoss())
        model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] +
                          rev_layers2 + [loss_layer])
        rng_init = fastmath.random.get_prng(12)
        model.init(labeled_batch, rng=rng_init)
        optimizer_fn = optimizers.Adam  # to test slots

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model.weights)
        trainer = optimizers.Trainer(model, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_layer_weights1 = first_layer.weights
        rev_layer12_weights1 = rev_layers1[2].weights
        mid_layer_weights1 = mid_layer.weights
        rev_layer20_weights1 = rev_layers2[0].weights
        loss_layer_weights1 = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model.init(labeled_batch, rng=rng_init)
        # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why?
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer.sublayers, rev_layers1),
             (mid_layer.sublayers, rev_layers2)],
            loss_layer,
            optimizer_fn,
            memoize_jit=False)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_layer_weights1, loss_layer.weights)
        self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights)
        self._assert_all_equal(mid_layer_weights1, mid_layer.weights)
        self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights)
        self._assert_all_equal(first_layer_weights1, first_layer.weights)
コード例 #6
0
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 attention_chunk_size=0,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    attention = ct.ApplyAttentionLayer(
        attention_type=attention_type,
        d_model=d_model,
        n_heads=n_heads,
        d_qk=d_model // n_heads,
        d_v=d_model // n_heads,
        masked=True,
        causal=False,
        attention_dropout=dropout,
        output_dropout=dropout,
        attention_chunk_size=attention_chunk_size,
        mode=mode)
    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, mode)

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
コード例 #7
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, n_attention_chunks, attention_type,
                 dropout, share_qk, ff_activation, ff_use_sru, ff_chunk_size,
                 mode):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: string, whether to share queries and keys
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  if not hasattr(attention_type, 'forward_unbatched'):
    if share_qk:
      pre_attention = [
          Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
          tl.LayerNorm(),
          tl.Dup(),
          tl.Parallel(
              tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
              tl.ComputeAttentionHeads(
                  n_heads=n_heads, d_head=d_attention_value),
          ),
          tl.Dup(),
      ]
    else:
      pre_attention = [
          Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
          tl.LayerNorm(),
          tl.Dup(), tl.Dup(),
          tl.Parallel(
              tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
              tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
              tl.ComputeAttentionHeads(
                  n_heads=n_heads, d_head=d_attention_value),
          ),
      ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    attention_half_residual = ReversibleAttentionHalfResidual(
        pre_attention, attention, post_attention)
  else:
    attention = attention_type(
        n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value,
        share_qk=share_qk, causal=True, output_dropout=dropout, mode=mode)
    attention_half_residual = ReversibleHalfResidualV2(
        tl.LayerNorm(),
        attention_layer=attention,
    )

  if ff_use_sru:
    feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)]
  else:
    feed_forward = [ChunkedFeedForward(d_model, d_ff, dropout, ff_activation,
                                       dropout, ff_chunk_size, mode)]

  return [
      attention_half_residual,
      tl.ReversibleSwap(),
      ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
コード例 #8
0
 def test_reversible_swap(self, backend):
     with fastmath.use_backend(backend):
         layer = tl.ReversibleSwap()
         xs = [np.array([1, 2]), np.array([10, 20])]
         ys = layer(xs)
         self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])
コード例 #9
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 n_attention_chunks, attention_type, dropout, share_qk,
                 ff_activation, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: string, whether to share queries and keys
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    if share_qk:
        pre_attention = [
            Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
            tl.LayerNorm(),
            tl.Dup(),
            tl.Parallel(
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key),
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value),
            ),
            tl.Dup(),
        ]
    else:
        pre_attention = [
            Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
            tl.LayerNorm(),
            tl.Dup(),
            tl.Dup(),
            tl.Parallel(
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key),
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key),
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value),
            ),
        ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_model, d_ff, dropout, ff_activation, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
コード例 #10
0
ファイル: reformer.py プロジェクト: ixxxxu/trax
 def _feed_forward():
     return [
         tl.ReversibleHalfResidual(_FF(),
                                   name='ReversibleHalfResidualDecoderFF'),
         tl.ReversibleSwap()
     ]