Esempio n. 1
0
    def test_reformer2_doubling(self):
        vocab_size = 2
        max_len = 2

        model = reformer.Reformer2(
            vocab_size,
            d_model=8,
            d_ff=16,
            n_encoder_layers=1,
            n_decoder_layers=6,
            n_heads=2,
            dropout=0.05,
            max_len=max_len,
            pos_type=None,
            half_before_layer=2,
            double_after_layer=2,
            encoder_attention_type=tl.Attention,
            encoder_decoder_attention_type=tl.CausalAttention,
            mode='train',
        )

        x = [
            np.ones((1, max_len)).astype(np.int32),
            np.ones((1, max_len)).astype(np.int32)
        ]
        model.init(shapes.signature(x))

        logits, dec_toks = model(x)
        del dec_toks

        self.assertEqual(logits.shape, (1, max_len, vocab_size))
Esempio n. 2
0
    def test_reformer2_quick(self, backend):
        with fastmath.use_backend(backend):
            vocab_size = 2
            max_len = 2

            model = reformer.Reformer2(
                vocab_size,
                d_model=4,
                d_ff=4,
                n_encoder_layers=1,
                n_decoder_layers=1,
                n_heads=2,
                dropout=0.05,
                max_len=max_len,
                pos_type=None,
                ff_activation=tl.Relu,
                ff_use_sru=0,
                ff_chunk_size=2,
                mode='train',
            )

            x = [
                np.ones((1, max_len)).astype(np.int32),
                np.ones((1, max_len)).astype(np.int32)
            ]
            model.init(shapes.signature(x))

            logits, dec_toks = model(x)
            del dec_toks

            self.assertEqual(logits.shape, (1, max_len, vocab_size))
Esempio n. 3
0
    def test_run_reversible_same_as_default_reformer2(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4)) + 1
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32)
        input_sig = (int_sig, int_sig, int_sig)
        # We want to test rng propagation too, so adding some dropout layers.
        model = reformer.Reformer2(20,
                                   d_model=8,
                                   d_ff=32,
                                   n_heads=1,
                                   dropout=0.0,
                                   n_encoder_layers=2,
                                   n_decoder_layers=2,
                                   ff_sparsity=(4, 8, 0.0, 1.0),
                                   pos_type=None,
                                   reversible_encoder=True)
        loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss())
        optimizer_fn = optimizers.Adafactor
        blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
            [model, loss], loss_chunk_size=4)
        blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks]
        model_with_loss = tl.Serial(model, loss)
        rng_init = fastmath.random.get_prng(12)
        model_with_loss.init(input_sig, rng=rng_init)

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model_with_loss.weights)
        trainer = optimizers.Trainer(model_with_loss, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_weights = blocks_serial[0][0].weights
        first_rev_weights = blocks[0][1][0].weights
        loss_weights = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model_with_loss.init(input_sig, rng=rng_init)
        trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer,
                                                     optimizer_fn)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_weights, loss_layer.weights)
        self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights)
        self._assert_all_equal(first_weights, blocks_serial[0][0].weights)
Esempio n. 4
0
    def test_run_sharded_reformer2(self):
        """Runs Reformer2 with sharded weights (only on 2+-device systems)."""
        if fastmath.local_device_count() == 1:
            return
        base.N_WEIGHTS_SHARDS = fastmath.local_device_count()
        inputs_batch = np.arange(8).reshape((2, 4)) + 1
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32)
        input_sig = (int_sig, int_sig, int_sig)
        # We want to test rng propagation too, so adding some dropout layers.
        model = reformer.Reformer2(
            20,
            d_model=8,
            d_ff=32,
            n_heads=1,
            dropout=0.0,
            n_encoder_layers=2,
            n_decoder_layers=2,
            ff_sparsity=(4, 8, 0.0, 1.0),
            encoder_attention_type=tl.Attention,
            encoder_decoder_attention_type=tl.CausalAttention,
            pos_type=None,
            reversible_encoder=True)
        loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss())
        model_with_loss = tl.Serial(model, loss)
        rng_init = fastmath.random.get_prng(12)
        model_with_loss.init(input_sig, rng=rng_init)

        # Make a step with the trainer.
        optimizer = optimizers.Adafactor(0.01)
        split_w = fastmath.nested_map(
            lambda x: x[0],
            tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS))
        optimizer.tree_init(split_w)
        trainer = optimizers.Trainer(model_with_loss, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        trainer.one_step(labeled_batch, rng_step1)
        # Reset shards back to default.
        base.N_WEIGHTS_SHARDS = 1
Esempio n. 5
0
    def test_reformer2_quick(self, backend, encoder_attention_type, preembed):
        with fastmath.use_backend(backend):
            vocab_size = 2
            input_vocab_size = None if preembed else vocab_size
            output_vocab_size = vocab_size if preembed else None
            max_len = 2

            model = reformer.Reformer2(
                input_vocab_size,
                d_model=4,
                d_ff=4,
                n_encoder_layers=1,
                n_decoder_layers=1,
                n_heads=2,
                dropout=0.05,
                max_len=max_len,
                pos_type=None,
                ff_activation=tl.Relu,
                ff_use_sru=0,
                ff_chunk_size=2,
                mode='train',
                output_vocab_size=output_vocab_size,
                encoder_attention_type=encoder_attention_type,
            )

            if preembed:
                model_inputs = [
                    np.ones((1, max_len, 3)).astype(np.float32),
                    np.ones((1, max_len)).astype(np.bool)
                ]
            else:
                model_inputs = [np.ones((1, max_len)).astype(np.int32)]
            x = model_inputs + [np.ones((1, max_len)).astype(np.int32)]
            model.init(shapes.signature(x))

            logits, dec_toks = model(x)
            del dec_toks

            self.assertEqual(logits.shape, (1, max_len, vocab_size))
Esempio n. 6
0
    def test_reformer2_one_step(self):
        vocab_size = 32
        max_len = 256
        pos_axial = 16
        assert pos_axial * pos_axial == max_len

        chunk_len = 32

        # Since 2 * chunk_len * n_buckets should be max_len.
        n_buckets = max_len // (2 * chunk_len)

        lsh_self_attention = functools.partial(self._lsh_self_attention_fn(),
                                               chunk_len=chunk_len,
                                               n_buckets=n_buckets)

        timebin_self_attention = self._timebin_self_attention_fn()

        model = reformer.Reformer2(
            vocab_size,
            d_model=32,
            d_ff=64,
            d_attention_key=64,
            d_attention_value=64,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_heads=2,
            dropout=0.05,
            max_len=max_len,
            encoder_attention_type=lsh_self_attention,
            encoder_decoder_attention_type=[
                timebin_self_attention, lsh_self_attention
            ],
            pos_axial_shape=(pos_axial, pos_axial),
            pos_d_axial_embs=(64, 192),
            ff_activation=tl.Relu,
            ff_use_sru=0,
            ff_chunk_size=64,
            ff_sparsity=8,
            mode='train',
        )

        x = [
            np.ones((1, max_len)).astype(np.int32),
            np.ones((1, max_len)).astype(np.int32)
        ]
        weights, state = model.init(shapes.signature(x))

        @fastmath.jit
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits_and_dec_toks, new_state = model.pure_fn(
                    x, weights, state, rng)
                # This returns [logits, decoder tokens]
                logits = logits_and_dec_toks[0]
                loss = fastmath.numpy.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state,
                        logits) = fastmath.grad(compute_mock_loss,
                                                has_aux=True)(weights)
            new_weights = fastmath.nested_map_multiarg(
                lambda w, g: w - 1e-4 * g, weights, gradients)
            return new_weights, new_state, logits

        weights, state, logits = mock_training_step(
            x, weights, state, fastmath.random.get_prng(0))

        self.assertEqual(logits.shape, (1, max_len, vocab_size))
Esempio n. 7
0
  def test_reformer2_one_step(self):
    d_model = 1024
    vocab_size = 14041
    max_len = 16384
    pos_axial = (128, 128)  # should multiply to max_len
    pos_d_axial_embs = (512, 512)  # sum to d model

    assert operator.mul(*pos_axial) == max_len
    assert sum(pos_d_axial_embs) == d_model

    d_ff = 4096
    n_heads = 8
    d_attn = d_model // n_heads

    n_buckets = 128
    encoder_chunk_len = (2 * max_len) // n_buckets  # 256
    decoder_chunk_len = 2 * encoder_chunk_len       # 512
    encoder_n_chunks_after = 1                      # since its not causal.

    lsh_self_attention = functools.partial(self._lsh_self_attention_fn(),
                                           n_buckets=n_buckets)

    encoder_lsh_self_attention = functools.partial(
        lsh_self_attention, n_chunks_after=encoder_n_chunks_after,
        chunk_len=encoder_chunk_len)

    decoder_lsh_self_attention = functools.partial(
        lsh_self_attention, n_chunks_after=0,
        chunk_len=decoder_chunk_len)

    model = reformer.Reformer2(
        vocab_size,
        d_model=d_model,
        d_ff=d_ff,
        d_attention_key=d_attn,
        d_attention_value=d_attn,
        n_encoder_layers=1,
        n_decoder_layers=1,
        n_heads=n_heads,
        dropout=0.05,
        max_len=max_len,
        encoder_attention_type=encoder_lsh_self_attention,
        encoder_decoder_attention_type=decoder_lsh_self_attention,
        pos_axial_shape=pos_axial,
        pos_d_axial_embs=pos_d_axial_embs,
        ff_activation=tl.Relu,
        ff_use_sru=0,
        mode='train',
    )

    def random_sentence():
      return np.random.randint(low=1, high=vocab_size - 1, size=(1, max_len),
                               dtype=np.int32)

    x = [random_sentence(), random_sentence()]
    weights, state = model.init(shapes.signature(x))

    @fastmath.jit
    def mock_training_step(x, weights, state, rng):
      def compute_mock_loss(weights):
        logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng)
        # This returns [logits, decoder tokens]
        logits = logits_and_dec_toks[0]
        loss = fastmath.numpy.mean(logits[..., 0])
        return loss, (new_state, logits)
      gradients, (new_state, logits) = fastmath.grad(
          compute_mock_loss, has_aux=True)(weights)
      new_weights = fastmath.nested_map_multiarg(
          lambda w, g: w - 1e-4 * g, weights, gradients)
      return new_weights, new_state, logits

    weights, state, logits = mock_training_step(
        x, weights, state, fastmath.random.get_prng(0))

    self.assertEqual(logits.shape, (1, max_len, vocab_size))