示例#1
0
def TransformerRevnetLM(vocab_size,
                        d_feature=512,
                        d_feedforward=2048,
                        d_attention_key=64,
                        d_attention_value=64,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        attention_loop_stride=0,
                        mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of *each half* of the two-part features
    d_feedforward: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_loop_stride: int: number of query elements to compute attention
      for in parallel. Set to 0 to disable memory-efficient attention.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        # TODO(kitaev): add dropout
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_feature, d_feedforward, d_attention_key,
                         d_attention_value, n_heads, n_attention_chunks,
                         attention_loop_stride, dropout, mode)
            for _ in range(n_layers)
        ]),
        tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
        tl.Concatenate(),
        Split(sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map([
            tl.Dense(vocab_size),
            tl.LogSoftmax(),
        ], sections=n_chunks),
    )
示例#2
0
def TransformerRevnetLM(vocab_size,
                        d_model=512,
                        d_ff=2048,
                        d_attention_key=64,
                        d_attention_value=64,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        attention_type=DotProductAttention,
                        mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                         n_heads, n_attention_chunks, attention_type, dropout,
                         mode) for _ in range(n_layers)
        ]),
        tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
        tl.Concatenate(),
        Split(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map([
            tl.Dense(vocab_size),
            tl.LogSoftmax(),
        ], n_sections=n_chunks),
    )
示例#3
0
def TransformerRevnetLM(vocab_size,
                        d_feature=512,
                        d_feedforward=2048,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of *each half* of the two-part features
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for memory-efficient attention
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  positional_embedder = [
      tl.Embedding(d_feature, vocab_size),
      # TODO(kitaev): dropout is disabled to save memory
      # tl.Dropout(rate=dropout, mode=mode),
      tl.PositionalEncoding(max_len=max_len),
  ]
  return tl.Model(
      tl.Concatenate(),
      tl.ShiftRight(),
      positional_embedder,
      Duplicate(),  # pylint: disable=no-value-for-parameter
      ReversibleSerial([
          DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks,
                       dropout, mode)
          for _ in range(n_layers)
      ]),
      tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
      tl.Concatenate(),
      Split(sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      Map([
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ]),
  )
示例#4
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up 4 successive game frames, concatenated on the last axis.
        tl.Dup(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)),
        tl.Concatenate(n_items=4, axis=-1),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
示例#5
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128):
    # Input's shape = (B, T, H, W, C)
    return tl.Serial(
        tl.Div(divisor=255.0),
        # Have 4 copies of the input, each one shifted to the right by one.
        tl.Branch(
            tl.NoOp(), tl.ShiftRight(),
            tl.Serial(
                tl.ShiftRight(),
                tl.ShiftRight(),
            ), tl.Serial(
                tl.ShiftRight(),
                tl.ShiftRight(),
                tl.ShiftRight(),
            )),
        # Concatenated on the last axis.
        tl.Concatenate(axis=-1),  # (B, T, H, W, 4C)
        tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2),
        tl.Relu(),
        tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2),
        tl.Relu(),
        tl.Flatten(num_axis_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
        # Eventually this is shaped (B, T, output_size)
    )
示例#6
0
def PreservePosition(layer):
  """Execute layer without position but preserve it in parallel."""
  return tl.Serial(
      CutAtPosition(),
      layer,
      tl.Concatenate(n_items=2)
  )
示例#7
0
文件: atari_cnn.py 项目: tianhai123/-
def FrameStack(n_frames):
  """Stacks a fixed number of frames along the dimension 1."""
  # Input shape: (B, T, ..., C).
  # Output shape: (B, T, ..., C * n_frames).
  assert n_frames >= 1
  return (
      # Make n_frames copies of the input sequence.
      [tl.Dup()] * (n_frames - 1),
      # Shift copies to the right by [0, .., n_frames - 1] frames.
      tl.Parallel(*map(_shift_right, range(n_frames))),
      # Concatenate along the channel dimension.
      tl.Concatenate(n_items=n_frames, axis=-1),
  )
示例#8
0
 def model(mode):
   del mode
   return layers.Serial(
       layers.Parallel(
           layers.Flatten(),  # Observation stack.
           layers.Embedding(d_feature=1, vocab_size=n_actions),  # Action.
       ),
       layers.Concatenate(),
       layers.Dense(n_units=1),
       layers.Dup(),
       layers.Parallel(
           layers.Dense(n_units=obs_shape[1]),  # New observation.
           None,  # Reward.
       )
   )
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks,
                 dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for memory-efficient attention
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    self_attention = [
        tl.LayerNorm(),
        tl.Branch([], tl.CausalMask(axis=-2)),  # Create mask.
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Select(0),  # Drop mask.
        tl.Dropout(rate=dropout, mode=mode),
    ]

    # TODO(kitaev): Memory-efficient attention. This chunking is temporary.
    self_attention = [
        Split(sections=n_attention_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map(self_attention),
        tl.Concatenate(axis=-2),
    ]

    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        ReversibleResidual([self_attention], [feed_forward]),
    ]
示例#10
0
def ApplyAndQueryPositions(layer, pos):
    """Execute layer without position and pos-layers on positions.

  This takes an embedding including position x = (emb, p), and
  outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p)
  where pos=[pos1...posn].

  Args:
    layer: layer to be executed without position information.
    pos: list of layers to be applied to positions.

  Returns:
    the result of this application.
  """
    n_heads = len(pos)
    return tl.Serial(
        tl.Dup(),  # (x, x)
        CutAtPosition(),  # (x_content, x_position, x)
        tl.Parallel([], tl.Swap()),  # (x_content, x, x_position)
        [tl.Parallel([], Dup2()) for _ in range(n_heads - 1)],
        # Now the stack is x_content, (x, x_position) * n_heads.
        tl.Parallel(*([layer] + pos)),
        tl.Concatenate(n_items=n_heads + 1))
示例#11
0
def TransformerLM(vocab_size,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  d_attention_key=None,
                  d_attention_value=None,
                  attention_type=tl.DotProductCausalAttention,
                  dropout=0.1,
                  share_qk=False,
                  max_len=2048,
                  n_chunks=0,
                  mode='train'):
  """Returns a Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_model // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_model // n_heads)
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys in decoder attention
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  if n_chunks == 0:
    concatenate_chunks = split_chunks = []
  else:
    concatenate_chunks = tl.Concatenate(n_items=n_chunks)
    split_chunks = tl.Split(n_sections=n_chunks, axis=-2)

  embedder = [
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, name='embedding', mode=mode),
      tl.PositionalEncoding(max_len=max_len, mode=mode),
  ]

  return tl.Model(                  # tokens (or chunked tuple of tokens)
      concatenate_chunks,           # tokens
      tl.ShiftRight(mode=mode),     # toks
      embedder,                     # vecs
      [DecoderBlock(  # pylint: disable=g-complex-comprehension
          d_model, d_ff, n_heads, d_attention_key, d_attention_value,
          attention_type, dropout, share_qk, i, mode)
       for i in range(n_layers)],   # vecs
      tl.LayerNorm(),               # vecs
      tl.Dense(vocab_size),         # vecs
      tl.LogSoftmax(),              # vecs
      split_chunks,                 # vecs (or chunked tuple of vecs)
  )