Esempio n. 1
0
    def test_repr(self):
        layer = tl.Concatenate()
        self.assertEqual(repr(layer), 'Concatenate_in2')

        layer = tl.Concatenate(axis=0)
        self.assertEqual(repr(layer), 'Concatenate_axis0_in2')

        layer = tl.Concatenate(axis=1)
        self.assertEqual(repr(layer), 'Concatenate_axis1_in2')

        layer = tl.Concatenate(n_items=3)
        self.assertEqual(repr(layer), 'Concatenate_in3')
Esempio n. 2
0
    def create_reformer_blocks(n_layers, dense=True):  # pylint: disable=invalid-name
        if n_layers == 0:
            return [tl.LayerNorm()]
        d_per_head = d_model // n_heads
        decoder_blocks = [
            DecoderBlock(
                d_model,
                d_ff,
                d_per_head,
                d_per_head,
                n_heads,  # pylint: disable=g-complex-comprehension
                vanilla_attn_type,
                dropout,
                ff_activation,
                dropout,
                ff_use_sru=0,
                ff_chunk_size=0,
                ff_sparsity=0,
                attention_chunk_size=0,
                mode=mode) for _ in range(n_layers)
        ]

        return [
            tl.Dup(),
            tl.ReversibleSerial(decoder_blocks),
            tl.Concatenate(),
            tl.LayerNorm(),
            tl.Dense(d_model) if dense else [],
        ]
Esempio n. 3
0
def PoolLayer(pool_layer=tl.AvgPool,
              pool_size=(2,),
              strides=(2,),
              separate_cls=True):
  """Returns a pool layer for Funnel Transformer.

  Args:
    pool_layer: Type of pooling layer used for downsampling;
        should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    strides: Offsets from the location of one window to the locations of
        neighboring windows along each axis. If specified, must be a tuple of
        the same length as `pool_size`. If None, then offsets of 1 along each
        window axis, :math:`(1, ..., 1)`, will be used.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
          embeddings of the first token (`cls` from BERT paper).
  """
  if separate_cls:
    cls_selection = tl.Fn('select_cls_token', lambda x: x[:, :1, :])
    tokens_after_cls = tl.Fn('rest_tokens', lambda x: x[:, 1:, :])

    return tl.Serial(
        tl.Branch(
            cls_selection,
            tl.Serial(
                tokens_after_cls,
                pool_layer(pool_size, strides)
            )
        ),
        tl.Concatenate(axis=1)
    )
  else:
    return pool_layer(pool_size, strides)
Esempio n. 4
0
 def MultiRNNCell():
     """Multi-layer RNN cell."""
     return tl.Serial(
         tl.Parallel([], tl.Split(n_items=n_layers)),
         tl.SerialWithSideOutputs(
             [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
         tl.Parallel([], tl.Concatenate(n_items=n_layers)))
Esempio n. 5
0
def MultiplicativeConvCausalAttention(d_feature,
                                      n_heads=1,
                                      sparsity=None,
                                      length_kernel_size=3,
                                      dropout=0.0,
                                      max_inference_length=2048,
                                      mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConvLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    length_kernel_size: Size of convolution kernel on the length dimension.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    sparsity = n_heads if sparsity is None else sparsity
    d_module = d_feature // sparsity
    return tl.Serial(
        tl.Select([0, 0]),  # duplicate activations
        MultiplicativeSparseDense(sparsity, d_feature,
                                  d_feature),  # shared q, k
        tl.Select([0, 0, 0]),  # use for q, k, v
        tl.Parallel(
            [
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
            [
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
            [
                tl.Concatenate(),  # use permuted and original for v
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=1,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
        ),
        tl.DotProductCausalAttention(dropout=dropout,
                                     max_inference_length=max_inference_length,
                                     mode=mode),
        tl.MergeHeads(n_heads),
    )
Esempio n. 6
0
    def test_run_reversible_same_as_default_extended(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        # We want to test rng propagation too, so adding some dropout layers.
        first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup())
        rev_layers1 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)),
            tl.ReversibleSwap(),
            tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)),
            tl.ReversibleSwap()
        ]
        mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup())
        rev_layers2 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)),
            tl.ReversibleSwap()
        ]
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3),
                               tl.LogSoftmax(), tl.CrossEntropyLoss())
        model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] +
                          rev_layers2 + [loss_layer])
        rng_init = fastmath.random.get_prng(12)
        model.init(labeled_batch, rng=rng_init)
        optimizer_fn = optimizers.Adam  # to test slots

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model.weights)
        trainer = optimizers.Trainer(model, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_layer_weights1 = first_layer.weights
        rev_layer12_weights1 = rev_layers1[2].weights
        mid_layer_weights1 = mid_layer.weights
        rev_layer20_weights1 = rev_layers2[0].weights
        loss_layer_weights1 = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model.init(labeled_batch, rng=rng_init)
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer.sublayers, rev_layers1),
             (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_layer_weights1, loss_layer.weights)
        self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights)
        self._assert_all_equal(mid_layer_weights1, mid_layer.weights)
        self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights)
        self._assert_all_equal(first_layer_weights1, first_layer.weights)
Esempio n. 7
0
 def test_with_defaults(self):
     layer = tl.Concatenate()  # Default n_items=2, axis=-1
     xs = [
         np.array([[1, 2, 3], [4, 5, 6]]),
         np.array([[10, 20, 30], [40, 50, 60]])
     ]
     ys = layer(xs)
     self.assertEqual(as_list(ys),
                      [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]])
Esempio n. 8
0
 def test_axis_1(self):
     layer = tl.Concatenate(axis=1)
     xs = [
         np.array([[1, 2, 3], [4, 5, 6]]),
         np.array([[10, 20, 30], [40, 50, 60]])
     ]
     y = layer(xs)
     self.assertEqual(as_list(y),
                      [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]])
Esempio n. 9
0
  def create_reformer_blocks(  # pylint: disable=invalid-name
      n_layers,
      total_kv_pooling=1,
      layer_chunk_len=None,
      force_relative=False,
      dense=True):
    if n_layers == 0:
      return [tl.LayerNorm()]

    def determine_attn_type(layer_number):  # pylint: disable=invalid-name
      if layer_chunk_len is None and not force_relative:
        return vanilla_attn_type

      if layer_chunk_len is not None:
        chunk_offset = (layer_number % 2) * (layer_chunk_len // 2)
      else:
        chunk_offset = None

      return functools.partial(
          RelativeAttentionWrapper,
          n_raw_tokens_generated=n_raw_tokens_generated,
          max_inference_length=max_len,
          total_kv_pooling=total_kv_pooling,
          chunk_len=layer_chunk_len,
          chunk_offset=chunk_offset)

    d_per_head = d_model // n_heads

    decoder_blocks = []
    for i in range(n_layers):
      layer_attn_type = determine_attn_type(i)

      decoder_blocks.append(
          DecoderBlock(
              d_model,
              d_ff,
              d_per_head,
              d_per_head,
              n_heads,
              layer_attn_type,
              dropout,
              ff_activation,
              dropout,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_sparsity=0,
              attention_chunk_size=0,
              mode=mode))

    return [
        tl.Dup(),
        tl.ReversibleSerial(decoder_blocks),
        tl.Concatenate(),
        tl.LayerNorm(),
        tl.Dense(d_model) if dense else [],
    ]
Esempio n. 10
0
 def test_n_items_is_not_default(self):
     layer = tl.Concatenate(n_items=3)
     xs = [
         np.array([[1, 2, 3], [4, 5, 6]]),
         np.array([[10, 20, 30], [40, 50, 60]]),
         np.array([[100, 200, 300], [400, 500, 600]])
     ]
     y = layer(xs)
     self.assertEqual(y.shape, (2, 9))
     self.assertEqual(as_list(y), [[1, 2, 3, 10, 20, 30, 100, 200, 300],
                                   [4, 5, 6, 40, 50, 60, 400, 500, 600]])
Esempio n. 11
0
 def test_run_reversible_slots(self):
   """Tests that slots can be read and assigned in reversible trainer."""
   layers = [tl.Dense(4), tl.Dup()]
   rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)),
                 tl.ReversibleSwap()]
   loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4),
                          tl.LogSoftmax(), tl.CrossEntropyLoss())
   trainer = optimizers.ReversibleSerialTrainer(
       [(layers, rev_layers)], loss_layer, optimizers.Adam)
   slots = trainer.slots
   trainer.slots = slots
   self.assertEqual(slots, trainer.slots)
Esempio n. 12
0
def _FrameStack(n_frames):
    """Stacks successive game frames along their last dimension."""
    # Input shape: (B, T, ..., C).
    # Output shape: (B, T, ..., C * n_frames).
    assert n_frames >= 1
    if n_frames == 1:
        return []  # No-op; just let the data flow through.
    return [
        # Create copies of input sequence, shift right by [0, ..., n_frames - 1]
        # frames, and concatenate along the channel dimension.
        tl.Branch(*map(_shift_right, range(n_frames))),
        tl.Concatenate(n_items=n_frames, axis=-1)
    ]
Esempio n. 13
0
    def test_run_reversible_large_weights(self):
        """Runs the reversible trainer with a lot of weights to test memory use."""
        # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.global_device_count() == 1 and ram_limited:
            return

        # Create inputs and rngs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 18  # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram
        rev_layers = []
        int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
        shape = shapes.ShapeDtype((2, 4, 16 * 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.Adafactor

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # Set to true to run again, e.g., for profiling.
        run_twice = False
        if run_twice:
            t = time.time()
            loss, _ = trainer.one_step(labeled_batch, rng_step)
            self.assertLess(float(loss.sum()),
                            10000.0)  # Just to get the loss.
            print('Took %.3f seconds to run, loss %s' %
                  (time.time() - t, loss))
Esempio n. 14
0
def FrameStack(n_frames):
    """Stacks a fixed number of frames along the dimension 1."""
    # Input shape: (B, T, ..., C).
    # Output shape: (B, T, ..., C * n_frames).
    assert n_frames >= 1
    if n_frames == 1:
        return ()
    return (
        # Make n_frames copies of the input sequence.
        [tl.Dup()] * (n_frames - 1),
        # Shift copies to the right by [0, .., n_frames - 1] frames.
        tl.Parallel(*map(_shift_right, range(n_frames))),
        # Concatenate along the channel dimension.
        tl.Concatenate(n_items=n_frames, axis=-1),
    )
Esempio n. 15
0
 def model(mode):
     del mode
     return layers.Serial(
         layers.Parallel(
             layers.Flatten(),  # Observation stack.
             layers.Embedding(d_feature=1,
                              vocab_size=n_actions),  # Action.
         ),
         layers.Concatenate(),
         layers.Dense(n_units=1),
         layers.Dup(),
         layers.Parallel(
             layers.Dense(n_units=obs_shape[1]),  # New observation.
             None,  # Reward.
         ))
Esempio n. 16
0
    def test_run_reversible_weights_trainsfer_xprof(self):
        """Runs the reversible trainer and profiles weight transfer stats."""
        run_this_test = False  # We only run this test manually.
        if not run_this_test or fastmath.global_device_count(
        ) == 1:  # TPU only
            return

        # Create inputs and rngs.
        inputs_batch = np.ones((1024, 128), dtype=np.int32)
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 6
        rev_layers = []
        int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32)
        shape = shapes.ShapeDtype((1024, 128, 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.SGD

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # We profile here.
        t = time.time()
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
Esempio n. 17
0
    def test_train_memory_efficient(self):
        """Trains a large network in a memory-efficient way."""
        # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.device_count() == 1 and ram_limited:
            return

        # Create the model.
        n_layers = 16  # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram
        model = tl.Serial(
            tl.Embedding(9, 16 * 1024),
            tl.Dup(),
            [[
                tl.ReversibleHalfResidual(tl.Dense(16 * 1024)),
                tl.ReversibleSwap()
            ] for _ in range(n_layers)],
            tl.Concatenate(),
            tl.Dense(9),
        )

        # Create inputs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))

        def _data_gen():
            while True:
                yield labeled_batch

        # Run training.
        loss_layer = tl.WeightedCategoryCrossEntropy()
        task = training.TrainTask(_data_gen(), loss_layer,
                                  optimizers.Adafactor)
        eval_task = training.EvalTask(_data_gen(),
                                      [tl.WeightedCategoryCrossEntropy()])
        loop = training.Loop(model, [task],
                             eval_tasks=[eval_task],
                             eval_at=lambda step_n: step_n == 2,
                             use_memory_efficient_trainer=True)
        self.assertEqual(0, loop.step)
        loop.run(n_steps=2)
        self.assertEqual(2, loop.step)
Esempio n. 18
0
  def test_run_reversible_large_weights(self):
    """Runs the reversible trainer with a lot of weights to test memory use."""
    # This test requires > 20GB RAM, only run on TPUs. It does pass on GPU
    # and CPU when you run it locally, but it's too big for unit-testing.
    ram_limited = True  # Set to False to run this test locally.
    if fastmath.device_count() == 1 and ram_limited:
      return

    # Create inputs and rngs.
    inputs_batch = np.arange(8).reshape((2, 4))
    targets_batch = inputs_batch
    labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
    first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup())
    rng_init = fastmath.random.get_prng(12)
    rng_step = fastmath.random.get_prng(13)

    # Initialize layers.
    first_layer.init(labeled_batch, rng=rng_init)
    n_layers = 20  # 20 layers each 16K x 16K = 256M weights ~= 1GB, 20GB ram
    rev_layers = []
    int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
    shape = shapes.ShapeDtype((2, 4, 16*1024))
    sig = (shape, shape)
    for _ in range(n_layers):
      layer = tl.ReversibleHalfResidual(tl.Dense(16*1024))
      layer.init(sig, rng=rng_init)
      layer.weights = tl.on_cpu(layer.weights)  # store weights in cpu memory
      rev_layers.append(layer)
      rev_layers.append(tl.ReversibleSwap())
    loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9),
                           tl.LogSoftmax(), tl.CrossEntropyLoss())
    loss_layer.init((shape, shape, int_shape, int_shape))
    optimizer_fn = optimizers.Adafactor

    # Make a step with reversible trainer.
    trainer = optimizers.ReversibleSerialTrainer(
        first_layer, rev_layers, loss_layer, optimizer_fn)
    trainer.one_step(labeled_batch, rng_step)
Esempio n. 19
0
def ApplyAndQueryPositions(layer, pos):
    """Execute layer without position and pos-layers on positions.

  This takes an embedding including position x = (emb, p), and
  outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p)
  where pos=[pos1...posn].

  Args:
    layer: layer to be executed without position information.
    pos: list of layers to be applied to positions.

  Returns:
    the result of this application.
  """
    n_heads = len(pos)
    return tl.Serial(
        tl.Dup(),  # (x, x)
        CutAtPosition(),  # (x_content, x_position, x)
        tl.Parallel([], tl.Swap()),  # (x_content, x, x_position)
        [tl.Parallel([], Dup2()) for _ in range(n_heads - 1)],
        # Now the stack is x_content, (x, x_position) * n_heads.
        tl.Parallel(*([layer] + pos)),
        tl.Concatenate(n_items=n_heads + 1))
Esempio n. 20
0
def FunnelTransformerLM(vocab_size,
                        d_model=512,
                        d_ff=2048,
                        vanilla_layers=(0, 1),
                        shorten_factors=(3,),
                        n_funnel_blocks=(6,),
                        n_heads=8,
                        dropout=0.1,
                        dropout_shared_axes=None,
                        mode='train',
                        ff_activation=tl.FastGelu):
  """Returns a Transformer language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  This model uses only the decoder part of the overall Transformer.

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level
        Transformer decoder layers before and after shortening.
    shorten_factors: by how much to shorten at each step - tuple of arbitrary
        length denoting by how much shorten at each pooling stage.
    n_funnel_blocks: number of Transformer decoder blocks after each stage of
        pooling - tuple of the same length as `shorten_factors`.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: str: 'train' or 'eval'.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  assert mode != 'predict'  # For now, 'predict' mode is unsupported.
  assert len(n_funnel_blocks) == len(shorten_factors)

  token_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)]

  context_bias_layer, location_bias_layer = _get_rel_att_inputs(d_model,
                                                                n_heads)

  n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers

  def create_decoder_blocks(n_layers, total_pooling):  # pylint: disable=invalid-name
    decoder_blocks = [
        # pylint: disable=g-complex-comprehension
        _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout,
                              dropout_shared_axes, mode, ff_activation,
                              context_bias_layer, location_bias_layer,
                              total_pooling)
        for _ in range(n_layers)]
    return decoder_blocks + [tl.LayerNorm()]

  total_pooling_acc = 1
  pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks,
                                             total_pooling=1)

  funnel_blocks = []

  for shorten_factor, block_len in zip(shorten_factors, n_funnel_blocks):
    funnel_blocks = funnel_blocks + [_FunnelRelativeDecoderBlock(
        d_model, d_ff, n_heads, dropout,
        dropout_shared_axes, mode,
        ff_activation,
        context_bias_layer=context_bias_layer,
        location_bias_layer=location_bias_layer,
        total_pooling=total_pooling_acc,
        shorten_factor=shorten_factor,
        resampler_fn=_DownsamplerLM)]
    total_pooling_acc *= shorten_factor
    funnel_blocks = funnel_blocks + create_decoder_blocks(block_len,
                                                          total_pooling_acc)

  upsampling_layer = _FunnelRelativeDecoderBlock(
      d_model, d_ff, n_heads, dropout,
      dropout_shared_axes, mode,
      ff_activation,
      context_bias_layer=context_bias_layer,
      location_bias_layer=location_bias_layer,
      total_pooling=total_pooling_acc,
      shorten_factor=total_pooling_acc,
      resampler_fn=_UpsamplerLM)

  conv_layer = tl.Serial(
      tl.CausalConv(d_model, total_pooling_acc),
      ff_activation()
  )

  post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks,
                                              total_pooling=1)

  # Assemble and return the model.
  return tl.Serial(              # tokens (or chunked tuple of tokens)
      tl.ShiftRight(mode=mode),  # toks
      token_encoder,             # vecs
      pre_decoder_blocks,        # vecs
      tl.Dup(),
      tl.ShiftRight(n_positions=total_pooling_acc - 1),
      funnel_blocks,
      tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
      upsampling_layer,
      tl.LayerNorm(),
      tl.Concatenate(),
      conv_layer,
      post_decoder_blocks,
      tl.Dense(vocab_size),      # vecs
  )
Esempio n. 21
0
 def test_n_in_n_out(self):
     layer = tl.Concatenate()
     self.assertEqual(layer.n_in, 2)
     self.assertEqual(layer.n_out, 1)
Esempio n. 22
0
def TransformerLM(vocab_size,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  d_attention_key=None,
                  d_attention_value=None,
                  attention_type=tl.DotProductCausalAttention,
                  dropout=0.1,
                  share_qk=False,
                  max_len=2048,
                  n_chunks=0,
                  mode='train'):
    """Returns a Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_model // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_model // n_heads)
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys in decoder attention
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    if n_chunks == 0:
        concatenate_chunks = split_chunks = []
    else:
        concatenate_chunks = tl.Concatenate(n_items=n_chunks)
        split_chunks = tl.Split(n_sections=n_chunks, axis=-2)

    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='embedding', mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
    ]

    return tl.Serial(  # tokens (or chunked tuple of tokens)
        concatenate_chunks,  # tokens
        tl.ShiftRight(mode=mode),  # toks
        embedder,  # vecs
        [
            DecoderBlock(  # pylint: disable=g-complex-comprehension
                d_model, d_ff, n_heads, d_attention_key, d_attention_value,
                attention_type, dropout, share_qk, i, mode)
            for i in range(n_layers)
        ],  # vecs
        tl.LayerNorm(),  # vecs
        tl.Dense(vocab_size),  # vecs
        tl.LogSoftmax(),  # vecs
        split_chunks,  # vecs (or chunked tuple of vecs)
    )
Esempio n. 23
0
def PreservePosition(layer):
    """Execute layer without position but preserve it in parallel."""
    return tl.Serial(CutAtPosition(), layer, tl.Concatenate(n_items=2))
Esempio n. 24
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               attention_type=tl.SelfAttention,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               ff_chunk_size=0,
               ff_sparsity=0,
               loss_sparsity_type='mult',
               loss_sparsity=0,
               loss_d_lowrank=0,
               loss_sparsity_prob=None,
               attention_chunk_size=0,
               mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    attention_type: class: attention class to use, such as SelfAttention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train', 'eval', or 'predict'

  Returns:
    the layer.
  """
    positional_encoding = ct.PositionalEncoder(mode, dropout, max_len,
                                               axial_pos_shape,
                                               d_axial_pos_embs)

    positional_embedder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     attention_chunk_size=attention_chunk_size,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    dense_loss_layer = tl.SparseDenseWithOptions(
        vocab_size,
        d_input=d_model,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        mode=mode)

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial(decoder_blocks),
        tl.Concatenate(),
        # TODO(kitaev): Test whether dropout should go before or after the
        # LayerNorm, and whether dropout broadcasting is needed here.
        tl.LayerNorm(),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        dense_loss_layer,
    )
Esempio n. 25
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      ff_chunk_size=0,
                      mode='train'):
  """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_embedding, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn(lambda x: np.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dense(d_model),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.ReversibleSerial(decoder_blocks),
      tl.Select([0], n_in=2),
      tl.LayerNorm(),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn(lambda x: np.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.CausalConv(d_embedding),
      tl.Relu(),
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
Esempio n. 26
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               ff_chunk_size=0,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train', 'eval', or 'predict'

  Returns:
    the layer.
  """
  if n_chunks == 0:
    n_chunks = 1
    concatenate_input_chunks = []
  else:
    concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)

  d_emb = d_model
  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  elif axial_pos_shape == 'fixed-base':  # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.FixedBasePositionalEncoding(mode=mode)
    d_emb //= 2
  elif axial_pos_shape == 'infinite':  # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.InfinitePositionalEncoding(affine=False)
  elif axial_pos_shape == 'infinite-affine':
    # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.InfinitePositionalEncoding()
  elif axial_pos_shape == 'time-bin':  # TODO(lukaszkaiser): remove this HACK
    positional_encoding = tl.TimeBinPositionalEncoding()
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_emb, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  return tl.Serial(
      concatenate_input_chunks,
      tl.ShiftRight(mode=mode),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial(decoder_blocks + [
          SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      ]),
      Map([
          # TODO(kitaev): Test whether dropout should go before or after the
          # LayerNorm, and whether dropout broadcasting is needed here.
          tl.LayerNorm(),
          BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ], n_sections=n_chunks),
  )
Esempio n. 27
0
def RelformerLM(vocab_size,
                d_model=512,
                d_ff=2048,
                vanilla_layers=(1, 1),
                shorten_factor=3,
                n_rel_layers=6,
                n_heads=8,
                dropout=0.1,
                dropout_shared_axes=None,
                vanilla_attn_type=tl.LSHSelfAttention,
                pos_type='fixed-base',
                max_len=3072,
                n_raw_tokens_generated=1,
                mode='train',
                ff_activation=tl.FastGelu):
    """Returns a Transformer language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  This model uses only the decoder part of the overall Transformer.

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level
        Transformer decoder layers before and after shortening.
    shorten_factor: by how much to shorten
    n_rel_layers: number of Transformer blocks after the pooling. These blocks
        use relative attention.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    vanilla_attn_type: class: attention class such as SelfAttention to use in
        the layers before and after shortening (vanilla layers).
    pos_type: string, the type of positional embeddings to use.
    max_len: int: maximum symbol length both for positional encoding and it is
      also the maximum length of the possible inference in 'predict' mode
    n_raw_tokens_generated: int: number of tokens generated with every pass
      through model in 'predict' mode. Number of tokens should be smaller and
      divisible by the first shorten factor we are using in the model.
      It cannot be larger than one if we use vanilla layers because we would
      lose autoregressive property of the model.
    mode: str: 'train' or 'eval' or 'predict'.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """

    token_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
    ]

    positional_encoder = PositionalEncoder(mode, dropout, max_len, pos_type)

    n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers

    def create_decoder_blocks(n_layers, total_pooling):  # pylint: disable=invalid-name
        context_bias_layer, location_bias_layer = _get_rel_att_inputs(
            d_model, n_heads)
        decoder_blocks = [
            # pylint: disable=g-complex-comprehension
            _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout,
                                  dropout_shared_axes, mode, ff_activation,
                                  context_bias_layer, location_bias_layer,
                                  total_pooling, max_len)
            for _ in range(n_layers)
        ]
        return decoder_blocks + [tl.LayerNorm()]

    def create_reformer_blocks(n_layers, dense=True):  # pylint: disable=invalid-name
        if n_layers == 0:
            return [tl.LayerNorm()]
        d_per_head = d_model // n_heads
        decoder_blocks = [
            DecoderBlock(
                d_model,
                d_ff,
                d_per_head,
                d_per_head,
                n_heads,  # pylint: disable=g-complex-comprehension
                vanilla_attn_type,
                dropout,
                ff_activation,
                dropout,
                ff_use_sru=0,
                ff_chunk_size=0,
                ff_sparsity=0,
                attention_chunk_size=0,
                mode=mode) for _ in range(n_layers)
        ]

        return [
            tl.Dup(),
            tl.ReversibleSerial(decoder_blocks),
            tl.Concatenate(),
            tl.LayerNorm(),
            tl.Dense(d_model) if dense else [],
        ]

    pre_decoder_blocks = create_reformer_blocks(n_pre_decoder_blocks,
                                                dense=True)

    relative_decoder_blocks = create_decoder_blocks(n_rel_layers,
                                                    shorten_factor)

    conv_layer = tl.Serial(tl.CausalConv(d_model, shorten_factor),
                           ff_activation())

    post_decoder_blocks = create_reformer_blocks(n_post_decoder_blocks,
                                                 dense=False)

    cacher = RelformerCacher(total_kv_pooling=shorten_factor,
                             n_raw_tokens_generated=n_raw_tokens_generated,
                             max_inference_length=max_len,
                             shift=shorten_factor - 1,
                             mode=mode)

    picker = RelformerPicker(total_kv_pooling=shorten_factor,
                             n_raw_tokens_generated=n_raw_tokens_generated,
                             mode=mode)

    cacher_conv = RelformerCacher(
        total_kv_pooling=shorten_factor,
        n_raw_tokens_generated=n_raw_tokens_generated,
        max_inference_length=max_len,
        shift=shorten_factor - 1,
        sliding=True,
        mode=mode)

    picker_conv = PickLastTokenInPredict(mode=mode)

    # Assemble and return the model.
    return tl.Serial(  # tokens (or chunked tuple of tokens)
        tl.ShiftRight(mode=mode),  # toks
        token_encoder,  # vecs
        positional_encoder,
        pre_decoder_blocks,  # vecs
        tl.Dup(),
        cacher,
        tl.ShiftRight(n_positions=shorten_factor - 1, mode=mode),
        _DownsamplerLM(shorten_factor, d_model),
        relative_decoder_blocks,
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        _UpsamplerLM(shorten_factor, d_model),
        tl.LayerNorm(),
        picker,
        tl.Concatenate(),
        cacher_conv,
        conv_layer,
        picker_conv,
        post_decoder_blocks,
        tl.Dense(vocab_size),  # vecs
    )
Esempio n. 28
0
x = np.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)

# %% [markdown]
# ### Concatenate Layer
# Now let's check how to build a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2.

# %% tags=[]
# Create a concatenate trax layer
concat = tl.Concatenate()
print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
Esempio n. 29
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  if n_chunks == 0:
    n_chunks = 1
    concatenate_input_chunks = []
    concatenate_output_chunks = tl.Concatenate(n_items=n_chunks, axis=-2)
  else:
    concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)
    concatenate_output_chunks = []

  positional_embedder = [
      tl.Embedding(d_model, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.PositionalEncoding(max_len=max_len),
  ]
  return tl.Model(
      concatenate_input_chunks,
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),
      tl.ReversibleSerial([
          # pylint: disable=g-complex-comprehension
          DecoderBlock(d_model, d_ff,
                       d_attention_key, d_attention_value, n_heads,
                       n_attention_chunks, attention_type,
                       dropout, share_qk, mode)
          for _ in range(n_layers)
      ] + [
          SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      ]),
      Map([
          # TODO(kitaev): Test whether dropout should go before or after the
          # LayerNorm, and whether dropout broadcasting is needed here.
          tl.LayerNorm(),
          BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ], n_sections=n_chunks),
      concatenate_output_chunks,
  )