def test_terraformer_doubling(self): vocab_size = 2 max_len = 2 model = terraformer.ConfigurableTerraformer( vocab_size, d_model=8, d_ff=16, n_encoder_layers=1, n_decoder_layers=6, n_heads=2, dropout=0.05, max_len=max_len, pos_type=None, half_before_layer=2, double_after_layer=2, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, mode='train', ) x = [np.ones((1, max_len)).astype(np.int32), np.ones((1, max_len)).astype(np.int32)] model.init(shapes.signature(x)) logits, dec_toks = model(x) del dec_toks self.assertEqual(logits.shape, (1, max_len, vocab_size))
def test_run_reversible_same_as_default_terraformer(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) + 1 targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) input_sig = (int_sig, int_sig, int_sig) # We want to test rng propagation too, so adding some dropout layers. model = terraformer.ConfigurableTerraformer(20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), pos_type=None, reversible_encoder=True) loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) optimizer_fn = optimizers.Adafactor blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( [model, loss], loss_chunk_size=4) blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] model_with_loss = tl.Serial(model, loss) rng_init = fastmath.random.get_prng(12) model_with_loss.init(input_sig, rng=rng_init) # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model_with_loss.weights) trainer = optimizers.Trainer(model_with_loss, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_weights = blocks_serial[0][0].weights first_rev_weights = blocks[0][1][0].weights loss_weights = loss_layer.weights # Now make 3 steps with reversible trainer. model_with_loss.init(input_sig, rng=rng_init) trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_weights, loss_layer.weights) self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) self._assert_all_equal(first_weights, blocks_serial[0][0].weights)
def test_run_sharded_terraformer(self): """Runs Terraformer with sharded weights (only on 2+-device systems).""" if fastmath.local_device_count() == 1: return base.N_WEIGHTS_SHARDS = fastmath.local_device_count() inputs_batch = np.arange(8).reshape((2, 4)) + 1 targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) input_sig = (int_sig, int_sig, int_sig) # We want to test rng propagation too, so adding some dropout layers. model = terraformer.ConfigurableTerraformer( 20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, pos_type=None, reversible_encoder=True) loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) model_with_loss = tl.Serial(model, loss) rng_init = fastmath.random.get_prng(12) model_with_loss.init(input_sig, rng=rng_init) # Make a step with the trainer. optimizer = optimizers.Adafactor(0.01) split_w = fastmath.nested_map( lambda x: x[0], tl.shard(model_with_loss.weights, base.N_WEIGHTS_SHARDS)) optimizer.tree_init(split_w) trainer = optimizers.Trainer(model_with_loss, optimizer) rng_step1 = fastmath.random.get_prng(7) trainer.one_step(labeled_batch, rng_step1) # Reset shards back to default. base.N_WEIGHTS_SHARDS = 1
def test_terraformer_quick(self, backend, encoder_attention_type, preembed): with fastmath.use_backend(backend): vocab_size = 2 input_vocab_size = None if preembed else vocab_size output_vocab_size = vocab_size if preembed else None max_len = 2 model = terraformer.ConfigurableTerraformer( input_vocab_size, d_model=4, d_ff=4, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, dropout=0.05, max_len=max_len, pos_type=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=2, mode='train', output_vocab_size=output_vocab_size, encoder_attention_type=encoder_attention_type, ) if preembed: model_inputs = [ np.ones((1, max_len, 3)).astype(np.float32), np.ones((1, max_len)).astype(np.bool) ] else: model_inputs = [np.ones((1, max_len)).astype(np.int32)] x = model_inputs + [np.ones((1, max_len)).astype(np.int32)] model.init(shapes.signature(x)) logits, dec_toks = model(x) del dec_toks self.assertEqual(logits.shape, (1, max_len, vocab_size))
def test_terraformer_one_step(self): d_model = 1024 vocab_size = 14041 max_len = 16384 pos_axial = (128, 128) # should multiply to max_len pos_d_axial_embs = (512, 512) # sum to d model assert operator.mul(*pos_axial) == max_len assert sum(pos_d_axial_embs) == d_model d_ff = 4096 n_heads = 8 d_attn = d_model // n_heads n_buckets = 128 encoder_chunk_len = (2 * max_len) // n_buckets # 256 decoder_chunk_len = 2 * encoder_chunk_len # 512 encoder_n_chunks_after = 1 # since its not causal. lsh_self_attention = functools.partial(self._lsh_self_attention_fn(), n_buckets=n_buckets) encoder_lsh_self_attention = functools.partial( lsh_self_attention, n_chunks_after=encoder_n_chunks_after, chunk_len=encoder_chunk_len) decoder_lsh_self_attention = functools.partial( lsh_self_attention, n_chunks_after=0, chunk_len=decoder_chunk_len) model = terraformer.ConfigurableTerraformer( vocab_size, d_model=d_model, d_ff=d_ff, d_attention_key=d_attn, d_attention_value=d_attn, n_encoder_layers=1, n_decoder_layers=1, n_heads=n_heads, dropout=0.05, max_len=max_len, encoder_attention_type=encoder_lsh_self_attention, encoder_decoder_attention_type=decoder_lsh_self_attention, pos_axial_shape=pos_axial, pos_d_axial_embs=pos_d_axial_embs, ff_activation=tl.Relu, ff_use_sru=0, mode='train', ) def random_sentence(): return np.random.randint(low=1, high=vocab_size - 1, size=(1, max_len), dtype=np.int32) x = [random_sentence(), random_sentence()] weights, state = model.init(shapes.signature(x)) @fastmath.jit def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits_and_dec_toks, new_state = model.pure_fn( x, weights, state, rng) # This returns [logits, decoder tokens] logits = logits_and_dec_toks[0] loss = fastmath.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = fastmath.grad(compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits weights, state, logits = mock_training_step( x, weights, state, fastmath.random.get_prng(0)) self.assertEqual(logits.shape, (1, max_len, vocab_size))
def test_terraformer_one_step(self): vocab_size = 32 max_len = 256 pos_axial = 16 assert pos_axial * pos_axial == max_len chunk_len = 32 # Since 2 * chunk_len * n_buckets should be max_len. n_buckets = max_len // (2 * chunk_len) lsh_self_attention = functools.partial(self._lsh_self_attention_fn(), chunk_len=chunk_len, n_buckets=n_buckets) timebin_self_attention = self._timebin_self_attention_fn() model = terraformer.ConfigurableTerraformer( vocab_size, d_model=32, d_ff=64, d_attention_key=64, d_attention_value=64, n_encoder_layers=2, n_decoder_layers=2, n_heads=2, dropout=0.05, max_len=max_len, encoder_attention_type=lsh_self_attention, encoder_decoder_attention_type=[timebin_self_attention, lsh_self_attention], pos_axial_shape=(pos_axial, pos_axial), pos_d_axial_embs=(64, 192), ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=64, ff_sparsity=8, mode='train', ) x = [np.ones((1, max_len)).astype(np.int32), np.ones((1, max_len)).astype(np.int32)] weights, state = model.init(shapes.signature(x)) @fastmath.jit def mock_training_step(x, weights, state, rng): def compute_mock_loss(weights): logits_and_dec_toks, new_state = model.pure_fn(x, weights, state, rng) # This returns [logits, decoder tokens] logits = logits_and_dec_toks[0] loss = fastmath.numpy.mean(logits[..., 0]) return loss, (new_state, logits) gradients, (new_state, logits) = fastmath.grad( compute_mock_loss, has_aux=True)(weights) new_weights = fastmath.nested_map_multiarg( lambda w, g: w - 1e-4 * g, weights, gradients) return new_weights, new_state, logits weights, state, logits = mock_training_step( x, weights, state, fastmath.random.get_prng(0)) self.assertEqual(logits.shape, (1, max_len, vocab_size))