示例#1
0
文件: reformer.py 项目: xibelly/trax
def PositionalEncoding(mode,
                       dropout=None,
                       max_len=None,
                       axial_pos_shape=None,
                       d_axial_pos_embs=None):
    """Returns the positional encoding layer depending on the arguments."""
    if not axial_pos_shape:
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode)
    elif axial_pos_shape == 'fixed-base':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)
    elif axial_pos_shape == 'infinite':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.InfinitePositionalEncoding(affine=False)
    elif axial_pos_shape == 'infinite-affine':
        # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.InfinitePositionalEncoding()
    elif axial_pos_shape == 'time-bin':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.TimeBinPositionalEncoding()
    else:
        assert d_axial_pos_embs is not None
        positional_encoding = tl.AxialPositionalEncoding(
            shape=axial_pos_shape,
            d_embs=d_axial_pos_embs,
            dropout_broadcast_dims=tuple(range(1,
                                               len(axial_pos_shape) + 1)),
            dropout=dropout,
            mode=mode)

    return positional_encoding
示例#2
0
def PositionalEncoder(mode,
                      dropout=None,
                      max_len=None,
                      pos_type=None,
                      pos_axial_shape=None,
                      pos_d_axial_embs=None,
                      pos_start_from_zero_prob=1.0,
                      pos_max_offset_to_add=0,
                      use_bfloat16=False):
  """Returns the positional encoding layer depending on the arguments.

  Args:
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    dropout: Stochastic rate (probability) for dropping an activation
      value when applying dropout after the embedding block.
    max_len: Maximum symbol length for positional encoding.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    use_bfloat16: If `True`, use bfloat16 weights instead of the default
      float32; this can save memory but may (rarely) lead to numerical issues.

  Returns:
    A layer that will do the positional encoding.
  """
  if not pos_type:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16,
        start_from_zero_prob=pos_start_from_zero_prob,
        max_offset_to_add=pos_max_offset_to_add, mode=mode)
  elif pos_type == 'sin-cos':
    positional_encoding = tl.SinCosPositionalEncoding(mode=mode)
  elif pos_type == 'fixed-base':
    positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)
  elif pos_type == 'infinite':
    positional_encoding = tl.InfinitePositionalEncoding(affine=False)
  elif pos_type == 'infinite-affine':
    positional_encoding = tl.InfinitePositionalEncoding()
  elif pos_type == 'time-bin':
    positional_encoding = tl.TimeBinPositionalEncoding()
  elif pos_type == 'no':
    positional_encoding = tl.Serial()  # no positional encoding at all
  else:  # TODO(lukaszkaiser): name this type and check for the correct name
    assert pos_d_axial_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=pos_axial_shape, d_embs=pos_d_axial_embs,
        dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)),
        dropout=dropout, mode=mode)

  return positional_encoding
示例#3
0
def PositionalEncoder(mode,
                      dropout=None,
                      max_len=None,
                      axial_pos_shape=None,
                      d_axial_pos_embs=None,
                      use_bfloat16=False):
    """Returns the positional encoding layer depending on the arguments.

  Args:
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    dropout: Stochastic rate (probability) for dropping an activation
      value when applying dropout after the embedding block.
    max_len: Maximum symbol length for positional encoding.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    use_bfloat16: If `True`, use bfloat16 weights instead of the default
      float32; this can save memory but may (rarely) lead to numerical issues.

  Returns:
    A layer that will do the positional encoding.
  """

    if not axial_pos_shape:
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode,
                                                    use_bfloat16=use_bfloat16)
    elif axial_pos_shape == 'sin-cos':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.SinCosPositionalEncoding(mode=mode)
    elif axial_pos_shape == 'fixed-base':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)
    elif axial_pos_shape == 'infinite':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.InfinitePositionalEncoding(affine=False)
    elif axial_pos_shape == 'infinite-affine':
        # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.InfinitePositionalEncoding()
    elif axial_pos_shape == 'time-bin':  # TODO(lukaszkaiser): remove this HACK
        positional_encoding = tl.TimeBinPositionalEncoding()
    else:
        assert d_axial_pos_embs is not None
        positional_encoding = tl.AxialPositionalEncoding(
            shape=axial_pos_shape,
            d_embs=d_axial_pos_embs,
            dropout_broadcast_dims=tuple(range(1,
                                               len(axial_pos_shape) + 1)),
            dropout=dropout,
            mode=mode)

    return positional_encoding
示例#4
0
文件: reformer.py 项目: ajiangcn/trax
  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    if not axial_pos_shape:
      positional_encoding = tl.PositionalEncoding(
          max_len=max_len, dropout=dropout, mode=mode)
    else:
      assert d_axial_pos_embs is not None
      positional_encoding = tl.AxialPositionalEncoding(
          shape=axial_pos_shape, d_embs=d_axial_pos_embs,
          dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
          dropout=dropout, mode=mode)

    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]
示例#5
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      ff_chunk_size=0,
                      mode='train'):
  """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_embedding, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn(lambda x: np.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dense(d_model),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.ReversibleSerial(decoder_blocks),
      tl.Select([0], n_in=2),
      tl.LayerNorm(),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn(lambda x: np.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.CausalConv(d_embedding),
      tl.Relu(),
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
示例#6
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               ff_chunk_size=0,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train', 'eval', or 'predict'

  Returns:
    the layer.
  """
  if n_chunks == 0:
    n_chunks = 1
    concatenate_input_chunks = []
  else:
    concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)

  d_emb = d_model
  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  elif axial_pos_shape == 'fixed-base':  # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)
    d_emb //= 2
  elif axial_pos_shape == 'infinite':  # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.InfinitePositionalEncoding(affine=False)
  elif axial_pos_shape == 'infinite-affine':
    # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.InfinitePositionalEncoding()
  elif axial_pos_shape == 'time-bin':  # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.TimeBinPositionalEncoding()
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_emb, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  return tl.Serial(
      concatenate_input_chunks,
      tl.ShiftRight(mode=mode),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial(decoder_blocks + [
          SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      ]),
      Map([
          # TODO(kitaev): Test whether dropout should go before or after the
          # LayerNorm, and whether dropout broadcasting is needed here.
          tl.LayerNorm(),
          BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ], n_sections=n_chunks),
  )