Exemple #1
0
    def test_positional_encoder(self, pos_axial_shape):
        # dim should divide FixedBasePositionalEncoding.n_digits
        batch, length, dim = 2, 32, 8
        input_shape = (batch, length, dim)
        vocab_size = 32
        x = np.random.randint(0, vocab_size - 1, input_shape)
        # should sum to dim
        pos_d_axial_embs = (4, 4)

        positional_encoding = ct.PositionalEncoder(
            'train',
            dropout=0.1,
            max_len=length,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs)
        _, _ = positional_encoding.init(shapes.signature(x))
        y = positional_encoding(x)
        self.assertEqual(y.shape, input_shape)
Exemple #2
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               attention_type=tl.SelfAttention,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               ff_chunk_size=0,
               ff_sparsity=0,
               loss_sparsity_type='mult',
               loss_sparsity=0,
               loss_d_lowrank=0,
               loss_sparsity_prob=None,
               attention_chunk_size=0,
               mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    attention_type: class: attention class to use, such as SelfAttention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train', 'eval', or 'predict'

  Returns:
    the layer.
  """
    positional_encoding = ct.PositionalEncoder(mode, dropout, max_len,
                                               axial_pos_shape,
                                               d_axial_pos_embs)

    positional_embedder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     attention_chunk_size=attention_chunk_size,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    dense_loss_layer = tl.SparseDenseWithOptions(
        vocab_size,
        d_input=d_model,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        mode=mode)

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial(decoder_blocks),
        tl.Concatenate(),
        # TODO(kitaev): Test whether dropout should go before or after the
        # LayerNorm, and whether dropout broadcasting is needed here.
        tl.LayerNorm(),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        dense_loss_layer,
    )
Exemple #3
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      attention_type=tl.SelfAttention,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      ff_chunk_size=0,
                      ff_sparsity=0,
                      attention_chunk_size=0,
                      mode='train'):
    """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape ::

      [batch, length // F, d_model]

  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    attention_type: class: attention class to use, such as SelfAttention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

    positional_encoding = ct.PositionalEncoder(mode, dropout, max_len,
                                               axial_pos_shape,
                                               d_axial_pos_embs)

    positional_embedder = [
        tl.Embedding(vocab_size, d_embedding),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     attention_chunk_size=attention_chunk_size,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # pylint: disable=g-long-lambda
    return tl.Serial(
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),  # Stack has (x, x), the first will be shortened
        # Before shortening, we need to pad by shorten factor so as not to leak
        # information into the future. To understand why, imagine shorten factor
        # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
        # would have 0ABC, which gets grouped to [0A][BC] on input, which is
        # predicting ABCD as targets. The problem is that [0A] has access to A
        # and [BC] has access to C -- it will learn to copy it, peek into
        # the future. Shifting twice to [00][AB] solves the problem as the first
        # "big" symbol becomes all-0 and the rest is shifted enough.
        tl.ShiftRight(n_positions=shorten_factor - 1),
        tl.Fn(
            'Shorten',
            lambda x: jnp.reshape(  # Shorten -- move to depth.
                x, (x.shape[0], x.shape[1] // shorten_factor, -1)),
            n_out=1),
        tl.Dense(d_model),
        tl.Dup(),  # Stack has (short_x, short_x, x)
        tl.ReversibleSerial(decoder_blocks),
        tl.Select([0], n_in=2),
        tl.LayerNorm(),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Dense(shorten_factor * d_embedding),
        tl.Fn(
            'ProlongBack',
            lambda x: jnp.reshape(  # Prolong back.
                x, (x.shape[0], x.shape[1] * shorten_factor, -1)),
            n_out=1),
        tl.Concatenate(),  # Concatenate with just the embeddings.
        tl.CausalConv(d_embedding),
        tl.Relu(),
        tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
        tl.Dense(vocab_size),
    )