def test_reformer2_deterministic_eval(self): with fastmath.use_backend(fastmath.Backend.JAX): vocab_size = 16 d_model = 4 batch_size = 2 length = 5 model_fn = functools.partial( reformer.Reformer2, vocab_size, d_model=d_model, d_ff=16, n_encoder_layers=0, n_decoder_layers=1, n_heads=2, dropout=0.0, max_len=length * 2, pos_type=None, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, ) inp = np.random.randint(vocab_size, size=(batch_size, length)) out = np.zeros((batch_size, length), dtype=np.int32) test_utils.test_eval_is_deterministic((inp, out), model_fn)
def test_deterministic_eval(self): d_model = 32 seq_len = 3 x_shape = (1, seq_len, d_model) inp = np.ones(x_shape).astype(np.float32) model_fn = functools.partial( tl.CausalAttention, d_feature=d_model, n_heads=4, ) test_utils.test_eval_is_deterministic(inp, model_fn)