def test_funnel_transformer_lm_autoregressive_property(self): input_shape = (1, 12) d_model = 8 vocab_size = 26 rng_1 = jax.random.PRNGKey(0) rng_2 = jax.random.PRNGKey(1) def _get_output_logits(unitialized_eval_model: tl.Layer, x): input_signature = shapes.signature(x) unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) output_logits, *_ = unitialized_eval_model(x, rng=rng_1) return output_logits def test_autoregressive_property(model): with fastmath.use_backend(fastmath.Backend.JAX): x_1 = jax.random.randint(rng_1, input_shape, 0, vocab_size) y_1 = _get_output_logits(model, x_1) x_2 = jax.random.randint(rng_2, input_shape, 0, vocab_size) for i in range(input_shape[1]): masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) y_2 = _get_output_logits(model, masked_x_2) self.assertEqual(y_2.shape[0], input_shape[1]) np.testing.assert_array_almost_equal( y_1[:i + 1], y_2[:i + 1]) model_chunked = ft.RelformerLM( vocab_size, shorten_factor=3, n_rel_layers=2, vanilla_layers=(1, 1), d_model=d_model, d_ff=4 * d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention, rel_chunk_len=2, vanilla_chunk_len=4, ) test_autoregressive_property(model_chunked) model_without_chunks = ft.RelformerLM( vocab_size, shorten_factor=3, n_rel_layers=2, vanilla_layers=(1, 1), d_model=d_model, d_ff=4 * d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention, rel_chunk_len=None, vanilla_chunk_len=None, ) test_autoregressive_property(model_without_chunks)
def test_funnel_transformer_lm_predict_eval_equal(self): d_model = 8 vocab_size = 4 batch_size = 1 n_len_eval = 21 attention_type = tl.SelfAttention shorten_factor = 3 n_rel_layers = 1 vanilla_layers = (1, 1) n_heads = 2 eval_funnel = ft.RelformerLM(vocab_size, shorten_factor=shorten_factor, n_rel_layers=n_rel_layers, vanilla_layers=vanilla_layers, d_model=d_model, d_ff=d_model, n_heads=n_heads, vanilla_attn_type=attention_type, mode='eval') input_funnel = jax.random.randint(key=jax.random.PRNGKey(0), minval=0, maxval=vocab_size, shape=(batch_size, n_len_eval)).astype(np.int32) _, _ = eval_funnel.init(shapes.signature(input_funnel), rng=jax.random.PRNGKey(0), use_cache=False) y_eval = eval_funnel(input_funnel) self.assertEqual(y_eval.shape, (batch_size, n_len_eval, vocab_size)) if attention_type == tl.SelfAttention: gin.bind_parameter('trax.layers.SelfAttention.chunk_len', n_len_eval) predict_funnel = ft.RelformerLM(vocab_size, shorten_factor=shorten_factor, n_rel_layers=n_rel_layers, vanilla_layers=vanilla_layers, d_model=d_model, d_ff=d_model, n_heads=n_heads, vanilla_attn_type=attention_type, mode='predict') input_funnel = np.concatenate( [np.zeros((batch_size, 1)).astype(np.int32), input_funnel], axis=1) input_funnel = input_funnel[:, :-1] _, _ = predict_funnel.init(shapes.signature(input_funnel[:, 0:1]), rng=jax.random.PRNGKey(0), use_cache=False) for i in range(n_len_eval): y = predict_funnel(input_funnel[:, i:i + 1]) np.testing.assert_array_almost_equal(y, y_eval[:, i:i + 1, :], decimal=5)
def test_funnel_transformer_lm_forward_shape(self): d_model = 8 vocab_size = 7 x = np.ones((3, 6)).astype(np.int32) simple_funnel = ft.RelformerLM(vocab_size, shorten_factor=3, n_rel_layers=1, vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention) _, _ = simple_funnel.init(shapes.signature(x)) y = simple_funnel(x) self.assertEqual(y.shape, (3, 6, vocab_size)) multi_stage_funnel = ft.FunnelTransformerLM(vocab_size, shorten_factors=(3, 2), n_funnel_blocks=(0, 0), vanilla_layers=(0, 0), d_model=d_model, d_ff=d_model, n_heads=2) _, _ = multi_stage_funnel.init(shapes.signature(x)) y = multi_stage_funnel(x) self.assertEqual(y.shape, (3, 6, vocab_size))
def test_funnel_transformer_lm_forward_shape(self): d_model = 16 vocab_size = 7 length = 48 batch_size = 3 x = np.ones((batch_size, length)).astype(np.int32) model_chunked = ft.RelformerLM(vocab_size, shorten_factor=3, n_rel_layers=3, vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention, rel_chunk_len=4, vanilla_chunk_len=2, max_len=48) _, _ = model_chunked.init(shapes.signature(x)) y = model_chunked(x) self.assertEqual(y.shape, (batch_size, length, vocab_size)) model_without_chunks = ft.RelformerLM( vocab_size, shorten_factor=3, n_rel_layers=3, vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention, max_len=48) _, _ = model_without_chunks.init(shapes.signature(x)) y = model_without_chunks(x) self.assertEqual(y.shape, (batch_size, length, vocab_size))
def test_autoregressive_sample_relformerlm(self): batch_size = 4 max_length = 5 model = ft.RelformerLM(10, d_model=8, d_ff=16, n_rel_layers=1, vanilla_layers=(1, 1), shorten_factor=3, n_heads=2, mode='predict') model.init(shapes.ShapeDtype((batch_size, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=batch_size, eos_id=-1, max_length=max_length, accelerate=False) self.assertEqual(s1.shape, (batch_size, max_length))
def test_funnel_transformer_lm_forward_shape_eval(self): d_model = 8 vocab_size = 7 batch_size = 1 x = np.zeros((batch_size, 6)).astype(np.int32) simple_funnel = ft.RelformerLM(vocab_size, shorten_factor=3, n_rel_layers=1, vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention, mode='eval') _, _ = simple_funnel.init(shapes.signature(x)) y = simple_funnel(x) self.assertEqual(y.shape, (batch_size, 6, vocab_size))
def test_funnel_transformer_lm_forward_shape_predict(self): d_model = 8 vocab_size = 7 batch_size = 1 x = np.ones((batch_size, 1)).astype(np.int32) gin.bind_parameter('trax.layers.SelfAttention.chunk_len', 20) simple_funnel = ft.RelformerLM(vocab_size, shorten_factor=3, n_rel_layers=1, vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, vanilla_attn_type=tl.SelfAttention, mode='predict') _, _ = simple_funnel.init(shapes.signature(x)) for _ in range(5): y = simple_funnel(x) self.assertEqual(y.shape, (batch_size, 1, vocab_size)) gin.clear_config()
def test_funnel_transformer_lm_forward_shape_predict(self): d_model = 8 vocab_size = 4 batch_size = 1 n_len_eval = 42 attention_type = tl.SelfAttention shorten_factor = 3 n_rel_layers = 2 vanilla_layers = (1, 1) n_heads = 2 rel_chunk_len, vanilla_chunk_len = 2, 6 x = np.ones((batch_size, 1)).astype(np.int32) gin.bind_parameter('trax.layers.SelfAttention.chunk_len', 20) predict_funnel = ft.RelformerLM(vocab_size, shorten_factor=shorten_factor, n_rel_layers=n_rel_layers, vanilla_layers=vanilla_layers, d_model=d_model, d_ff=d_model, n_heads=n_heads, vanilla_attn_type=attention_type, rel_chunk_len=rel_chunk_len, vanilla_chunk_len=vanilla_chunk_len, max_len=n_len_eval, mode='predict') _, _ = predict_funnel.init(shapes.signature(x)) for _ in range(5): y = predict_funnel(x) self.assertEqual(y.shape, (batch_size, 1, vocab_size)) gin.clear_config()