Beispiel #1
0
    def test_run_reversible_same_as_default_extended(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        # We want to test rng propagation too, so adding some dropout layers.
        first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup())
        rev_layers1 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)),
            tl.ReversibleSwap(),
            tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)),
            tl.ReversibleSwap()
        ]
        mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup())
        rev_layers2 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)),
            tl.ReversibleSwap()
        ]
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3),
                               tl.LogSoftmax(), tl.CrossEntropyLoss())
        model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] +
                          rev_layers2 + [loss_layer])
        rng_init = fastmath.random.get_prng(12)
        model.init(labeled_batch, rng=rng_init)
        optimizer_fn = optimizers.Adam  # to test slots

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model.weights)
        trainer = optimizers.Trainer(model, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_layer_weights1 = first_layer.weights
        rev_layer12_weights1 = rev_layers1[2].weights
        mid_layer_weights1 = mid_layer.weights
        rev_layer20_weights1 = rev_layers2[0].weights
        loss_layer_weights1 = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model.init(labeled_batch, rng=rng_init)
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer.sublayers, rev_layers1),
             (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_layer_weights1, loss_layer.weights)
        self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights)
        self._assert_all_equal(mid_layer_weights1, mid_layer.weights)
        self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights)
        self._assert_all_equal(first_layer_weights1, first_layer.weights)
Beispiel #2
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, attention_type, dropout, ff_activation,
                 ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,
                 attention_chunk_size, n_attention_layers=1,
                 n_feedforward_layers=1, center_layernorm=True,
                 use_bfloat16=False, mode='train'):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_attention_layers: how many residual causal attention layers should we
      have before the feed-forward block (default: 1, the standard block)
    n_feedforward_layers: how many FFNN layers should we have (default 1).
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    mode: str: 'train' or 'eval'


  Returns:
    the layer.
  """
  # pylint: disable=g-complex-comprehension
  attention_half_residuals = [
      [tl.ReversibleHalfResidual(
          tl.LayerNorm(center=center_layernorm),
          attention_layer=ct.ApplyAttentionLayer(
              attention_type, d_model, n_heads, d_attention_key,
              d_attention_value, True, False, dropout, dropout,
              attention_chunk_size, mode),
          name='ReversibleHalfResidualDecoderAttn'),
       tl.ReversibleSwap()
      ] for _ in range(n_attention_layers)]

  feed_forwards = [
      [tl.ReversibleHalfResidual(
          ct.FeedForwardWithOptions(
              d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
              ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,
              mode, use_bfloat16),
          name='ReversibleHalfResidualDecoderFF'),
       tl.ReversibleSwap()
      ] for _ in range(n_feedforward_layers)]
  # pylint: enable=g-complex-comprehension
  return attention_half_residuals + feed_forwards
Beispiel #3
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                        ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0,
                        ff_sparsity=0):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate for feed-forward layer
    mode: str: 'train' or 'eval'
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    the layer.
  """
  enc_dec_attention = tl.EncDecAttention(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  enc_dec_attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=enc_dec_attention,
  )

  causal_attention = tl.SelfAttention(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      causal=True,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  causal_attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=causal_attention,
  )

  feed_forward = ct.FeedForwardWithOptions(
      d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
      ff_chunk_size, ff_use_sru, ff_sparsity, mode)

  return [                             # vec_d1 vec_d2 vec_e masks
      causal_attention_half_residual,
      tl.ReversibleSwap(),
      enc_dec_attention_half_residual,
      tl.ReversibleSwap(),
      tl.ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
Beispiel #4
0
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation,
                 ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0,
                 mode='train'):
  """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
  if mode == 'predict':
    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    mode = 'eval'

  attention = attention_type(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      masked=True, causal=False,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=attention,
  )

  feed_forward = FeedForwardWithOptions(
      d_model, d_ff, dropout, ff_activation, ff_dropout,
      ff_chunk_size, ff_use_sru, ff_sparsity, mode)

  return [
      attention_half_residual,
      tl.ReversibleSwap(),
      tl.ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
Beispiel #5
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                        ff_dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate for feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    enc_dec_attention = tl.EncDecAttention(n_heads=n_heads,
                                           d_qk=d_model // n_heads,
                                           d_v=d_model // n_heads,
                                           attention_dropout=dropout,
                                           output_dropout=dropout,
                                           mode=mode)
    enc_dec_attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=enc_dec_attention,
    )

    causal_attention = tl.SelfAttention(n_heads=n_heads,
                                        d_qk=d_model // n_heads,
                                        d_v=d_model // n_heads,
                                        causal=True,
                                        attention_dropout=dropout,
                                        output_dropout=dropout,
                                        mode=mode)
    causal_attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=causal_attention,
    )

    feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation,
                               ff_dropout, mode)

    return [  # vec_d1 vec_d2 vec_e masks
        causal_attention_half_residual,
        tl.ReversibleSwap(),
        enc_dec_attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Beispiel #6
0
 def _attention_half_residual():
   return [
       tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm),
                                 attention_layer=_Attn(),
                                 name='ReversibleHalfResidualDecoderAttn'),
       tl.ReversibleSwap()
   ]
Beispiel #7
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 attention_type, dropout, ff_activation, ff_use_sru,
                 ff_chunk_size, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    attention = attention_type(n_heads=n_heads,
                               d_qk=d_attention_key,
                               d_v=d_attention_value,
                               causal=True,
                               output_dropout=dropout,
                               mode=mode)
    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    if ff_use_sru:
        feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)]
    else:
        feed_forward = [
            ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, dropout,
                               ff_chunk_size, mode)
        ]

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Beispiel #8
0
 def _feed_forward():
     layers = [
         tl.ReversibleHalfResidual(_FF(),
                                   name='ReversibleHalfResidualEncoderFF')
     ]
     if use_two_swaps_per_block:
         layers.append(tl.ReversibleSwap())
     return layers
Beispiel #9
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 attention_type, dropout, ff_activation, ff_dropout,
                 ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size,
                 mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train' or 'eval'


  Returns:
    the layer.
  """
    attention = ct.ApplyAttentionLayer(attention_type, d_model, n_heads,
                                       d_attention_key, d_attention_value,
                                       True, False, dropout, dropout,
                                       attention_chunk_size, mode)
    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, mode)

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Beispiel #10
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, attention_type, dropout, ff_activation,
                 ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, mode):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  # TODO(lukaszkaiser): unify attention layers API and remove this branch
  try:
    attention = attention_type(
        n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value,
        causal=True, output_dropout=dropout, mode=mode)
  except TypeError:  # No d_qk arguments in less advanced layers.
    attention = attention_type(d_model, n_heads=n_heads,
                               dropout=dropout, mode=mode)
  attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=attention,
  )

  feed_forward = FeedForwardWithOptions(
      d_model, d_ff, dropout, ff_activation, ff_dropout,
      ff_chunk_size, ff_use_sru, ff_sparsity, mode)

  return [
      attention_half_residual,
      tl.ReversibleSwap(),
      tl.ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
Beispiel #11
0
 def test_run_reversible_slots(self):
   """Tests that slots can be read and assigned in reversible trainer."""
   layers = [tl.Dense(4), tl.Dup()]
   rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)),
                 tl.ReversibleSwap()]
   loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4),
                          tl.LogSoftmax(), tl.CrossEntropyLoss())
   trainer = optimizers.ReversibleSerialTrainer(
       [(layers, rev_layers)], loss_layer, optimizers.Adam)
   slots = trainer.slots
   trainer.slots = slots
   self.assertEqual(slots, trainer.slots)
Beispiel #12
0
    def test_run_reversible_large_weights(self):
        """Runs the reversible trainer with a lot of weights to test memory use."""
        # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.global_device_count() == 1 and ram_limited:
            return

        # Create inputs and rngs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 18  # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram
        rev_layers = []
        int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
        shape = shapes.ShapeDtype((2, 4, 16 * 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.Adafactor

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # Set to true to run again, e.g., for profiling.
        run_twice = False
        if run_twice:
            t = time.time()
            loss, _ = trainer.one_step(labeled_batch, rng_step)
            self.assertLess(float(loss.sum()),
                            10000.0)  # Just to get the loss.
            print('Took %.3f seconds to run, loss %s' %
                  (time.time() - t, loss))
Beispiel #13
0
    def test_run_reversible_weights_trainsfer_xprof(self):
        """Runs the reversible trainer and profiles weight transfer stats."""
        run_this_test = False  # We only run this test manually.
        if not run_this_test or fastmath.global_device_count(
        ) == 1:  # TPU only
            return

        # Create inputs and rngs.
        inputs_batch = np.ones((1024, 128), dtype=np.int32)
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup())
        rng_init = fastmath.random.get_prng(12)
        rng_step = fastmath.random.get_prng(13)

        # Initialize layers.
        first_layer.init(labeled_batch, rng=rng_init)
        n_layers = 6
        rev_layers = []
        int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32)
        shape = shapes.ShapeDtype((1024, 128, 1024))
        sig = (shape, shape)
        for _ in range(n_layers):
            layer = tl.ReversibleHalfResidual(tl.Dense(1024))
            layer.init(sig, rng=rng_init)
            layer.weights = tl.on_cpu(
                layer.weights)  # store weights in cpu memory
            rev_layers.append(layer)
            rev_layers.append(tl.ReversibleSwap())
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(),
                               tl.CrossEntropyLoss())
        loss_layer.init((shape, shape, int_shape, int_shape))
        optimizer_fn = optimizers.SGD

        # Make a step with reversible trainer.
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer, rev_layers)], loss_layer, optimizer_fn)
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        # We profile here.
        t = time.time()
        loss, _ = trainer.one_step(labeled_batch, rng_step)
        self.assertLess(float(loss.sum()), 10000.0)  # Just to get the loss.
        print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
Beispiel #14
0
    def test_train_memory_efficient(self):
        """Trains a large network in a memory-efficient way."""
        # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU
        # and CPU when you run it locally, but it's too big for unit-testing.
        ram_limited = True  # Set to False to run this test locally.
        if fastmath.device_count() == 1 and ram_limited:
            return

        # Create the model.
        n_layers = 16  # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram
        model = tl.Serial(
            tl.Embedding(9, 16 * 1024),
            tl.Dup(),
            [[
                tl.ReversibleHalfResidual(tl.Dense(16 * 1024)),
                tl.ReversibleSwap()
            ] for _ in range(n_layers)],
            tl.Concatenate(),
            tl.Dense(9),
        )

        # Create inputs.
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))

        def _data_gen():
            while True:
                yield labeled_batch

        # Run training.
        loss_layer = tl.WeightedCategoryCrossEntropy()
        task = training.TrainTask(_data_gen(), loss_layer,
                                  optimizers.Adafactor)
        eval_task = training.EvalTask(_data_gen(),
                                      [tl.WeightedCategoryCrossEntropy()])
        loop = training.Loop(model, [task],
                             eval_tasks=[eval_task],
                             eval_at=lambda step_n: step_n == 2,
                             use_memory_efficient_trainer=True)
        self.assertEqual(0, loop.step)
        loop.run(n_steps=2)
        self.assertEqual(2, loop.step)
Beispiel #15
0
  def test_run_reversible_large_weights(self):
    """Runs the reversible trainer with a lot of weights to test memory use."""
    # This test requires > 20GB RAM, only run on TPUs. It does pass on GPU
    # and CPU when you run it locally, but it's too big for unit-testing.
    ram_limited = True  # Set to False to run this test locally.
    if fastmath.device_count() == 1 and ram_limited:
      return

    # Create inputs and rngs.
    inputs_batch = np.arange(8).reshape((2, 4))
    targets_batch = inputs_batch
    labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
    first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup())
    rng_init = fastmath.random.get_prng(12)
    rng_step = fastmath.random.get_prng(13)

    # Initialize layers.
    first_layer.init(labeled_batch, rng=rng_init)
    n_layers = 20  # 20 layers each 16K x 16K = 256M weights ~= 1GB, 20GB ram
    rev_layers = []
    int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32)
    shape = shapes.ShapeDtype((2, 4, 16*1024))
    sig = (shape, shape)
    for _ in range(n_layers):
      layer = tl.ReversibleHalfResidual(tl.Dense(16*1024))
      layer.init(sig, rng=rng_init)
      layer.weights = tl.on_cpu(layer.weights)  # store weights in cpu memory
      rev_layers.append(layer)
      rev_layers.append(tl.ReversibleSwap())
    loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9),
                           tl.LogSoftmax(), tl.CrossEntropyLoss())
    loss_layer.init((shape, shape, int_shape, int_shape))
    optimizer_fn = optimizers.Adafactor

    # Make a step with reversible trainer.
    trainer = optimizers.ReversibleSerialTrainer(
        first_layer, rev_layers, loss_layer, optimizer_fn)
    trainer.one_step(labeled_batch, rng_step)
Beispiel #16
0
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 attention_chunk_size=0,
                 use_bfloat16=False,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    attention = ct.ApplyAttentionLayer(
        attention_type=attention_type,
        d_model=d_model,
        n_heads=n_heads,
        d_qk=d_model // n_heads,
        d_v=d_model // n_heads,
        masked=True,
        causal=False,
        attention_dropout=dropout,
        output_dropout=dropout,
        attention_chunk_size=attention_chunk_size,
        mode=mode)
    # TODO(lukaszkaiser): refactor efficient attention layers to unify the API
    # If we're using standard attention, we need to pass reshaped mask and not
    # return the mask to be compatible with the EfficientAttention API.
    if attention.n_out == 2:

        def reshape_mask(mask):
            return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))

        attention = tl.Serial(
            tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2),
            attention, tl.Select([0], n_in=2))

    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, mode, use_bfloat16)

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Beispiel #17
0
 def _feed_forward():
     return [
         tl.ReversibleHalfResidual(_FF(),
                                   name='ReversibleHalfResidualDecoderFF'),
         tl.ReversibleSwap()
     ]