예제 #1
0
    def test_reformer2_deterministic_eval(self):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            d_model = 4
            batch_size = 2
            length = 5

            model_fn = functools.partial(
                reformer.Reformer2,
                vocab_size,
                d_model=d_model,
                d_ff=16,
                n_encoder_layers=0,
                n_decoder_layers=1,
                n_heads=2,
                dropout=0.0,
                max_len=length * 2,
                pos_type=None,
                encoder_attention_type=tl.Attention,
                encoder_decoder_attention_type=tl.CausalAttention,
            )

            inp = np.random.randint(vocab_size, size=(batch_size, length))
            out = np.zeros((batch_size, length), dtype=np.int32)

            test_utils.test_eval_is_deterministic((inp, out), model_fn)
예제 #2
0
    def test_deterministic_eval(self):
        d_model = 32
        seq_len = 3
        x_shape = (1, seq_len, d_model)
        inp = np.ones(x_shape).astype(np.float32)

        model_fn = functools.partial(
            tl.CausalAttention,
            d_feature=d_model,
            n_heads=4,
        )

        test_utils.test_eval_is_deterministic(inp, model_fn)