def test_reformer_rng_consistency(self): with math.use_backend('jax'): vocab_size = 16 batch_size = 1 input_sd = ShapeDtype((batch_size, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', attention_type=PoisonOnRNGMismatchAttention) rng = math.random.get_prng(0) weights, state = model.init(input_signature) def dummy_loss_fn(weights): inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2 output = model(inputs, weights=weights, state=state, rng=rng) dummy_loss = math.numpy.sum(output[0]) return dummy_loss grad_fn = math.grad(dummy_loss_fn) grads = grad_fn(weights) # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. for grad in jax.tree_util.tree_leaves(grads): assert onp.all(onp.isfinite(grad))
def test_reformer_lm_memory(self): lsh_self_attention = functools.partial( tl.LSHSelfAttention, attention_dropout=0.0, chunk_len=64, n_buckets=[128, 128], n_chunks_after=0, n_chunks_before=1, n_hashes=1, n_parallel_heads=1, predict_drop_len=128, predict_mem_len=1024, ) timebin_self_attention = functools.partial( tl.SelfAttention, attention_dropout=0.05, chunk_len=64, n_chunks_before=1, n_parallel_heads=1, ) model = reformer.ReformerLM( vocab_size=256, d_model=256, d_ff=512, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=2, dropout=0.05, max_len=1048576, attention_type=[timebin_self_attention, lsh_self_attention], axial_pos_shape=(1024, 1024), d_axial_pos_embs=(64, 192), ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=131072, mode='train', ) x = np.ones((1, 1048576)).astype(np.int32) weights, state = model.init(shapes.signature(x)) @jax.jit def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = model.pure_fn(x, weights, state, rng) loss = jnp.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = jax.grad(compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits weights, state, logits = mock_training_step(x, weights, state, jax.random.PRNGKey(0)) self.assertEqual(logits.shape, (1, 1048576, 256))
def test_reformer_lm_forward_shape(self): vocab_size = 16 model = reformer.ReformerLM( vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16) xs = [np.ones((1, 8)).astype(np.int32), np.ones((1, 8)).astype(np.int32)] _, _ = model.init(shapes.signature(xs)) ys = model(xs) self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)])
def test_reformer_lm_forward_shape(self): """Run the ReformerLM forward and check output shape.""" vocab_size = 16 input_sd = ShapeDtype((1, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM(vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual(((1, 8, 16), (1, 8)), final_shape)
def test_reformer_lm_lsh(self): lsh_self_attention = self._lsh_self_attention_fn() timebin_self_attention = self._timebin_self_attention_fn() model = reformer.ReformerLM( vocab_size=256, d_model=256, d_ff=512, d_attention_key=64, d_attention_value=64, n_layers=2, n_heads=2, dropout=0.05, max_len=65536, attention_type=[timebin_self_attention, lsh_self_attention], pos_axial_shape=(256, 256), pos_d_axial_embs=(64, 192), ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=8192, mode='train', ) x = np.ones((1, 65536)).astype(np.int32) weights, state = model.init(shapes.signature(x)) @fastmath.jit def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits, new_state = model.pure_fn(x, weights, state, rng) loss = fastmath.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = fastmath.grad(compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits weights, state, logits = mock_training_step( x, weights, state, fastmath.random.get_prng(0)) self.assertEqual(logits.shape, (1, 65536, 256))
def test_reformer_lm_forward_shape_tf(self): with math.use_backend('tf'): vocab_size = 16 timebin_attn = self._timebin_self_attention_fn( use_reference_code=True) model = reformer.ReformerLM(vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=64, attention_type=timebin_attn) xs = [ np.ones((1, 64)).astype(np.int32), np.ones((1, 64)).astype(np.int32) ] _, _ = model.init(shapes.signature(xs)) ys = model(xs) self.assertEqual([y.shape for y in ys], [(1, 64, 16), (1, 64)])