Beispiel #1
0
    def test_reformer2_predict_equals_eval(self):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            d_model = 8
            batch_size = 2
            length = 5

            model_fn = functools.partial(
                reformer.Reformer2,
                vocab_size,
                d_model=d_model,
                d_ff=16,
                n_encoder_layers=1,
                n_decoder_layers=1,
                n_heads=2,
                dropout=0.0,
                max_len=length * 2,
                pos_type=None,
                n_decoder_attention_layers=1,
                encoder_attention_type=tl.Attention,
                encoder_decoder_attention_type=tl.CausalAttention,
            )

            # Token id of 0 indicates padding; and predict mode doesn't support it.
            inp = np.random.randint(1, vocab_size, size=(batch_size, length))
            out = np.zeros((batch_size, length), dtype=np.int32)

            # TODO(jaszczur): check why init_tokens > 1 fails nondeterministically
            test_utils.test_eval_equals_predict((inp, out), model_fn, 1, -1)
Beispiel #2
0
    def _test_sparse_fast_inference(self, length):
        with fastmath.use_backend(fastmath.Backend.JAX):
            vocab_size = 16
            d_model = 4
            batch_size = 2

            encoder_decoder_attention_type = functools.partial(
                tl.MultiplicativeConvCausalAttention,
                sparsity=2,
                length_kernel_size=1,
            )

            model_fn = functools.partial(
                ct.ConfigurableTransformer,
                input_vocab_size=vocab_size,
                d_model=d_model,
                d_ff=8,
                n_encoder_layers=2,
                n_decoder_layers=2,
                n_heads=2,
                loss_sparsity=2,
                ff_sparsity=2,
                encoder_decoder_attention_type=encoder_decoder_attention_type,
                ff_use_sru=(1, 4),
            )

            inp = np.random.randint(vocab_size, size=(batch_size, length))
            out = np.zeros((batch_size, length), dtype=np.int32)

            test_utils.test_eval_equals_predict((inp, out),
                                                model_fn,
                                                seq_tensor=1)
Beispiel #3
0
  def test_terraformer_predict_equals_eval(self):
    with fastmath.use_backend(fastmath.Backend.JAX):
      vocab_size = 16
      d_model = 8
      batch_size = 1
      length = 5

      model_fn = functools.partial(
          terraformer.ConfigurableTerraformer,
          vocab_size,
          d_model=d_model,
          d_ff=16,
          n_encoder_layers=1,
          n_decoder_layers=1,
          n_heads=2,
          ff_use_sru=(1, 8),  # ? is SRU working?
          dropout=0.0,
          max_len=(length+7)*2,
          pos_type=None,
          reversible_encoder=True,
          n_decoder_attention_layers=1,
          encoder_attention_type=tl.Attention,
          encoder_decoder_attention_type=tl.CausalAttention,
      )

      # Token id of 0 indicates padding; and predict mode doesn't support it.
      inp = np.random.randint(1, vocab_size, size=(batch_size, length))
      inp[:, -2:] = 0
      out = np.zeros((batch_size, length), dtype=np.int32)

      test_utils.test_eval_equals_predict(
          (inp, out), model_fn, seq_axis=1, seq_tensor=-1, init_tokens=1)
Beispiel #4
0
    def test_predict_equals_eval(self):
        d_model = 32
        seq_len = 10
        x_shape = (1, seq_len, d_model)
        inp = np.ones(x_shape).astype(np.float32)

        model_fn = functools.partial(
            tl.CausalAttention,
            d_feature=d_model,
            n_heads=4,
        )

        test_utils.test_eval_equals_predict(inp, model_fn)
Beispiel #5
0
    def _test_fast_inference(self, length):
        with fastmath.use_backend(fastmath.Backend.JAX):
            model_fn = functools.partial(
                ct.ConfigurableTransformerLM,
                vocab_size=16,
                d_model=4,
                d_ff=8,
                n_layers=2,
                n_heads=2,
            )
            batch_size = 2
            inp = np.zeros((batch_size, length), dtype=np.int32)

            test_utils.test_eval_equals_predict(inp, model_fn)