Beispiel #1
0
    def __init__(self, pre_attention, attention, post_attention):
        self.pre_attention = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(pre_attention, [], []),
        ])
        assert hasattr(attention, 'forward_and_backward')
        self.attention = ApplyAttentionWrapper(attention)
        self.post_attention = tl.Parallel(post_attention, [], [])

        layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            tl.Parallel(tl.Add(), []),
        ]
        super(ReversibleAttentionHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            self.subtract_top,
        ]
Beispiel #2
0
  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial([
        # TODO(jonni): Rewrite without using Select.
        tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')),
        tl.Parallel(residual_layers, [], []),
    ])

    layers = [self.compute_residual, tl.Add()]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.SubtractTop()
    self.reverse_layers = [self.compute_residual, self.subtract_top]
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        mode):
    """Transformer encoder-decoder layer.

  The input is a triple pair (encoder, mask, decoder_input) where
  the mask is created from the original source to prevent attending
  to the padding part of the encoder.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (encoder, mask, decoder_activations).
  """
    # Decoder self-attending to decoder.
    self_attention = tl.Residual(
        tl.LayerNorm(),
        tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # create mask
        tl.MultiHeadedAttention(feature_depth,
                                num_heads=num_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Select(0),  # drop mask
        tl.Dropout(rate=dropout, mode=mode))
    # Decoder attending to encoder.
    encoder_decoder_attention = tl.Serial(
        tl.Select(((2, 0, 0), 1)),  # ((dec, enc, enc), mask)
        tl.MultiHeadedAttentionQKV(  # ((q, k, v), mask) --> new, mask
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        tl.Select(0),  # drop the mask
        tl.Dropout(rate=dropout, mode=mode),
    )
    return tl.Serial(
        tl.Parallel(tl.Copy(), tl.Copy(), self_attention),
        tl.Branch(tl.Copy(), encoder_decoder_attention),
        tl.UnnestBranches(),  # (encoder, mask, old_act, new_act)
        tl.Select((0, 1, (2, 3))),
        tl.Parallel(  # Residual after encoder-decoder attention.
            tl.Copy(), tl.Copy(), tl.Add()),
        tl.Parallel(  # Feed-forward on the third component (decoder).
            tl.Copy(), tl.Copy(),
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
Beispiel #4
0
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        ])

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]
 def loss(mask_id=None, has_weights=False):
     """Cross-entropy loss as scalar compatible with Trax masking."""
     return layers.Serial(
         # Swap from (pred-obs, pred-reward, target-obs, target-reward)
         # to (pred-obs, target-obs, pred-reward, target-reward).
         layers.Parallel([], layers.Swap()),
         # Cross-entropy loss for obs, L2 loss on reward.
         layers.Parallel(
             layers.CrossEntropyLossScalar(mask_id, has_weights),
             layers.L2LossScalar(mask_id, has_weights)),
         # Add both losses.
         layers.Add(),
         # Zero out in this test.
         layers.MulConstant(constant=0.0))