Ejemplo n.º 1
0
    def test_autoregressive_sample_reformerlm(self):
        lsh_self_attention = self._lsh_self_attention_fn()
        timebin_self_attention = self._timebin_self_attention_fn()

        model = models.ReformerLM(
            vocab_size=256,
            d_model=256,
            d_ff=512,
            d_attention_key=128,
            d_attention_value=128,
            n_layers=2,
            n_heads=2,
            dropout=0.05,
            max_len=65536,
            attention_type=[timebin_self_attention, lsh_self_attention],
            pos_axial_shape=(256, 256),
            pos_d_axial_embs=(128, 128),
            ff_activation=tl.Relu,
            ff_use_sru=0,
            mode='predict',
        )
        model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
        s1 = decoding.autoregressive_sample(model,
                                            batch_size=1,
                                            eos_id=-1,
                                            max_length=10)
        self.assertEqual(s1.shape[0], 1)
        self.assertEqual(s1.shape[1], 10)
Ejemplo n.º 2
0
 def test_autoregressive_sample_reformerlm_quality(self):
   timebin_self_attention = self._timebin_self_attention_fn()
   pred_model = models.ReformerLM(
       d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2,
       attention_type=timebin_self_attention,
       n_layers=2, vocab_size=13, mode='predict')
   shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
   model_path = os.path.join(_TESTDATA, 'reformerlm_copy.pkl.gz')
   pred_model.init_from_file(model_path, weights_only=True,
                             input_signature=(shape11, shape11))
   inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
   s = decoding.autoregressive_sample(pred_model, inputs,
                                      max_length=6, temperature=0.0)
   self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
Ejemplo n.º 3
0
    def test_autoregressive_sample_reformerlm_lsh_quality(self):
        # After changes to some fastmath.custom_vjp functions (made so that we could
        # land JAX PR #4008), this test started failing, with an assertion error on
        # efficient_attention.py:1382 (q_len == 1).
        # TODO(mattjj,lukaszkaiser): revive this test after landing #4008
        raise unittest.SkipTest('temporarily skipping test so that we can '
                                'land https://github.com/google/jax/pull/4008')
        # pylint: disable=unreachable
        max_len = 32

        pred_model = models.ReformerLM(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=2 * max_len,
            n_heads=4,
            n_layers=3,
            ff_use_sru=0,
            d_attention_key=64,
            d_attention_value=64,
            attention_type=functools.partial(tl.LSHSelfAttention,
                                             chunk_len=16,
                                             n_hashes=2,
                                             n_buckets=[32, 32],
                                             predict_drop_len=max_len,
                                             predict_mem_len=max_len,
                                             max_length_for_buckets=1024),
            vocab_size=13,
            axial_pos_shape='fixed-base',
            d_axial_pos_embs=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)

        model_path = os.path.join(_TESTDATA, 'reformerlm_copy_lsh_attn.pkl.gz')
        pred_model.init_from_file(model_path,
                                  weights_only=True,
                                  input_signature=shape11)

        # 0w0
        inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32)
        inp_len = inputs.shape[1]
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=inp_len - 2,
                                           temperature=0.0)

        np.testing.assert_equal(s[0], inputs[0, 1:inp_len - 1])
Ejemplo n.º 4
0
    def test_autoregressive_sample_reformerlm_lsh_quality(self):
        max_len = 32

        pred_model = models.ReformerLM(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=2 * max_len,
            n_heads=4,
            n_layers=3,
            ff_use_sru=0,
            d_attention_key=64,
            d_attention_value=64,
            attention_type=functools.partial(tl.LSHSelfAttention,
                                             chunk_len=16,
                                             n_hashes=2,
                                             n_buckets=[32, 32],
                                             predict_drop_len=max_len,
                                             predict_mem_len=max_len,
                                             max_length_for_buckets=1024),
            vocab_size=13,
            pos_type='fixed-base',
            pos_d_axial_embs=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)

        model_path = os.path.join(_TESTDATA, 'reformerlm_copy_lsh_attn.pkl.gz')
        pred_model.init_from_file(model_path,
                                  weights_only=True,
                                  input_signature=shape11)

        # 0w0
        inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32)
        inp_len = inputs.shape[1]
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=inp_len - 2,
                                           temperature=0.0)

        np.testing.assert_equal(s[0], inputs[0, 1:inp_len - 1])
Ejemplo n.º 5
0
    def test_autoregressive_sample_reformerlm_lsh(self):
        max_len = 32

        pred_model = models.ReformerLM(
            mode='predict',
            d_model=256,
            d_ff=512,
            dropout=0.05,
            max_len=2 * max_len,
            n_heads=4,
            n_layers=3,
            ff_use_sru=0,
            d_attention_key=64,
            d_attention_value=64,
            attention_type=functools.partial(tl.LSHSelfAttention,
                                             chunk_len=16,
                                             n_hashes=2,
                                             n_buckets=[32, 32],
                                             predict_drop_len=max_len,
                                             predict_mem_len=max_len,
                                             max_length_for_buckets=1024),
            vocab_size=13,
            pos_type='fixed-base',
            pos_d_axial_embs=None,
        )

        shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
        pred_model.init(shape11)

        # 0w0
        inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32)
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])],
                        mode='constant',
                        constant_values=0)
        s = decoding.autoregressive_sample(pred_model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=10,
                                           temperature=0.0)

        self.assertEqual(s.shape[0], 1)
        self.assertEqual(s.shape[1], 10)