Example #1
0
    def test_reformer_rng_consistency(self):
        with math.use_backend('jax'):
            vocab_size = 16
            batch_size = 1
            input_sd = ShapeDtype((batch_size, 8), np.int32)
            input_signature = (input_sd, input_sd)
            model = reformer.ReformerLM(
                vocab_size,
                d_model=32,
                d_ff=64,
                d_attention_key=16,
                d_attention_value=16,
                n_layers=1,
                n_heads=2,
                max_len=16,
                n_chunks=2,
                n_attention_chunks=1,
                mode='train',
                attention_type=PoisonOnRNGMismatchAttention)

            rng = math.random.get_prng(0)
            weights, state = model.init(input_signature)

            def dummy_loss_fn(weights):
                inputs = (np.zeros(input_sd.shape, dtype=np.int32), ) * 2
                output = model(inputs, weights=weights, state=state, rng=rng)
                dummy_loss = math.numpy.sum(output[0])
                return dummy_loss

            grad_fn = math.grad(dummy_loss_fn)
            grads = grad_fn(weights)
            # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
            for grad in jax.tree_util.tree_leaves(grads):
                assert onp.all(onp.isfinite(grad))
Example #2
0
    def test_reformer_lm_memory(self):
        lsh_self_attention = functools.partial(
            tl.LSHSelfAttention,
            attention_dropout=0.0,
            chunk_len=64,
            n_buckets=[128, 128],
            n_chunks_after=0,
            n_chunks_before=1,
            n_hashes=1,
            n_parallel_heads=1,
            predict_drop_len=128,
            predict_mem_len=1024,
        )
        timebin_self_attention = functools.partial(
            tl.SelfAttention,
            attention_dropout=0.05,
            chunk_len=64,
            n_chunks_before=1,
            n_parallel_heads=1,
        )

        model = reformer.ReformerLM(
            vocab_size=256,
            d_model=256,
            d_ff=512,
            d_attention_key=64,
            d_attention_value=64,
            n_layers=6,
            n_heads=2,
            dropout=0.05,
            max_len=1048576,
            attention_type=[timebin_self_attention, lsh_self_attention],
            axial_pos_shape=(1024, 1024),
            d_axial_pos_embs=(64, 192),
            ff_activation=tl.Relu,
            ff_use_sru=0,
            ff_chunk_size=131072,
            mode='train',
        )
        x = np.ones((1, 1048576)).astype(np.int32)
        weights, state = model.init(shapes.signature(x))

        @jax.jit
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits, new_state = model.pure_fn(x, weights, state, rng)
                loss = jnp.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state, logits) = jax.grad(compute_mock_loss,
                                                      has_aux=True)(weights)
            new_weights = fastmath.nested_map_multiarg(
                lambda w, g: w - 1e-4 * g, weights, gradients)
            return new_weights, new_state, logits

        weights, state, logits = mock_training_step(x, weights, state,
                                                    jax.random.PRNGKey(0))
        self.assertEqual(logits.shape, (1, 1048576, 256))
Example #3
0
 def test_reformer_lm_forward_shape(self):
   vocab_size = 16
   model = reformer.ReformerLM(
       vocab_size, d_model=32, d_ff=64, d_attention_key=16,
       d_attention_value=16, n_layers=1, n_heads=2, max_len=16)
   xs = [np.ones((1, 8)).astype(np.int32),
         np.ones((1, 8)).astype(np.int32)]
   _, _ = model.init(shapes.signature(xs))
   ys = model(xs)
   self.assertEqual([y.shape for y in ys], [(1, 8, 16), (1, 8)])
Example #4
0
 def test_reformer_lm_forward_shape(self):
     """Run the ReformerLM forward and check output shape."""
     vocab_size = 16
     input_sd = ShapeDtype((1, 8), np.int32)
     input_signature = (input_sd, input_sd)
     model = reformer.ReformerLM(vocab_size,
                                 d_model=32,
                                 d_ff=64,
                                 d_attention_key=16,
                                 d_attention_value=16,
                                 n_layers=1,
                                 n_heads=2,
                                 max_len=16)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual(((1, 8, 16), (1, 8)), final_shape)
Example #5
0
    def test_reformer_lm_lsh(self):
        lsh_self_attention = self._lsh_self_attention_fn()
        timebin_self_attention = self._timebin_self_attention_fn()

        model = reformer.ReformerLM(
            vocab_size=256,
            d_model=256,
            d_ff=512,
            d_attention_key=64,
            d_attention_value=64,
            n_layers=2,
            n_heads=2,
            dropout=0.05,
            max_len=65536,
            attention_type=[timebin_self_attention, lsh_self_attention],
            pos_axial_shape=(256, 256),
            pos_d_axial_embs=(64, 192),
            ff_activation=tl.Relu,
            ff_use_sru=0,
            ff_chunk_size=8192,
            mode='train',
        )
        x = np.ones((1, 65536)).astype(np.int32)
        weights, state = model.init(shapes.signature(x))

        @fastmath.jit
        def mock_training_step(x, weights, state, rng):
            def compute_mock_loss(weights):
                logits, new_state = model.pure_fn(x, weights, state, rng)
                loss = fastmath.numpy.mean(logits[..., 0])
                return loss, (new_state, logits)

            gradients, (new_state,
                        logits) = fastmath.grad(compute_mock_loss,
                                                has_aux=True)(weights)
            new_weights = fastmath.nested_map_multiarg(
                lambda w, g: w - 1e-4 * g, weights, gradients)
            return new_weights, new_state, logits

        weights, state, logits = mock_training_step(
            x, weights, state, fastmath.random.get_prng(0))
        self.assertEqual(logits.shape, (1, 65536, 256))
Example #6
0
 def test_reformer_lm_forward_shape_tf(self):
     with math.use_backend('tf'):
         vocab_size = 16
         timebin_attn = self._timebin_self_attention_fn(
             use_reference_code=True)
         model = reformer.ReformerLM(vocab_size,
                                     d_model=32,
                                     d_ff=64,
                                     d_attention_key=16,
                                     d_attention_value=16,
                                     n_layers=1,
                                     n_heads=2,
                                     max_len=64,
                                     attention_type=timebin_attn)
         xs = [
             np.ones((1, 64)).astype(np.int32),
             np.ones((1, 64)).astype(np.int32)
         ]
         _, _ = model.init(shapes.signature(xs))
         ys = model(xs)
         self.assertEqual([y.shape for y in ys], [(1, 64, 16), (1, 64)])