Exemple #1
0
 def test_swap(self):
   layer = tl.Swap()
   xs = [np.array([1, 2, 3]),
         np.array([10, 20, 30])]
   ys = layer(xs)
   self.assertEqual(as_list(ys), [[10, 20, 30],
                                  [1, 2, 3]])
Exemple #2
0
  def __init__(self, pre_attention, attention, post_attention):
    self.pre_attention = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(pre_attention, [], []),
    )
    assert hasattr(attention, 'forward_and_backward')
    self.attention = ApplyAttentionWrapper(attention)
    self.post_attention = tl.Parallel(post_attention, [], [])

    layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        tl.Parallel(tl.Add(), []),
    ]
    super(ReversibleAttentionHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        self.subtract_top,
    ]
Exemple #3
0
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial(
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        )

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]
Exemple #4
0
 def loss(mask_id=None, has_weights=False):
     """Cross-entropy loss as scalar compatible with Trax masking."""
     return layers.Serial(
         # Swap from (pred-obs, pred-reward, target-obs, target-reward)
         # to (pred-obs, target-obs, pred-reward, target-reward).
         layers.Parallel([], layers.Swap()),
         # Cross-entropy loss for obs, L2 loss on reward.
         layers.Parallel(
             layers.CrossEntropyLossScalar(mask_id, has_weights),
             layers.L2LossScalar(mask_id, has_weights)),
         # Add both losses.
         layers.Add(),
         # Zero out in this test.
         layers.MulConstant(constant=0.0))
Exemple #5
0
 def loss(id_to_mask=None, has_weights=False):
   """Cross-entropy loss as scalar compatible with Trax masking."""
   return layers.Serial(
       # Swap from (pred-obs, pred-reward, target-obs, target-reward)
       # to (pred-obs, target-obs, pred-reward, target-reward).
       layers.Parallel([], layers.Swap()),
       # Cross-entropy loss for obs, L2 loss on reward.
       layers.Parallel(layers.CrossEntropyLoss(id_to_mask, has_weights),
                       layers.L2Loss(id_to_mask, has_weights)),
       # Add both losses.
       layers.Add(),
       # Zero out in this test.
       layers.Fn(lambda x: x * 0.0),
   )
def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode,
                   ff_activation):
    """Transformer encoder-decoder layer.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    the layer, returning a triple (decoder_activations, mask, encoder).
  """
    decoder_self_attention = [  #        vecs_d   pmask vecs_e
        tl.LayerNorm(),  #        vecs_d   ..... ......
        tl.BasicCausalAttention(d_model,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d          ..... ......
    ]
    decoder_to_encoder_attention = [  # vecs_d        masks         vecs_e
        tl.LayerNorm(),  # vecs_d        masks         vecs_e
        tl.Parallel([], [], tl.Dup()),  # ______        _____  vecs_e vecs_e
        tl.Parallel([], tl.Swap()),  # ______        vecs_e masks  ......
        tl.Parallel([], tl.Dup()),  # ______ vecs_e vecs_e .....  ......
        tl.AttentionQKV(  # (q k v masks ... --> vecs_d masks ...)
            d_model,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d mask vecs_e
    ]
    feed_forward = [
        FeedForward(d_model, d_ff, dropout, layer_idx, mode, ff_activation),
    ]
    return tl.Serial(  # vecs_d masks vecs_e
        tl.Residual(decoder_self_attention),  # vecs_d masks vecs_e
        tl.Residual(decoder_to_encoder_attention),  # vecs_d masks vecs_e
        tl.Residual(feed_forward),  # vecs_d masks vecs_e
    )
Exemple #7
0
 def loss():
     """Cross-entropy loss as scalar compatible with Trax masking."""
     ones = layers.Fn(lambda x: math.numpy.ones_like(x))  # pylint: disable=unnecessary-lambda
     return layers.Serial(
         # Swap from (pred-obs, pred-reward, target-obs, target-reward)
         # to (pred-obs, target-obs, pred-reward, target-reward).
         layers.Parallel([], layers.Swap()),
         # Duplicate target-obs and target-reward and make 1 to add weights.
         layers.Parallel([], layers.Branch([], ones)),
         layers.Parallel([], [], [], [], layers.Branch([], ones)),
         # Cross-entropy loss for obs, L2 loss on reward.
         layers.Parallel(layers.CrossEntropyLoss(), layers.L2Loss()),
         # Add both losses.
         layers.Add(),
         # Zero out in this test.
         layers.Fn(lambda x: x * 0.0),
     )
Exemple #8
0
  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(         # x1_or_y1, x2,           ...
        tl.Parallel([], tl.Dup()),             # x1_or_y1, x2, x2,       ...
        tl.Swap(),                             # x2, x1_or_y1, x2,       ...
        tl.Parallel([], [], residual_layers),  # x2, x1_or_y1, residual, ...
        tl.Select([2, 1, 0]),                  # residual, x1_or_y1, x2, ...
    )

    self.n_preserve = self.compute_residual.n_out - 2
    parallel_preserve = [[]] * self.n_preserve

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), *parallel_preserve)
    ]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve)
    self.reverse_layers = [self.compute_residual, self.subtract_top]
Exemple #9
0
def ApplyAndQueryPositions(layer, pos):
    """Execute layer without position and pos-layers on positions.

  This takes an embedding including position x = (emb, p), and
  outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p)
  where pos=[pos1...posn].

  Args:
    layer: layer to be executed without position information.
    pos: list of layers to be applied to positions.

  Returns:
    the result of this application.
  """
    n_heads = len(pos)
    return tl.Serial(
        tl.Dup(),  # (x, x)
        CutAtPosition(),  # (x_content, x_position, x)
        tl.Parallel([], tl.Swap()),  # (x_content, x, x_position)
        [tl.Parallel([], Dup2()) for _ in range(n_heads - 1)],
        # Now the stack is x_content, (x, x_position) * n_heads.
        tl.Parallel(*([layer] + pos)),
        tl.Concatenate(n_items=n_heads + 1))
Exemple #10
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_encoder_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_decoder_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Serial(  # tokens_e tokens_d
        tl.Parallel([], tl.Dup()),  # toks_e toks_d toks_d (for loss)
        tl.Swap(),  # toks_d toks_e ....

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Dup(),  # ______ toks_e toks_e
                tl.Parallel(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  #        toks_d masks vecs_e
        tl.ShiftRight(),  #        toks_d ..... ......
        out_embed,  #        vecs_d ..... ......
        tl.Dup(),  # vecs_d vecs_d ..... ......
        tl.Parallel([], tl.EncoderDecoderMask()),  # ______    masks     ......
        encoder_decoder_stack,  # vecs_d    masks     vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )
Exemple #11
0
def LayerDropSkippingTransformerLM(vocab_size,
                                   d_model=512,
                                   d_ff=2048,
                                   n_layers=6,
                                   n_heads=8,
                                   dropout=0.1,
                                   max_len=2048,
                                   mode='train',
                                   ff_activation=tl.Relu,
                                   skip_fraction=0.4):
  """Returns a Skipping Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference
    ff_activation: the non-linearity in feed-forward layer
    skip_fraction: fraction of times to skip some layers

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  embedder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, mode=mode),
      tl.PositionalEncoding(max_len=max_len, mode=mode),
  ]

  def ConditionedBlock(current_layer_num):
    return tl.Serial(
        # stack: embedding, n_layers_to_keep
        tl.Select([1, 0, 1]),  # n_layers_to_keep, embedding, n_layers_to_keep
        tl.Cond(
            # if n_layers_to_keep > current_layer_num
            LargerThan(float(current_layer_num)),
            # then: run block
            tl.Serial(transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
                d_model, d_ff, n_heads, dropout, [], mode, ff_activation)),
            # else: run noop
            tl.Serial()
            )
        # stack: embedding, n_layers_to_keep
        )

  if mode == 'train':
    minimum_layers = 0.0
    maximum_layers = float(n_layers) / skip_fraction
  else:
    minimum_layers = maximum_layers = float(n_layers)

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      embedder,
      # stack: embedding
      tl.RandomUniform(minimum_layers, maximum_layers, sync=True),
      # stack: n_layers_to_keep, embedding
      tl.Swap(),
      # stack: embedding, n_layers_to_keep
      [ConditionedBlock(i) for i in range(n_layers)],
      # stack: embedding, n_layers_to_keep
      tl.Select([0], n_in=2),  # stack: embedding
      tl.LayerNorm(),
      tl.Dense(vocab_size),
      tl.LogSoftmax(),
  )
Exemple #12
0
def EveryOtherLayerDropTransformerLM(vocab_size,
                                     d_model=512,
                                     d_ff=2048,
                                     n_layers=6,
                                     n_heads=8,
                                     dropout=0.1,
                                     max_len=2048,
                                     mode='train',
                                     ff_activation=tl.Relu,
                                     skip_mode='even',
                                     skip_fraction=0.5,
                                     eval_skip_fraction=0.0):
    """Returns an "EveryOther" LayerDrop Transformer language model.

  During each training step it either runs all layers, or skips a subset of
  layers. This subset is the same every time, and it is specified by
  "skip_mode".
  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference
    ff_activation: the non-linearity in feed-forward layer
    skip_mode: which layers to skip when skipping: even/odd/1half/2half.
    skip_fraction: fraction of times to skip layers
    eval_skip_fraction: fraction of times to skip layers during eval

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    embedder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
    ]

    if mode == 'train':
        pass
    else:
        skip_fraction = eval_skip_fraction

    skip_mode_funs = {  # which layers should be skipped?
        'even': (lambda num: num%2 == 0),  # 0th layer is even
        'odd': (lambda num: num%2 == 1),
        '1half': (lambda num: num < (n_layers/2)),
        '2half': (lambda num: num >= (n_layers/2)),
    }

    skip_mode_fun = skip_mode_funs[skip_mode]

    @assert_shape('...sd,->...sd,')
    def ConditionedBlock(current_layer_num):
        return tl.Serial(
            # stack: embedding, n_layers_to_keep
            tl.Select([1, 0,
                       1]),  # n_layers_to_keep, embedding, n_layers_to_keep
            tl.Cond(
                # if random() > skip_fraction OR layer not in skip_mode ...
                LargerThan(skip_fraction if skip_mode_fun(current_layer_num
                                                          ) else 0.0),
                # then: run block
                tl.Serial(
                    transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
                        d_model, d_ff, n_heads, dropout, [], mode,
                        ff_activation))
                # else: noop (implicit)
            )
            # stack: embedding, n_layers_to_keep
        )

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        embedder,
        # stack: embedding
        tl.RandomUniform(0., 1., sync=True),
        # stack: n_layers_to_keep, embedding
        tl.Swap(),
        # stack: embedding, n_layers_to_keep
        [ConditionedBlock(i) for i in range(n_layers)],
        # stack: embedding, n_layers_to_keep
        tl.Select([0], n_in=2),  # stack: embedding
        tl.LayerNorm(),
        tl.Dense(vocab_size),
    )
Exemple #13
0
def Dup2():
    """Copy first 2 elements of the stack: (a, b, ...) -> (a, b, a, b, ...)."""
    return [  # Stack is (a, b, ...)
        tl.Parallel(tl.Dup(), tl.Dup()),  # Stack is (a, a, b, b, ...)
        tl.Parallel([], tl.Swap())  # Stack is (a, b, a, b, ...)
    ]