Beispiel #1
0
  def encoder(embedded_source, source_mask):
    """Transformer encoder stack.

    Args:
      embedded_source: staxlayer variable: embedded source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    stax.Identity,  # key
                                    stax.Identity,  # value
                                    source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(keep_rate, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(),
                      feed_forward,
                      stax.Dropout(keep_rate, mode=mode))
    )
    return stax.serial(
        embedded_source,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(),
    )
Beispiel #2
0
def DecoderLayer(feature_depth,
                 feedforward_depth,
                 num_heads,
                 dropout,
                 mode):
  """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    init and apply.
  """
  return stax.serial(
      stax.residual(  # Self-attention block.
          stax.LayerNorm(),
          stax.FanOut(4),
          stax.parallel(stax.Identity,  # query
                        stax.Identity,  # key
                        stax.Identity,  # value
                        stax.CausalMask(axis=-2)),  # attention mask
          stax.MultiHeadedAttention(feature_depth, num_heads=num_heads,
                                    dropout=dropout, mode=mode),
          stax.Dropout(dropout, mode=mode)
      ),
      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)
  )
Beispiel #3
0
  def Encoder(source, source_mask):
    """Transformer encoder stack.

    Args:
      source: staxlayer variable: raw source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    stax.Identity,  # key
                                    stax.Identity,  # value
                                    source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        ResidualFeedForward(
            feature_depth, feedforward_depth, dropout, mode=mode),
    )
    return stax.serial(
        source,
        source_embedding_layer,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(),
    )
Beispiel #4
0
def _build_combinator_tree(input_treespec, in_vars):
    """Build a trivial Staxlayer that takes a complicated tree of inputs."""
    parallel_args = []
    for e in input_treespec:
        if isinstance(e, int):
            parallel_args.append(in_vars[e])
        elif isinstance(e, tuple):
            parallel_args.append(_build_combinator_tree(e, in_vars))
    return stax.serial(stax.parallel(*parallel_args), stax.FanInSum)
Beispiel #5
0
    def decoder(memory, target, target_mask, memory_mask):
        """Transformer decoder stack.

    Args:
      memory: staxlayer variable: encoded source sequences
      target: staxlayer variable: raw target sequences
      target_mask: staxlayer variable: self-attention mask
      memory_mask: staxlayer variable: memory attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
        decoder_layer = stax.serial(
            # target attends to self
            stax.residual(
                stax.LayerNorm(),
                stax.FanOut(4),
                stax.parallel(
                    stax.Identity,  # query
                    stax.Identity,  # key
                    stax.Identity,  # value
                    target_mask),  # attention mask
                multi_attention,
                stax.Dropout(keep_rate, mode=mode)),
            # target attends to encoded source
            stax.residual(
                stax.LayerNorm(),
                stax.FanOut(4),
                stax.parallel(
                    stax.Identity,  # query
                    memory,  # key
                    memory,  # value
                    memory_mask),  # attention mask
                multi_attention,
                stax.Dropout(keep_rate, mode=mode)),
            # feed-forward
            stax.residual(stax.LayerNorm(), feed_forward,
                          stax.Dropout(keep_rate, mode=mode)))
        return stax.serial(
            target,
            target_embedding_layer,
            stax.repeat(decoder_layer, num_layers),
            stax.LayerNorm(),
        )
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = stax.serial(stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), padding='SAME'))
    shortcut = stax.Identity if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum)
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.serial(stax.Conv(filters1, (1, 1),
                                 strides), stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                           stax.BatchNorm())
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum, stax.Relu)
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2 = filters

    def MakeMain(input_shape):
        # the number of output channels depends on the number of input channels
        return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(),
                           stax.Relu,
                           stax.Conv(filters2, (ks, ks), padding='SAME'),
                           stax.BatchNorm(), stax.Relu,
                           stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())

    main = stax.shape_dependent(MakeMain)
    return stax.serial(stax.FanOut(2), stax.parallel(main, stax.Identity),
                       stax.FanInSum, stax.Relu)
Beispiel #9
0
 def lambda_fun3(x, y, z, w, v):
     input_tree = _build_combinator_tree(tree_spec, (x, y, z))
     return stax.serial(input_tree, stax.FanOut(3),
                        stax.parallel(w, v, stax.Identity),
                        stax.FanInSum)
Beispiel #10
0
def TransformerLM(vocab_size,  # pylint: disable=invalid-name
                  mode='train',
                  num_layers=6,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding

  Returns:
    init and apply.
  """
  keep_rate = 1.0 - dropout
  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(keep_rate, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  # Single decoder layer
  decoder_layer = stax.serial(
      # target attends to self
      stax.residual(stax.LayerNorm(),
                    stax.FanOut(4),
                    stax.parallel(stax.Identity,  # query
                                  stax.Identity,  # key
                                  stax.Identity,  # value
                                  stax.CausalMask(axis=-2)),  # attention mask
                    multi_attention,
                    stax.Dropout(keep_rate, mode=mode)),
      # feed-forward
      stax.residual(stax.LayerNorm(),
                    feed_forward,
                    stax.Dropout(keep_rate, mode=mode))
  )

  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.Dropout(keep_rate, mode=mode),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.repeat(decoder_layer, num_layers),
      stax.LayerNorm(),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )