def MultiHeadedAttentionPosition(positions, d_feature, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" return tl.Serial( tl.Dup(), tl.Dup(), tl.Parallel( ApplyAndQueryPositions( tl.Dense(d_feature), pos=[SumLearnedPick(positions) for _ in range(n_heads)]), PreservePosition(tl.Dense(d_feature)), PreservePosition(tl.Dense(d_feature)), ), tl.Parallel( CopyHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), ), tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop the mask. CombineHeadsPos(h=n_heads), PreservePosition(tl.Dense(d_feature)), )
def policy_and_value_net(n_actions, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: layers = [ tl.Dup(), tl.Parallel( [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()], [bottom_layers_fn(), tl.Dense(1)], ) ] else: layers = [ bottom_layers_fn(), tl.Dup(), tl.Parallel( [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)], ) ] return tl.Model(layers)
def policy_and_value_net(rng_key, batch_observations_shape, observations_dtype, n_actions, bottom_layers_fn=(), two_towers=True): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: layers = [ tl.Dup(), tl.Parallel( [bottom_layers_fn(), tl.Dense(n_actions), tl.LogSoftmax()], [bottom_layers_fn(), tl.Dense(1)], ) ] else: layers = [ bottom_layers_fn(), tl.Dup(), tl.Parallel( [tl.Dense(n_actions), tl.LogSoftmax()], [tl.Dense(1)], ) ] net = tl.Model(layers) params, state = net.initialize(batch_observations_shape, observations_dtype, rng_key) return params, state, net
def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ]) assert hasattr(attention, 'forward_and_backward') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.Dup(), # vec vec tl.Parallel([], tl.CausalMask(axis=-2)), # vec mask tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # vec tl.Dropout(rate=dropout, mode=mode), # vec ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.Copy()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def lower(layer): """Apply layer below the current inputs, targets, and possibly weights.""" if self._has_weights: # Apply layer below inputs, targets, and loss weights. return layers.Parallel([], [], [], layer) else: # Apply layer below inputs and targets. return layers.Parallel([], [], layer)
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) if vocab_size is None: # In continuous policies every element of the output sequence corresponds to # an observation. n_preds_per_input = n_controls kwargs = {} else: # In discrete policies every element of the output sequence corresponds to # a symbol in the discrete representation, and each control takes 1 symbol. n_preds_per_input = 1 kwargs = {"vocab_size": vocab_size} if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input), tl.Flatten() ], ) ] else: layers = [ bottom_layers_fn(**kwargs), tl.Dup(), tl.Parallel( [ tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_preds_per_input), tl.Flatten()], ) ] return tl.Model(layers)
def Transformer(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] encoder = [ tl.Branch(positional_embedder, tl.PaddingMask()), [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.LayerNorm(), ] return tl.Model( tl.Parallel([], tl.ShiftRight()), tl.Parallel(encoder, positional_embedder), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('decoder', ('mask', 'decoder'), 'encoder')), # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Parallel([], tl.EncoderDecoderMask(), []), [ EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask and encoder. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = layers.Residual( layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = layers.Serial( layers.Reorder(output=((2, 0, 0), 1)), # ((dec, enc, enc), mask) layers.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new v feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), ) return layers.Serial( layers.Parallel(layers.Identity(), layers.Identity(), self_attention), layers.Branch(), layers.Parallel(layers.Identity(), encoder_decoder_attention), layers.UnnestBranches(), # (encoder, mask, old_act, new_act) layers.Reorder(output=(0, 1, (2, 3))), layers.Parallel( # Residual after encoder-decoder attention. layers.Identity(), layers.Identity(), layers.SumBranches()), layers.Parallel( # Feed-forward on the third component (decoder). layers.Identity(), layers.Identity(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len)) encoder = layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(embedding, layers.PaddingMask()), layers.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), layers.Parallel(layers.LayerNorm(), layers.Identity())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return layers.Serial( layers.Parallel(layers.Identity(), layers.ShiftRight()), layers.Parallel(encoder, embedding), layers.UnnestBranches(), # (encoder, encoder_mask, decoder_input) layers.Reorder(output=(0, (1, 2), 2)), layers. Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()), layers.Serial(*stack), layers.ThirdBranch(), layers.LayerNorm(), layers.Dense(vocab_size), layers.LogSoftmax())
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Branch(tl.NoOp(), tl.CausalMask(axis=-2)), # create mask tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop mask tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select((2, 0, 0, 1)), # (dec, enc, enc, mask) tl.MultiHeadedAttentionQKV( # (q, k, v, mask) --> new, mask feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop the mask tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( tl.Parallel(tl.NoOp(), tl.NoOp(), self_attention), tl.Branch(tl.NoOp(), encoder_decoder_attention), tl.Select(inputs=(('encoder', 'mask', 'old_act'), 'new_act'), output=('encoder', 'mask', ('old_act', 'new_act'))), tl.Parallel( # Residual after encoder-decoder attention. tl.NoOp(), tl.NoOp(), tl.Add()), tl.Parallel( # Feed-forward on the third component (decoder). tl.NoOp(), tl.NoOp(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]
def loss(mask_id=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel( layers.CrossEntropyLossScalar(mask_id, has_weights), layers.L2LossScalar(mask_id, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.MulConstant(constant=0.0))
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = tl.Serial(tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len)) encoder = tl.Serial( tl.Branch(embedding, tl.PaddingMask()), tl.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), tl.Parallel(tl.LayerNorm(), tl.NoOp())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return tl.Serial( tl.Parallel(tl.NoOp(), tl.ShiftRight()), tl.Parallel(encoder, embedding), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('encoder', ('mask', 'decoder'), 'decoder')), tl.Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()), tl.Serial(*stack), tl.Select(2), # Drop encoder and mask. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. ) )
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.CausalMask(axis=-2)), # ______ masks ..... ...... tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # ______ 0 ..... ...... tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.MultiHeadedAttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e ]
def Encoder(source, source_mask): """Transformer encoder stack. Args: source: layer variable: raw source sequences source_mask: layer variable: self-attention mask Returns: Layer variable that outputs encoded source. """ encoder_layer = layers.Serial( # input attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value source_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), ) return layers.Serial( source, source_embedding_layer, layers.repeat(encoder_layer, num_layers), layers.LayerNorm(), )
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( [ tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key) ], [ tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key) ], [ tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value) ], ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return layers.Serial( layers.Residual( # Self-attention block. layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def SumLearnedPick(positions): """Get a pair (vec, pos) and pick new pos.""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l) ]) add_values = np.array( [positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l) ]) sub_values = np.array( [positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) return tl.Serial( tl.Dup(), tl.Dup(), tl.Dup(), tl.Dup(), tl.Parallel( LearnedQP(), LearnedQP(keys=succ_keys, values=succ_values), LearnedQP(keys=subtract_1_keys, values=subtract_1_values), LearnedQP(keys=add_keys, values=add_values, binary=True), LearnedQP(keys=sub_keys, values=sub_values, binary=True), ), Unnest(), SoftmaxBranches(n_branches=5))
def AtariCnn(hidden_sizes=(32, 32), output_size=128, mode='train'): """An Atari CNN.""" del mode # TODO(jonni): Include link to paper? # Input shape: (B, T, H, W, C) # Output shape: (B, T, output_size) return tl.Model( tl.ToFloat(), tl.Div(divisor=255.0), # Set up 4 successive game frames, concatenated on the last axis. tl.Dup(), tl.Dup(), tl.Dup(), tl.Parallel(None, _shift_right(1), _shift_right(2), _shift_right(3)), tl.Concatenate(n_items=4, axis=-1), # (B, T, H, W, 4C) tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), tl.Relu(), tl.Flatten(n_axes_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), )
def DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode): """Transformer decoder layer. Args: positions: random vectors for positions d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return [ tl.Residual( # Self-attention block. PreservePosition(tl.LayerNorm()), tl.Dup(), tl.Parallel( [], # activation for (q, k, v) tl.CausalMask(axis=-2)), # attention mask MultiHeadedAttentionPosition(positions, d_feature, n_heads=n_heads, dropout=dropout, mode=mode), PreservePosition(tl.Dropout(rate=dropout, mode=mode))), ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode) ]
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, attention_loop_stride, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_loop_stride: int: number of query elements to compute attention for in parallel. Set to 0 to disable memory-efficient attention. dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention = [ Chunk(sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter ), ] # TODO(kitaev): add dropout if attention_loop_stride < 1: # Use the standard implementation if no loop_stride is provided. attention = DotProductAttention(dropout=None, mode=mode) else: attention = MemoryEfficientDotProductAttention( loop_stride=attention_loop_stride, dropout=None, mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ JoinHeads(), # pylint: disable=no-value-for-parameter tl.Dense(d_feature), Unchunk(sections=n_attention_chunks), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), ReversibleSwap(), ReversibleHalfResidual(feed_forward), ReversibleSwap(), ]
def TransformerRevnetLM(vocab_size, d_feature=512, d_feedforward=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, attention_loop_stride=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_feature: int: depth of *each half* of the two-part features d_feedforward: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_loop_stride: int: number of query elements to compute attention for in parallel. Set to 0 to disable memory-efficient attention. mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), # TODO(kitaev): add dropout tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(n_items=n_chunks), tl.ShiftRight(), positional_embedder, tl.Dup(), ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_feature, d_feedforward, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_loop_stride, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ], sections=n_chunks), )
def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: layer variable: encoded source sequences target: layer variable: raw target sequences target_mask: layer variable: self-attention mask memory_mask: layer variable: memory attention mask Returns: Layer variable that outputs encoded source. """ decoder_layer = layers.Serial( # target attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value target_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # target attends to encoded source layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query memory, # key memory, # value memory_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)) return layers.Serial( target, target_embedding_layer, layers.repeat(decoder_layer, num_layers), layers.LayerNorm(), )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='emb_dropout', mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model([ # tokens tl.Dup(), # toks toks tl.Parallel(embedder, tl.PaddingMask()), # vecs mask [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ], # vecs mask tl.Parallel([], tl.Drop()), # ____ 0 tl.LayerNorm(), # vecs tl.Mean(axis=1), # Average on length. # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs ])
def policy_and_value_net(n_actions, n_controls, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) n_logits = n_controls * n_actions if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(), tl.Dense(n_logits), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [bottom_layers_fn(), tl.Dense(n_controls), tl.Flatten()], ) ] else: layers = [ bottom_layers_fn(), tl.Dup(), tl.Parallel( [ tl.Dense(n_logits), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_controls), tl.Flatten()], ) ] return tl.Model(layers)
def TransformerRevnetLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, attention_type=DotProductAttention, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(n_items=n_chunks), tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = layers.Serial( layers.BatchNorm(), layers.Relu(), layers.Conv(channels, (3, 3), strides, padding='SAME'), layers.BatchNorm(), layers.Relu(), layers.Conv(channels, (3, 3), padding='SAME')) shortcut = layers.Identity() if not channel_mismatch else layers.Conv( channels, (3, 3), strides, padding='SAME') return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut), layers.SumBranches())