Exemplo n.º 1
0
  def test_terraformer_doubling(self):
    vocab_size = 2
    max_len = 2

    model = terraformer.ConfigurableTerraformer(
        vocab_size,
        d_model=8,
        d_ff=16,
        n_encoder_layers=1,
        n_decoder_layers=6,
        n_heads=2,
        dropout=0.05,
        max_len=max_len,
        pos_type=None,
        half_before_layer=2,
        double_after_layer=2,
        encoder_attention_type=tl.Attention,
        encoder_decoder_attention_type=tl.CausalAttention,
        mode='train',
    )

    x = [np.ones((1, max_len)).astype(np.int32),
         np.ones((1, max_len)).astype(np.int32)]
    model.init(shapes.signature(x))

    logits, dec_toks = model(x)
    del dec_toks

    self.assertEqual(logits.shape, (1, max_len, vocab_size))
Exemplo n.º 2
0
    def test_run_reversible_same_as_default_terraformer(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4)) + 1
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32)
        input_sig = (int_sig, int_sig, int_sig)
        # We want to test rng propagation too, so adding some dropout layers.
        model = terraformer.ConfigurableTerraformer(20,
                                                    d_model=8,
                                                    d_ff=32,
                                                    n_heads=1,
                                                    dropout=0.0,
                                                    n_encoder_layers=2,
                                                    n_decoder_layers=2,
                                                    ff_sparsity=(4, 8, 0.0,
                                                                 1.0),
                                                    pos_type=None,
                                                    reversible_encoder=True)
        loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss())
        optimizer_fn = optimizers.Adafactor
        blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
            [model, loss], loss_chunk_size=4)
        blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks]
        model_with_loss = tl.Serial(model, loss)
        rng_init = fastmath.random.get_prng(12)
        model_with_loss.init(input_sig, rng=rng_init)

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model_with_loss.weights)
        trainer = optimizers.Trainer(model_with_loss, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_weights = blocks_serial[0][0].weights
        first_rev_weights = blocks[0][1][0].weights
        loss_weights = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model_with_loss.init(input_sig, rng=rng_init)
        trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer,
                                                     optimizer_fn)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_weights, loss_layer.weights)
        self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights)
        self._assert_all_equal(first_weights, blocks_serial[0][0].weights)
Exemplo n.º 3
0
    def test_run_sharded_terraformer(self):
        """Runs Terraformer with sharded weights (only on 2+-device systems)."""
        if fastmath.local_device_count() == 1:
            return
        base.N_WEIGHTS_SHARDS = fastmath.local_device_count()
        inputs_batch = np.arange(8).reshape((2, 4)) + 1
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32)
        input_sig = (int_sig, int_sig, int_sig)
        # We want to test rng propagation too, so adding some dropout layers.
        model = terraformer.ConfigurableTerraformer(
            20,
            d_model=8,
            d_ff=32,
            n_heads=1,
            dropout=0.0,
            n_encoder_layers=2,
            n_decoder_layers=2,
            ff_sparsity=(4, 8, 0.0, 1.0),
            encoder_attention_type=tl.Attention,
            encoder_decoder_attention_type=tl.CausalAttention,
            pos_type=None,
            reversible_encoder=True)
        loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss())
        model_with_loss = tl.Serial(model, loss)
        rng_init = fastmath.random.get_prng(12)
        model_with_loss.init(input_sig, rng=rng_init)

        # Make a step with the trainer.
        optimizer = optimizers.Adafactor(0.01)
        split_w = fastmath.nested_map(
            lambda x: x[0],
            tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS))
        optimizer.tree_init(split_w)
        trainer = optimizers.Trainer(model_with_loss, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        trainer.one_step(labeled_batch, rng_step1)
        # Reset shards back to default.
        base.N_WEIGHTS_SHARDS = 1
Exemplo n.º 4
0
    def test_terraformer_quick(self, backend, encoder_attention_type,
                               preembed):
        with fastmath.use_backend(backend):
            vocab_size = 2
            input_vocab_size = None if preembed else vocab_size
            output_vocab_size = vocab_size if preembed else None
            max_len = 2

            model = terraformer.ConfigurableTerraformer(
                input_vocab_size,
                d_model=4,
                d_ff=4,
                n_encoder_layers=1,
                n_decoder_layers=1,
                n_heads=2,
                dropout=0.05,
                max_len=max_len,
                pos_type=None,
                ff_activation=tl.Relu,
                ff_use_sru=0,
                ff_chunk_size=2,
                mode='train',
                output_vocab_size=output_vocab_size,
                encoder_attention_type=encoder_attention_type,
            )

            if preembed:
                model_inputs = [
                    np.ones((1, max_len, 3)).astype(np.float32),
                    np.ones((1, max_len)).astype(np.bool)
                ]
            else:
                model_inputs = [np.ones((1, max_len)).astype(np.int32)]
            x = model_inputs + [np.ones((1, max_len)).astype(np.int32)]
            model.init(shapes.signature(x))

            logits, dec_toks = model(x)
            del dec_toks

            self.assertEqual(logits.shape, (1, max_len, vocab_size))
Exemplo n.º 5
0
    def test_terraformer_one_step(self):
        d_model = 1024
        vocab_size = 14041
        max_len = 16384
        pos_axial = (128, 128)  # should multiply to max_len
        pos_d_axial_embs = (512, 512)  # sum to d model

        assert operator.mul(*pos_axial) == max_len
        assert sum(pos_d_axial_embs) == d_model

        d_ff = 4096
        n_heads = 8
        d_attn = d_model // n_heads

        n_buckets = 128
        encoder_chunk_len = (2 * max_len) // n_buckets  # 256
        decoder_chunk_len = 2 * encoder_chunk_len  # 512
        encoder_n_chunks_after = 1  # since its not causal.

        lsh_self_attention = functools.partial(self._lsh_self_attention_fn(),
                                               n_buckets=n_buckets)

        encoder_lsh_self_attention = functools.partial(
            lsh_self_attention,
            n_chunks_after=encoder_n_chunks_after,
            chunk_len=encoder_chunk_len)

        decoder_lsh_self_attention = functools.partial(
            lsh_self_attention, n_chunks_after=0, chunk_len=decoder_chunk_len)

        model = terraformer.ConfigurableTerraformer(
            vocab_size,
            d_model=d_model,
            d_ff=d_ff,
            d_attention_key=d_attn,
            d_attention_value=d_attn,
            n_encoder_layers=1,
            n_decoder_layers=1,
            n_heads=n_heads,
            dropout=0.05,
            max_len=max_len,
            encoder_attention_type=encoder_lsh_self_attention,
            encoder_decoder_attention_type=decoder_lsh_self_attention,
            pos_axial_shape=pos_axial,
            pos_d_axial_embs=pos_d_axial_embs,
            ff_activation=tl.Relu,
            ff_use_sru=0,
            mode='train',
        )

        def random_sentence():
            return np.random.randint(low=1,
                                     high=vocab_size - 1,
                                     size=(1, max_len),
                                     dtype=np.int32)

        x = [random_sentence(), random_sentence()]
        weights, state = model.init(shapes.signature(x))

        @fastmath.jit
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits_and_dec_toks, new_state = model.pure_fn(
                    x, weights, state, rng)
                # This returns [logits, decoder tokens]
                logits = logits_and_dec_toks[0]
                loss = fastmath.numpy.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state,
                        logits) = fastmath.grad(compute_mock_loss,
                                                has_aux=True)(weights)
            new_weights = fastmath.nested_map_multiarg(
                lambda w, g: w - 1e-4 * g, weights, gradients)
            return new_weights, new_state, logits

        weights, state, logits = mock_training_step(
            x, weights, state, fastmath.random.get_prng(0))

        self.assertEqual(logits.shape, (1, max_len, vocab_size))
Exemplo n.º 6
0
  def test_terraformer_one_step(self):
    vocab_size = 32
    max_len = 256
    pos_axial = 16
    assert pos_axial * pos_axial == max_len

    chunk_len = 32

    # Since 2 * chunk_len * n_buckets should be max_len.
    n_buckets = max_len // (2 * chunk_len)

    lsh_self_attention = functools.partial(self._lsh_self_attention_fn(),
                                           chunk_len=chunk_len,
                                           n_buckets=n_buckets)

    timebin_self_attention = self._timebin_self_attention_fn()

    model = terraformer.ConfigurableTerraformer(
        vocab_size,
        d_model=32,
        d_ff=64,
        d_attention_key=64,
        d_attention_value=64,
        n_encoder_layers=2,
        n_decoder_layers=2,
        n_heads=2,
        dropout=0.05,
        max_len=max_len,
        encoder_attention_type=lsh_self_attention,
        encoder_decoder_attention_type=[timebin_self_attention,
                                        lsh_self_attention],
        pos_axial_shape=(pos_axial, pos_axial),
        pos_d_axial_embs=(64, 192),
        ff_activation=tl.Relu,
        ff_use_sru=0,
        ff_chunk_size=64,
        ff_sparsity=8,
        mode='train',
    )

    x = [np.ones((1, max_len)).astype(np.int32),
         np.ones((1, max_len)).astype(np.int32)]
    weights, state = model.init(shapes.signature(x))

    @fastmath.jit
    def mock_training_step(x, weights, state, rng):
      def compute_mock_loss(weights):
        logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng)
        # This returns [logits, decoder tokens]
        logits = logits_and_dec_toks[0]
        loss = fastmath.numpy.mean(logits[..., 0])
        return loss, (new_state, logits)
      gradients, (new_state, logits) = fastmath.grad(
          compute_mock_loss, has_aux=True)(weights)
      new_weights = fastmath.nested_map_multiarg(
          lambda w, g: w - 1e-4 * g, weights, gradients)
      return new_weights, new_state, logits

    weights, state, logits = mock_training_step(
        x, weights, state, fastmath.random.get_prng(0))

    self.assertEqual(logits.shape, (1, max_len, vocab_size))