예제 #1
0
    def test_funnel_transformer_lm_autoregressive_property(self):
        input_shape = (1, 12)
        d_model = 8
        vocab_size = 26
        rng_1 = jax.random.PRNGKey(0)
        rng_2 = jax.random.PRNGKey(1)

        def _get_output_logits(unitialized_eval_model: tl.Layer, x):
            input_signature = shapes.signature(x)
            unitialized_eval_model.init(input_signature,
                                        rng=rng_1,
                                        use_cache=False)

            output_logits, *_ = unitialized_eval_model(x, rng=rng_1)
            return output_logits

        def test_autoregressive_property(model):
            with fastmath.use_backend(fastmath.Backend.JAX):
                x_1 = jax.random.randint(rng_1, input_shape, 0, vocab_size)
                y_1 = _get_output_logits(model, x_1)

                x_2 = jax.random.randint(rng_2, input_shape, 0, vocab_size)

                for i in range(input_shape[1]):
                    masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]),
                                                axis=1)

                    y_2 = _get_output_logits(model, masked_x_2)
                    self.assertEqual(y_2.shape[0], input_shape[1])
                    np.testing.assert_array_almost_equal(
                        y_1[:i + 1], y_2[:i + 1])

        model_chunked = ft.RelformerLM(
            vocab_size,
            shorten_factor=3,
            n_rel_layers=2,
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=4 * d_model,
            n_heads=2,
            vanilla_attn_type=tl.SelfAttention,
            rel_chunk_len=2,
            vanilla_chunk_len=4,
        )
        test_autoregressive_property(model_chunked)

        model_without_chunks = ft.RelformerLM(
            vocab_size,
            shorten_factor=3,
            n_rel_layers=2,
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=4 * d_model,
            n_heads=2,
            vanilla_attn_type=tl.SelfAttention,
            rel_chunk_len=None,
            vanilla_chunk_len=None,
        )
        test_autoregressive_property(model_without_chunks)
예제 #2
0
    def test_funnel_transformer_lm_predict_eval_equal(self):
        d_model = 8
        vocab_size = 4
        batch_size = 1
        n_len_eval = 21
        attention_type = tl.SelfAttention

        shorten_factor = 3
        n_rel_layers = 1
        vanilla_layers = (1, 1)
        n_heads = 2

        eval_funnel = ft.RelformerLM(vocab_size,
                                     shorten_factor=shorten_factor,
                                     n_rel_layers=n_rel_layers,
                                     vanilla_layers=vanilla_layers,
                                     d_model=d_model,
                                     d_ff=d_model,
                                     n_heads=n_heads,
                                     vanilla_attn_type=attention_type,
                                     mode='eval')

        input_funnel = jax.random.randint(key=jax.random.PRNGKey(0),
                                          minval=0,
                                          maxval=vocab_size,
                                          shape=(batch_size,
                                                 n_len_eval)).astype(np.int32)
        _, _ = eval_funnel.init(shapes.signature(input_funnel),
                                rng=jax.random.PRNGKey(0),
                                use_cache=False)
        y_eval = eval_funnel(input_funnel)
        self.assertEqual(y_eval.shape, (batch_size, n_len_eval, vocab_size))

        if attention_type == tl.SelfAttention:
            gin.bind_parameter('trax.layers.SelfAttention.chunk_len',
                               n_len_eval)

        predict_funnel = ft.RelformerLM(vocab_size,
                                        shorten_factor=shorten_factor,
                                        n_rel_layers=n_rel_layers,
                                        vanilla_layers=vanilla_layers,
                                        d_model=d_model,
                                        d_ff=d_model,
                                        n_heads=n_heads,
                                        vanilla_attn_type=attention_type,
                                        mode='predict')

        input_funnel = np.concatenate(
            [np.zeros((batch_size, 1)).astype(np.int32), input_funnel], axis=1)
        input_funnel = input_funnel[:, :-1]
        _, _ = predict_funnel.init(shapes.signature(input_funnel[:, 0:1]),
                                   rng=jax.random.PRNGKey(0),
                                   use_cache=False)

        for i in range(n_len_eval):
            y = predict_funnel(input_funnel[:, i:i + 1])
            np.testing.assert_array_almost_equal(y,
                                                 y_eval[:, i:i + 1, :],
                                                 decimal=5)
예제 #3
0
    def test_funnel_transformer_lm_forward_shape(self):
        d_model = 8
        vocab_size = 7
        x = np.ones((3, 6)).astype(np.int32)

        simple_funnel = ft.RelformerLM(vocab_size,
                                       shorten_factor=3,
                                       n_rel_layers=1,
                                       vanilla_layers=(1, 1),
                                       d_model=d_model,
                                       d_ff=d_model,
                                       n_heads=2,
                                       vanilla_attn_type=tl.SelfAttention)
        _, _ = simple_funnel.init(shapes.signature(x))
        y = simple_funnel(x)
        self.assertEqual(y.shape, (3, 6, vocab_size))

        multi_stage_funnel = ft.FunnelTransformerLM(vocab_size,
                                                    shorten_factors=(3, 2),
                                                    n_funnel_blocks=(0, 0),
                                                    vanilla_layers=(0, 0),
                                                    d_model=d_model,
                                                    d_ff=d_model,
                                                    n_heads=2)

        _, _ = multi_stage_funnel.init(shapes.signature(x))
        y = multi_stage_funnel(x)
        self.assertEqual(y.shape, (3, 6, vocab_size))
예제 #4
0
    def test_funnel_transformer_lm_forward_shape(self):
        d_model = 16
        vocab_size = 7
        length = 48
        batch_size = 3
        x = np.ones((batch_size, length)).astype(np.int32)

        model_chunked = ft.RelformerLM(vocab_size,
                                       shorten_factor=3,
                                       n_rel_layers=3,
                                       vanilla_layers=(1, 1),
                                       d_model=d_model,
                                       d_ff=d_model,
                                       n_heads=2,
                                       vanilla_attn_type=tl.SelfAttention,
                                       rel_chunk_len=4,
                                       vanilla_chunk_len=2,
                                       max_len=48)
        _, _ = model_chunked.init(shapes.signature(x))
        y = model_chunked(x)
        self.assertEqual(y.shape, (batch_size, length, vocab_size))

        model_without_chunks = ft.RelformerLM(
            vocab_size,
            shorten_factor=3,
            n_rel_layers=3,
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=d_model,
            n_heads=2,
            vanilla_attn_type=tl.SelfAttention,
            max_len=48)

        _, _ = model_without_chunks.init(shapes.signature(x))
        y = model_without_chunks(x)
        self.assertEqual(y.shape, (batch_size, length, vocab_size))
예제 #5
0
 def test_autoregressive_sample_relformerlm(self):
     batch_size = 4
     max_length = 5
     model = ft.RelformerLM(10,
                            d_model=8,
                            d_ff=16,
                            n_rel_layers=1,
                            vanilla_layers=(1, 1),
                            shorten_factor=3,
                            n_heads=2,
                            mode='predict')
     model.init(shapes.ShapeDtype((batch_size, 1), dtype=np.int32))
     s1 = decoding.autoregressive_sample(model,
                                         batch_size=batch_size,
                                         eos_id=-1,
                                         max_length=max_length,
                                         accelerate=False)
     self.assertEqual(s1.shape, (batch_size, max_length))
예제 #6
0
    def test_funnel_transformer_lm_forward_shape_eval(self):
        d_model = 8
        vocab_size = 7
        batch_size = 1
        x = np.zeros((batch_size, 6)).astype(np.int32)
        simple_funnel = ft.RelformerLM(vocab_size,
                                       shorten_factor=3,
                                       n_rel_layers=1,
                                       vanilla_layers=(1, 1),
                                       d_model=d_model,
                                       d_ff=d_model,
                                       n_heads=2,
                                       vanilla_attn_type=tl.SelfAttention,
                                       mode='eval')

        _, _ = simple_funnel.init(shapes.signature(x))
        y = simple_funnel(x)
        self.assertEqual(y.shape, (batch_size, 6, vocab_size))
예제 #7
0
    def test_funnel_transformer_lm_forward_shape_predict(self):
        d_model = 8
        vocab_size = 7
        batch_size = 1
        x = np.ones((batch_size, 1)).astype(np.int32)
        gin.bind_parameter('trax.layers.SelfAttention.chunk_len', 20)
        simple_funnel = ft.RelformerLM(vocab_size,
                                       shorten_factor=3,
                                       n_rel_layers=1,
                                       vanilla_layers=(1, 1),
                                       d_model=d_model,
                                       d_ff=d_model,
                                       n_heads=2,
                                       vanilla_attn_type=tl.SelfAttention,
                                       mode='predict')

        _, _ = simple_funnel.init(shapes.signature(x))

        for _ in range(5):
            y = simple_funnel(x)
            self.assertEqual(y.shape, (batch_size, 1, vocab_size))
        gin.clear_config()
예제 #8
0
    def test_funnel_transformer_lm_forward_shape_predict(self):
        d_model = 8
        vocab_size = 4
        batch_size = 1
        n_len_eval = 42
        attention_type = tl.SelfAttention

        shorten_factor = 3
        n_rel_layers = 2
        vanilla_layers = (1, 1)
        n_heads = 2

        rel_chunk_len, vanilla_chunk_len = 2, 6

        x = np.ones((batch_size, 1)).astype(np.int32)
        gin.bind_parameter('trax.layers.SelfAttention.chunk_len', 20)

        predict_funnel = ft.RelformerLM(vocab_size,
                                        shorten_factor=shorten_factor,
                                        n_rel_layers=n_rel_layers,
                                        vanilla_layers=vanilla_layers,
                                        d_model=d_model,
                                        d_ff=d_model,
                                        n_heads=n_heads,
                                        vanilla_attn_type=attention_type,
                                        rel_chunk_len=rel_chunk_len,
                                        vanilla_chunk_len=vanilla_chunk_len,
                                        max_len=n_len_eval,
                                        mode='predict')

        _, _ = predict_funnel.init(shapes.signature(x))

        for _ in range(5):
            y = predict_funnel(x)
            self.assertEqual(y.shape, (batch_size, 1, vocab_size))
        gin.clear_config()