def test_autoregressive_sample_reformerlm(self): lsh_self_attention = self._lsh_self_attention_fn() timebin_self_attention = self._timebin_self_attention_fn() model = models.ReformerLM( vocab_size=256, d_model=256, d_ff=512, d_attention_key=128, d_attention_value=128, n_layers=2, n_heads=2, dropout=0.05, max_len=65536, attention_type=[timebin_self_attention, lsh_self_attention], pos_axial_shape=(256, 256), pos_d_axial_embs=(128, 128), ff_activation=tl.Relu, ff_use_sru=0, mode='predict', ) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) s1 = decoding.autoregressive_sample(model, batch_size=1, eos_id=-1, max_length=10) self.assertEqual(s1.shape[0], 1) self.assertEqual(s1.shape[1], 10)
def test_autoregressive_sample_reformerlm_quality(self): timebin_self_attention = self._timebin_self_attention_fn() pred_model = models.ReformerLM( d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, attention_type=timebin_self_attention, n_layers=2, vocab_size=13, mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformerlm_copy.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape11, shape11)) inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) s = decoding.autoregressive_sample(pred_model, inputs, max_length=6, temperature=0.0) self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
def test_autoregressive_sample_reformerlm_lsh_quality(self): # After changes to some fastmath.custom_vjp functions (made so that we could # land JAX PR #4008), this test started failing, with an assertion error on # efficient_attention.py:1382 (q_len == 1). # TODO(mattjj,lukaszkaiser): revive this test after landing #4008 raise unittest.SkipTest('temporarily skipping test so that we can ' 'land https://github.com/google/jax/pull/4008') # pylint: disable=unreachable max_len = 32 pred_model = models.ReformerLM( mode='predict', d_model=256, d_ff=512, dropout=0.05, max_len=2 * max_len, n_heads=4, n_layers=3, ff_use_sru=0, d_attention_key=64, d_attention_value=64, attention_type=functools.partial(tl.LSHSelfAttention, chunk_len=16, n_hashes=2, n_buckets=[32, 32], predict_drop_len=max_len, predict_mem_len=max_len, max_length_for_buckets=1024), vocab_size=13, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformerlm_copy_lsh_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=shape11) # 0w0 inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) inp_len = inputs.shape[1] s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=inp_len - 2, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, 1:inp_len - 1])
def test_autoregressive_sample_reformerlm_lsh_quality(self): max_len = 32 pred_model = models.ReformerLM( mode='predict', d_model=256, d_ff=512, dropout=0.05, max_len=2 * max_len, n_heads=4, n_layers=3, ff_use_sru=0, d_attention_key=64, d_attention_value=64, attention_type=functools.partial(tl.LSHSelfAttention, chunk_len=16, n_hashes=2, n_buckets=[32, 32], predict_drop_len=max_len, predict_mem_len=max_len, max_length_for_buckets=1024), vocab_size=13, pos_type='fixed-base', pos_d_axial_embs=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformerlm_copy_lsh_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=shape11) # 0w0 inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) inp_len = inputs.shape[1] s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=inp_len - 2, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, 1:inp_len - 1])
def test_autoregressive_sample_reformerlm_lsh(self): max_len = 32 pred_model = models.ReformerLM( mode='predict', d_model=256, d_ff=512, dropout=0.05, max_len=2 * max_len, n_heads=4, n_layers=3, ff_use_sru=0, d_attention_key=64, d_attention_value=64, attention_type=functools.partial(tl.LSHSelfAttention, chunk_len=16, n_hashes=2, n_buckets=[32, 32], predict_drop_len=max_len, predict_mem_len=max_len, max_length_for_buckets=1024), vocab_size=13, pos_type='fixed-base', pos_d_axial_embs=None, ) shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) pred_model.init(shape11) # 0w0 inputs = np.array([[0, 3, 7, 5, 3, 2, 0]], dtype=np.int32) inputs = np.pad(inputs, [(0, 0), (0, max_len - inputs.shape[1])], mode='constant', constant_values=0) s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=10, temperature=0.0) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 10)