def test_reformer2_predict_equals_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 8 batch_size = 2 length = 5 model_fn = functools.partial( reformer.Reformer2, vocab_size, d_model=d_model, d_ff=16, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, dropout=0.0, max_len=length * 2, pos_type=None, n_decoder_attention_layers=1, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, ) # Token id of 0 indicates padding; and predict mode doesn't support it. inp = np.random.randint(1, vocab_size, size=(batch_size, length)) out = np.zeros((batch_size, length), dtype=np.int32) # TODO(jaszczur): check why init_tokens > 1 fails nondeterministically test_utils.test_eval_equals_predict((inp, out), model_fn, 1, -1)
def _test_sparse_fast_inference(self, length): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 4 batch_size = 2 encoder_decoder_attention_type = functools.partial( tl.MultiplicativeConvCausalAttention, sparsity=2, length_kernel_size=1, ) model_fn = functools.partial( ct.ConfigurableTransformer, input_vocab_size=vocab_size, d_model=d_model, d_ff=8, n_encoder_layers=2, n_decoder_layers=2, n_heads=2, loss_sparsity=2, ff_sparsity=2, encoder_decoder_attention_type=encoder_decoder_attention_type, ff_use_sru=(1, 4), ) inp = np.random.randint(vocab_size, size=(batch_size, length)) out = np.zeros((batch_size, length), dtype=np.int32) test_utils.test_eval_equals_predict((inp, out), model_fn, seq_tensor=1)
def test_terraformer_predict_equals_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 8 batch_size = 1 length = 5 model_fn = functools.partial( terraformer.ConfigurableTerraformer, vocab_size, d_model=d_model, d_ff=16, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, ff_use_sru=(1, 8), # ? is SRU working? dropout=0.0, max_len=(length+7)*2, pos_type=None, reversible_encoder=True, n_decoder_attention_layers=1, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, ) # Token id of 0 indicates padding; and predict mode doesn't support it. inp = np.random.randint(1, vocab_size, size=(batch_size, length)) inp[:, -2:] = 0 out = np.zeros((batch_size, length), dtype=np.int32) test_utils.test_eval_equals_predict( (inp, out), model_fn, seq_axis=1, seq_tensor=-1, init_tokens=1)
def test_predict_equals_eval(self): d_model = 32 seq_len = 10 x_shape = (1, seq_len, d_model) inp = np.ones(x_shape).astype(np.float32) model_fn = functools.partial( tl.CausalAttention, d_feature=d_model, n_heads=4, ) test_utils.test_eval_equals_predict(inp, model_fn)
def _test_fast_inference(self, length): with fastmath.use_backend(fastmath.Backend.JAX): model_fn = functools.partial( ct.ConfigurableTransformerLM, vocab_size=16, d_model=4, d_ff=8, n_layers=2, n_heads=2, ) batch_size = 2 inp = np.zeros((batch_size, length), dtype=np.int32) test_utils.test_eval_equals_predict(inp, model_fn)