Exemplo n.º 1
0
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn,
                         two_towers):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    @tl.layer()
    def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls and flattens controls."""
        return np.reshape(x, (x.shape[0], -1, n_actions))

    if vocab_size is None:
        # In continuous policies every element of the output sequence corresponds to
        # an observation.
        n_preds_per_input = n_controls
        kwargs = {}
    else:
        # In discrete policies every element of the output sequence corresponds to
        # a symbol in the discrete representation, and each control takes 1 symbol.
        n_preds_per_input = 1
        kwargs = {"vocab_size": vocab_size}

    if two_towers:
        layers = [
            tl.Dup(),
            tl.Parallel(
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input),
                    tl.Flatten()
                ],
            )
        ]
    else:
        layers = [
            bottom_layers_fn(**kwargs),
            tl.Dup(),
            tl.Parallel(
                [
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [tl.Dense(n_preds_per_input),
                 tl.Flatten()],
            )
        ]
    return tl.Model(layers)
Exemplo n.º 2
0
def TransformerDecoder(vocab_size=None,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       d_attention_key=None,
                       d_attention_value=None,
                       attention_type=tl.DotProductCausalAttention,
                       dropout=0.1,
                       share_qk=False,
                       max_len=2048,
                       mode='train'):
    """Returns a Transformer decoder model.

  The input to the model is either continuous or discrete - controlled by
  vocab_size. Does not shift the input to the right, i.e. the output for
  timestep t is based on inputs up to timestep t inclusively.

  Args:
    vocab_size: int or None: vocab size if running on discrete input, None
        otherwise.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_model // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_model // n_heads)
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys in decoder attention
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer decoder as a layer that maps from a continuous or discrete
    tensor to a continuous tensor.
  """
    if vocab_size is None:
        input_layer = tl.Dense
    else:
        input_layer = functools.partial(tl.Embedding, vocab_size=vocab_size)
    return tl.Model(  # vecs
        input_layer(d_model),  # vecs
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
        [
            DecoderBlock(  # pylint: disable=g-complex-comprehension
                d_model, d_ff, n_heads, d_attention_key, d_attention_value,
                attention_type, dropout, share_qk, i, mode)
            for i in range(n_layers)
        ],  # vecs
        tl.LayerNorm(),  # vecs
    )
Exemplo n.º 3
0
def FrameStackMLP(n_frames=4,
                  hidden_sizes=(64, ),
                  output_size=64,
                  mode='train'):
    """MLP operating on a fixed number of last frames."""
    del mode

    return tl.Model(
        FrameStack(n_frames=n_frames),
        [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes],
        tl.Dense(output_size),
    )
Exemplo n.º 4
0
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: Number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a ResNet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        tl.AvgPool(pool_size=(7, 7)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Exemplo n.º 5
0
def MLP(n_hidden_layers=2,
        d_hidden=512,
        activation_fn=tl.Relu,
        n_output_classes=10,
        mode="train"):
    """A multi-layer feedforward (perceptron) network."""
    del mode

    return tl.Model(
        tl.Flatten(),
        [[tl.Dense(d_hidden), activation_fn()]
         for _ in range(n_hidden_layers)],
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Exemplo n.º 6
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
    """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model([  #      tokens
        tl.Dup(),  # toks toks
        tl.Parallel(embedder, tl.PaddingMask()),  # vecs mask
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_layers)
        ],  # vecs mask
        tl.Parallel([], tl.Drop()),  # ____  0
        tl.LayerNorm(),  # vecs
        tl.Mean(axis=1),  # Average on length.    # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    ])
Exemplo n.º 7
0
def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'):
    """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727.

  Args:
    d_feature: Number of memory channels (dimensionality of feature embedding).
    steps: Number of times depthwise recurrence steps.
    vocab_size: Vocabulary size.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    A NeuralGPU Stax model.
  """
    del mode

    core = ConvDiagonalGRU(units=d_feature)
    return tl.Model(
        tl.Embedding(d_feature=d_feature, vocab_size=vocab_size),
        [core] * steps,
        tl.Dense(vocab_size),
        tl.LogSoftmax(),
    )
Exemplo n.º 8
0
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up n_frames successive game frames, concatenated on the last axis.
        FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
Exemplo n.º 9
0
def WideResnet(n_blocks=3,
               widen_factor=1,
               n_output_classes=10,
               bn_momentum=0.9,
               mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group. total layers = 6n + 4.
    widen_factor: int, widening factor of each group. k=1 is vanilla resnet.
    n_output_classes: int, number of distinct output classes.
    bn_momentum: float, momentum in BatchNorm.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(16, (3, 3), padding='SAME'),
        WideResnetGroup(n_blocks,
                        16 * widen_factor,
                        bn_momentum=bn_momentum,
                        mode=mode),
        WideResnetGroup(n_blocks,
                        32 * widen_factor, (2, 2),
                        bn_momentum=bn_momentum,
                        mode=mode),
        WideResnetGroup(n_blocks,
                        64 * widen_factor, (2, 2),
                        bn_momentum=bn_momentum,
                        mode=mode),
        tl.BatchNorm(momentum=bn_momentum, mode=mode),
        tl.Relu(),
        tl.AvgPool(pool_size=(8, 8)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Exemplo n.º 10
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  if n_chunks == 0:
    n_chunks = 1
    concatenate_input_chunks = []
    concatenate_output_chunks = tl.Concatenate(n_items=n_chunks, axis=-2)
  else:
    concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)
    concatenate_output_chunks = []

  positional_embedder = [
      tl.Embedding(d_model, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.PositionalEncoding(max_len=max_len),
  ]
  return tl.Model(
      concatenate_input_chunks,
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial([
          # pylint: disable=g-complex-comprehension
          DecoderBlock(d_model, d_ff,
                       d_attention_key, d_attention_value, n_heads,
                       n_attention_chunks, attention_type,
                       dropout, share_qk, mode)
          for _ in range(n_layers)
      ] + [
          SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      ]),
      Map([
          # TODO(kitaev): Test whether dropout should go before or after the
          # LayerNorm, and whether dropout broadcasting is needed here.
          tl.LayerNorm(),
          BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ], n_sections=n_chunks),
      concatenate_output_chunks,
  )
Exemplo n.º 11
0
 def model_fn(mode="train"):
   return layers.Model(
       layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode),
       models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode))
Exemplo n.º 12
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Model(  # tokens_e tokens_d
        tl.Parallel([], tl.Dup()),  # toks_e toks_d toks_d (for loss)
        tl.Swap(),  # toks_d toks_e ....

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Dup(),  # ______ toks_e toks_e
                tl.Parallel(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  #        toks_d masks vecs_e
        tl.ShiftRight(),  #        toks_d ..... ......
        out_embed,  #        vecs_d ..... ......
        tl.Dup(),  # vecs_d vecs_d ..... ......
        tl.Parallel([], tl.EncoderDecoderMask()),  # ______    masks     ......
        encoder_decoder_stack,  # vecs_d    masks     vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )
Exemplo n.º 13
0
def TransformerLM(vocab_size,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  d_attention_key=None,
                  d_attention_value=None,
                  attention_type=tl.DotProductCausalAttention,
                  dropout=0.1,
                  share_qk=False,
                  max_len=2048,
                  n_chunks=0,
                  mode='train'):
    """Returns a Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_model // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_model // n_heads)
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys in decoder attention
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    if n_chunks == 0:
        concatenate_chunks = split_chunks = []
    else:
        concatenate_chunks = tl.Concatenate(n_items=n_chunks)
        split_chunks = tl.Split(n_sections=n_chunks, axis=-2)

    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='embedding', mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
    ]

    return tl.Model(  # tokens (or chunked tuple of tokens)
        concatenate_chunks,  # tokens
        tl.ShiftRight(mode=mode),  # toks
        embedder,  # vecs
        [
            DecoderBlock(  # pylint: disable=g-complex-comprehension
                d_model, d_ff, n_heads, d_attention_key, d_attention_value,
                attention_type, dropout, share_qk, i, mode)
            for i in range(n_layers)
        ],  # vecs
        tl.LayerNorm(),  # vecs
        tl.Dense(vocab_size),  # vecs
        tl.LogSoftmax(),  # vecs
        split_chunks,  # vecs (or chunked tuple of vecs)
    )