def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ]) assert hasattr(attention, 'forward_and_backward') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # TODO(jonni): Rewrite without using Select. tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Add()] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.SubtractTop() self.reverse_layers = [self.compute_residual, self.subtract_top]
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # create mask tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop mask tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select(((2, 0, 0), 1)), # ((dec, enc, enc), mask) tl.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new, mask feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop the mask tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( tl.Parallel(tl.Copy(), tl.Copy(), self_attention), tl.Branch(tl.Copy(), encoder_decoder_attention), tl.UnnestBranches(), # (encoder, mask, old_act, new_act) tl.Select((0, 1, (2, 3))), tl.Parallel( # Residual after encoder-decoder attention. tl.Copy(), tl.Copy(), tl.Add()), tl.Parallel( # Feed-forward on the third component (decoder). tl.Copy(), tl.Copy(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]
def loss(mask_id=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel( layers.CrossEntropyLossScalar(mask_id, has_weights), layers.L2LossScalar(mask_id, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.MulConstant(constant=0.0))