Exemple #1
0
def MultiHeadedAttentionPosition(positions,
                                 d_feature,
                                 n_heads=8,
                                 dropout=0.0,
                                 mode='train'):
    """Transformer-style multi-headed attention."""
    return tl.Serial(
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            ApplyAndQueryPositions(
                tl.Dense(d_feature),
                pos=[SumLearnedPick(positions) for _ in range(n_heads)]),
            PreservePosition(tl.Dense(d_feature)),
            PreservePosition(tl.Dense(d_feature)),
        ),
        tl.Parallel(
            CopyHeadsPos(h=n_heads),
            MixHeadsPos(h=n_heads),
            MixHeadsPos(h=n_heads),
        ),
        tl.PureMultiHeadedAttention(d_feature=d_feature,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode),
        tl.Parallel([], tl.Drop()),  # Drop the mask.
        CombineHeadsPos(h=n_heads),
        PreservePosition(tl.Dense(d_feature)),
    )
Exemple #2
0
def policy_and_value_net(n_actions, bottom_layers_fn, two_towers):
  """A policy and value net function."""

  # Layers.

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

  if two_towers:
    layers = [
        tl.Dup(),
        tl.Parallel(
            [bottom_layers_fn(),
             tl.Dense(n_actions),
             tl.LogSoftmax()],
            [bottom_layers_fn(), tl.Dense(1)],
        )
    ]
  else:
    layers = [
        bottom_layers_fn(),
        tl.Dup(),
        tl.Parallel(
            [tl.Dense(n_actions), tl.LogSoftmax()],
            [tl.Dense(1)],
        )
    ]
  return tl.Model(layers)
Exemple #3
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         observations_dtype,
                         n_actions,
                         bottom_layers_fn=(),
                         two_towers=True):
  """A policy and value net function."""

  # Layers.

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

  if two_towers:
    layers = [
        tl.Dup(),
        tl.Parallel(
            [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()],
            [bottom_layers_fn(), tl.Dense(1)],
        )
    ]
  else:
    layers = [
        bottom_layers_fn(),
        tl.Dup(),
        tl.Parallel(
            [tl.Dense(n_actions), tl.LogSoftmax()],
            [tl.Dense(1)],
        )
    ]
  net = tl.Model(layers)
  params, state = net.initialize(batch_observations_shape, observations_dtype,
                                 rng_key)
  return params, state, net
Exemple #4
0
    def __init__(self, pre_attention, attention, post_attention):
        self.pre_attention = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(pre_attention, [], []),
        ])
        assert hasattr(attention, 'forward_and_backward')
        self.attention = ApplyAttentionWrapper(attention)
        self.post_attention = tl.Parallel(post_attention, [], [])

        layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            tl.Parallel(tl.Add(), []),
        ]
        super(ReversibleAttentionHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            self.subtract_top,
        ]
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode):
    """Returns a layer sequence that implements a Transformer decoder block.

  The input to the layer sequence is an activation tensor.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an activation tensor to an activation tensor.
  """
    self_attention = [
        tl.LayerNorm(),  # vec
        tl.Dup(),  # vec vec
        tl.Parallel([], tl.CausalMask(axis=-2)),  # vec mask
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Parallel([], tl.Drop()),  # vec
        tl.Dropout(rate=dropout, mode=mode),  # vec
    ]
    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        tl.Residual(self_attention),
        tl.Residual(feed_forward),
    ]
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
    return tl.Serial(
        tl.Residual(  # Attention block here.
            tl.Parallel(tl.LayerNorm(), tl.Copy()),
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())),
        tl.Parallel(
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
            tl.Div(
                divisor=2.0)  # Mask added to itself in the residual, divide.
        ))
Exemple #7
0
 def lower(layer):
     """Apply layer below the current inputs, targets, and possibly weights."""
     if self._has_weights:
         # Apply layer below inputs, targets, and loss weights.
         return layers.Parallel([], [], [], layer)
     else:
         # Apply layer below inputs and targets.
         return layers.Parallel([], [], layer)
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn,
                         two_towers):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    @tl.layer()
    def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls and flattens controls."""
        return np.reshape(x, (x.shape[0], -1, n_actions))

    if vocab_size is None:
        # In continuous policies every element of the output sequence corresponds to
        # an observation.
        n_preds_per_input = n_controls
        kwargs = {}
    else:
        # In discrete policies every element of the output sequence corresponds to
        # a symbol in the discrete representation, and each control takes 1 symbol.
        n_preds_per_input = 1
        kwargs = {"vocab_size": vocab_size}

    if two_towers:
        layers = [
            tl.Dup(),
            tl.Parallel(
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input),
                    tl.Flatten()
                ],
            )
        ]
    else:
        layers = [
            bottom_layers_fn(**kwargs),
            tl.Dup(),
            tl.Parallel(
                [
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [tl.Dense(n_preds_per_input),
                 tl.Flatten()],
            )
        ]
    return tl.Model(layers)
def Transformer(vocab_size,
                d_feature=512,
                d_feedforward=2048,
                n_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    encoder = [
        tl.Branch(positional_embedder, tl.PaddingMask()),
        [
            EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],
        tl.LayerNorm(),
    ]
    return tl.Model(
        tl.Parallel([], tl.ShiftRight()),
        tl.Parallel(encoder, positional_embedder),
        tl.Select(inputs=(('encoder', 'mask'), 'decoder'),
                  output=('decoder', ('mask', 'decoder'), 'encoder')),
        # (encoder_mask, decoder_input) -> encoder-decoder mask
        tl.Parallel([], tl.EncoderDecoderMask(), []),
        [
            EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],
        tl.Select(0),  # Drop mask and encoder.
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax(),
    )
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        mode):
    """Transformer encoder-decoder layer.

  The input is a triple pair (encoder, mask, decoder_input) where
  the mask is created from the original source to prevent attending
  to the padding part of the encoder.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (encoder, mask, decoder_activations).
  """
    # Decoder self-attending to decoder.
    self_attention = layers.Residual(
        layers.LayerNorm(),
        layers.Branch(),
        layers.Parallel(
            layers.Identity(),  # activation for (q, k, v)
            layers.CausalMask(axis=-2)),  # attention mask
        layers.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
        layers.Dropout(rate=dropout, mode=mode))
    # Decoder attending to encoder.
    encoder_decoder_attention = layers.Serial(
        layers.Reorder(output=((2, 0, 0), 1)),  # ((dec, enc, enc), mask)
        layers.MultiHeadedAttentionQKV(  # ((q, k, v), mask) --> new v
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        layers.Dropout(rate=dropout, mode=mode),
    )
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.Identity(), self_attention),
        layers.Branch(),
        layers.Parallel(layers.Identity(), encoder_decoder_attention),
        layers.UnnestBranches(),  # (encoder, mask, old_act, new_act)
        layers.Reorder(output=(0, 1, (2, 3))),
        layers.Parallel(  # Residual after encoder-decoder attention.
            layers.Identity(), layers.Identity(), layers.SumBranches()),
        layers.Parallel(  # Feed-forward on the third component (decoder).
            layers.Identity(), layers.Identity(),
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size),
                              layers.Dropout(rate=dropout, mode=mode),
                              layers.PositionalEncoding(max_len=max_len))
    encoder = layers.Serial(
        layers.Branch(),  # Branch input to create embedding and mask.
        layers.Parallel(embedding, layers.PaddingMask()),
        layers.Serial(*[
            EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]),
        layers.Parallel(layers.LayerNorm(), layers.Identity()))
    stack = [
        EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, mode) for _ in range(num_layers)
    ]
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.ShiftRight()),
        layers.Parallel(encoder, embedding),
        layers.UnnestBranches(),  # (encoder, encoder_mask, decoder_input)
        layers.Reorder(output=(0, (1, 2), 2)),
        layers.
        Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
            layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()),
        layers.Serial(*stack),
        layers.ThirdBranch(),
        layers.LayerNorm(),
        layers.Dense(vocab_size),
        layers.LogSoftmax())
Exemple #12
0
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        mode):
    """Transformer encoder-decoder layer.

  The input is a triple pair (encoder, mask, decoder_input) where
  the mask is created from the original source to prevent attending
  to the padding part of the encoder.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (encoder, mask, decoder_activations).
  """
    # Decoder self-attending to decoder.
    self_attention = tl.Residual(
        tl.LayerNorm(),
        tl.Branch(tl.NoOp(), tl.CausalMask(axis=-2)),  # create mask
        tl.MultiHeadedAttention(feature_depth,
                                num_heads=num_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Select(0),  # drop mask
        tl.Dropout(rate=dropout, mode=mode))
    # Decoder attending to encoder.
    encoder_decoder_attention = tl.Serial(
        tl.Select((2, 0, 0, 1)),  # (dec, enc, enc, mask)
        tl.MultiHeadedAttentionQKV(  # (q, k, v, mask) --> new, mask
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        tl.Select(0),  # drop the mask
        tl.Dropout(rate=dropout, mode=mode),
    )
    return tl.Serial(
        tl.Parallel(tl.NoOp(), tl.NoOp(), self_attention),
        tl.Branch(tl.NoOp(), encoder_decoder_attention),
        tl.Select(inputs=(('encoder', 'mask', 'old_act'), 'new_act'),
                  output=('encoder', 'mask', ('old_act', 'new_act'))),
        tl.Parallel(  # Residual after encoder-decoder attention.
            tl.NoOp(), tl.NoOp(), tl.Add()),
        tl.Parallel(  # Feed-forward on the third component (decoder).
            tl.NoOp(), tl.NoOp(),
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
Exemple #13
0
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        ])

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]
 def loss(mask_id=None, has_weights=False):
     """Cross-entropy loss as scalar compatible with Trax masking."""
     return layers.Serial(
         # Swap from (pred-obs, pred-reward, target-obs, target-reward)
         # to (pred-obs, target-obs, pred-reward, target-reward).
         layers.Parallel([], layers.Swap()),
         # Cross-entropy loss for obs, L2 loss on reward.
         layers.Parallel(
             layers.CrossEntropyLossScalar(mask_id, has_weights),
             layers.L2LossScalar(mask_id, has_weights)),
         # Add both losses.
         layers.Add(),
         # Zero out in this test.
         layers.MulConstant(constant=0.0))
Exemple #15
0
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    embedding = tl.Serial(tl.Embedding(feature_depth, vocab_size),
                          tl.Dropout(rate=dropout, mode=mode),
                          tl.PositionalEncoding(max_len=max_len))
    encoder = tl.Serial(
        tl.Branch(embedding, tl.PaddingMask()),
        tl.Serial(*[
            EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]), tl.Parallel(tl.LayerNorm(), tl.NoOp()))
    stack = [
        EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, mode) for _ in range(num_layers)
    ]
    return tl.Serial(
        tl.Parallel(tl.NoOp(), tl.ShiftRight()),
        tl.Parallel(encoder, embedding),
        tl.Select(inputs=(('encoder', 'mask'), 'decoder'),
                  output=('encoder', ('mask', 'decoder'), 'decoder')),
        tl.Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
            tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()),
        tl.Serial(*stack),
        tl.Select(2),  # Drop encoder and mask.
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax())
 def model(mode):
   del mode
   return layers.Serial(
       layers.Parallel(
           layers.Flatten(),  # Observation stack.
           layers.Embedding(d_feature=1, vocab_size=n_actions),  # Action.
       ),
       layers.Concatenate(),
       layers.Dense(n_units=1),
       layers.Dup(),
       layers.Parallel(
           layers.Dense(n_units=obs_shape[1]),  # New observation.
           None,  # Reward.
       )
   )
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode):
    """Transformer encoder-decoder layer.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (decoder_activations, mask, encoder).
  """
    decoder_self_attention = [  #        vecs_d   pmask vecs_e
        tl.LayerNorm(),  #        vecs_d   ..... ......
        tl.Dup(),  # vecs_d vecs_d   ..... ......
        tl.Parallel([],
                    tl.CausalMask(axis=-2)),  # ______ masks    ..... ......
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Parallel([], tl.Drop()),  # ______   0      ..... ......
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d          ..... ......
    ]
    decoder_to_encoder_attention = [  # vecs_d        masks         vecs_e
        tl.Parallel([], [], tl.Dup()),  # ______        _____  vecs_e vecs_e
        tl.Parallel([], tl.Swap()),  # ______        vecs_e masks  ......
        tl.Parallel([], tl.Dup()),  # ______ vecs_e vecs_e .....  ......
        tl.MultiHeadedAttentionQKV(  # (q k v masks ... --> vecs_d masks ...)
            d_feature,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d mask vecs_e
    ]
    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [  # vecs_d masks vecs_e
        tl.Residual(decoder_self_attention),  # vecs_d masks vecs_e
        tl.Residual(decoder_to_encoder_attention),  # vecs_d masks vecs_e
        tl.Residual(feed_forward),  # vecs_d masks vecs_e
    ]
Exemple #18
0
    def Encoder(source, source_mask):
        """Transformer encoder stack.

    Args:
      source: layer variable: raw source sequences
      source_mask: layer variable: self-attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        encoder_layer = layers.Serial(
            # input attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    source_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
        )
        return layers.Serial(
            source,
            source_embedding_layer,
            layers.repeat(encoder_layer, num_layers),
            layers.LayerNorm(),
        )
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 n_attention_chunks, attention_type, dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """

    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key)
            ],
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key)
            ],
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value)
            ],
        ),
    ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_model, d_ff, dropout, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return layers.Serial(
        layers.Residual(  # Self-attention block.
            layers.LayerNorm(),
            layers.Branch(),
            layers.Parallel(
                layers.Identity(),  # activation for (q, k, v)
                layers.CausalMask(axis=-2)),  # attention mask
            layers.MultiHeadedAttention(feature_depth,
                                        num_heads=num_heads,
                                        dropout=dropout,
                                        mode=mode),
            layers.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Exemple #21
0
def SumLearnedPick(positions):
    """Get a pair (vec, pos) and pick new pos."""
    succ_keys = positions[:-1, :]
    succ_values = positions[1:, :]
    subtract_1_keys = positions[1:, :]
    subtract_1_values = positions[:-1, :]
    l = int(positions.shape[0]) // 2
    add_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for i in range(l)
        for j in range(l)
    ])
    add_values = np.array(
        [positions[i + j, :] for i in range(l) for j in range(l)])
    # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
    sub_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for j in range(l)
        for i in range(l)
    ])
    sub_values = np.array(
        [positions[max(i - j, 0), :] for j in range(l) for i in range(l)])
    return tl.Serial(
        tl.Dup(), tl.Dup(), tl.Dup(), tl.Dup(),
        tl.Parallel(
            LearnedQP(),
            LearnedQP(keys=succ_keys, values=succ_values),
            LearnedQP(keys=subtract_1_keys, values=subtract_1_values),
            LearnedQP(keys=add_keys, values=add_values, binary=True),
            LearnedQP(keys=sub_keys, values=sub_values, binary=True),
        ), Unnest(), SoftmaxBranches(n_branches=5))
Exemple #22
0
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up 4 successive game frames, concatenated on the last axis.
        tl.Dup(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)),
        tl.Concatenate(n_items=4, axis=-1),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
Exemple #23
0
def DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    positions: random vectors for positions
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return [
        tl.Residual(  # Self-attention block.
            PreservePosition(tl.LayerNorm()),
            tl.Dup(),
            tl.Parallel(
                [],  # activation for (q, k, v)
                tl.CausalMask(axis=-2)),  # attention mask
            MultiHeadedAttentionPosition(positions,
                                         d_feature,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode),
            PreservePosition(tl.Dropout(rate=dropout, mode=mode))),
        ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode)
    ]
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks,
                 attention_loop_stride, dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_loop_stride: int: number of query elements to compute attention
      for in parallel. Set to 0 to disable memory-efficient attention.
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """

    pre_attention = [
        Chunk(sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            [tl.Dense(d_feature),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
            [tl.Dense(d_feature),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
            [tl.Dense(d_feature),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
        ),
    ]

    # TODO(kitaev): add dropout
    if attention_loop_stride < 1:
        # Use the standard implementation if no loop_stride is provided.
        attention = DotProductAttention(dropout=None, mode=mode)
    else:
        attention = MemoryEfficientDotProductAttention(
            loop_stride=attention_loop_stride, dropout=None, mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        JoinHeads(),  # pylint: disable=no-value-for-parameter
        tl.Dense(d_feature),
        Unchunk(sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        ReversibleSwap(),
    ]
Exemple #25
0
def TransformerRevnetLM(vocab_size,
                        d_feature=512,
                        d_feedforward=2048,
                        d_attention_key=64,
                        d_attention_value=64,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        attention_loop_stride=0,
                        mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of *each half* of the two-part features
    d_feedforward: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_loop_stride: int: number of query elements to compute attention
      for in parallel. Set to 0 to disable memory-efficient attention.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        # TODO(kitaev): add dropout
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_feature, d_feedforward, d_attention_key,
                         d_attention_value, n_heads, n_attention_chunks,
                         attention_loop_stride, dropout, mode)
            for _ in range(n_layers)
        ]),
        tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
        tl.Concatenate(),
        Split(sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map([
            tl.Dense(vocab_size),
            tl.LogSoftmax(),
        ], sections=n_chunks),
    )
Exemple #26
0
    def Decoder(memory, target, target_mask, memory_mask):
        """Transformer decoder stack.

    Args:
      memory: layer variable: encoded source sequences
      target: layer variable: raw target sequences
      target_mask: layer variable: self-attention mask
      memory_mask: layer variable: memory attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        decoder_layer = layers.Serial(
            # target attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    target_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # target attends to encoded source
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    memory,  # key
                    memory,  # value
                    memory_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode))
        return layers.Serial(
            target,
            target_embedding_layer,
            layers.repeat(decoder_layer, num_layers),
            layers.LayerNorm(),
        )
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
    """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model([  #      tokens
        tl.Dup(),  # toks toks
        tl.Parallel(embedder, tl.PaddingMask()),  # vecs mask
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_layers)
        ],  # vecs mask
        tl.Parallel([], tl.Drop()),  # ____  0
        tl.LayerNorm(),  # vecs
        tl.Mean(axis=1),  # Average on length.    # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    ])
Exemple #28
0
def policy_and_value_net(n_actions, n_controls, bottom_layers_fn, two_towers):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    @tl.layer()
    def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls and flattens controls."""
        return np.reshape(x, (x.shape[0], -1, n_actions))

    n_logits = n_controls * n_actions

    if two_towers:
        layers = [
            tl.Dup(),
            tl.Parallel(
                [
                    bottom_layers_fn(),
                    tl.Dense(n_logits),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [bottom_layers_fn(),
                 tl.Dense(n_controls),
                 tl.Flatten()],
            )
        ]
    else:
        layers = [
            bottom_layers_fn(),
            tl.Dup(),
            tl.Parallel(
                [
                    tl.Dense(n_logits),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [tl.Dense(n_controls), tl.Flatten()],
            )
        ]
    return tl.Model(layers)
Exemple #29
0
def TransformerRevnetLM(vocab_size,
                        d_model=512,
                        d_ff=2048,
                        d_attention_key=64,
                        d_attention_value=64,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        attention_type=DotProductAttention,
                        mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                         n_heads, n_attention_chunks, attention_type, dropout,
                         mode) for _ in range(n_layers)
        ]),
        tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
        tl.Concatenate(),
        Split(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map([
            tl.Dense(vocab_size),
            tl.LogSoftmax(),
        ], n_sections=n_chunks),
    )
Exemple #30
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = layers.Serial(
        layers.BatchNorm(), layers.Relu(),
        layers.Conv(channels, (3, 3), strides, padding='SAME'),
        layers.BatchNorm(), layers.Relu(),
        layers.Conv(channels, (3, 3), padding='SAME'))
    shortcut = layers.Identity() if not channel_mismatch else layers.Conv(
        channels, (3, 3), strides, padding='SAME')
    return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut),
                         layers.SumBranches())