def test_multiple_sequences(self):
        tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
        model = FlaxBertModel.from_pretrained("bert-base-cased")

        sequences = [
            "this is an example sentence", "this is another", "and a third one"
        ]
        encodings = tokenizer(sequences,
                              return_tensors=TensorType.JAX,
                              padding=True,
                              truncation=True)

        @jax.jit
        def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
            return model(input_ids, attention_mask, token_type_ids)

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                tokens, pooled = model_jitted(**encodings)
                self.assertEqual(tokens.shape, (3, 7, 768))
                self.assertEqual(pooled.shape, (3, 768))

        with self.subTest("JIT Enabled"):
            jitted_tokens, jitted_pooled = model_jitted(**encodings)

            self.assertEqual(jitted_tokens.shape, (3, 7, 768))
            self.assertEqual(jitted_pooled.shape, (3, 768))
예제 #2
0
    def test_bert_jax_jit(self):
        for model_name in ["bert-base-cased", "bert-large-uncased"]:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = FlaxBertModel.from_pretrained(model_name)
            tokens = tokenizer("Do you support jax jitted function?",
                               return_tensors=TensorType.JAX)

            @jax.jit
            def eval(**kwargs):
                return model(**kwargs)

            eval(**tokens).block_until_ready()
    def test_from_pytorch(self):
        with torch.no_grad():
            with self.subTest("bert-base-cased"):
                tokenizer = BertTokenizerFast.from_pretrained(
                    "bert-base-cased")
                fx_model = FlaxBertModel.from_pretrained("bert-base-cased")
                pt_model = BertModel.from_pretrained("bert-base-cased")

                # Check for simple input
                pt_inputs = tokenizer.encode_plus(
                    "This is a simple input",
                    return_tensors=TensorType.PYTORCH)
                fx_inputs = tokenizer.encode_plus(
                    "This is a simple input", return_tensors=TensorType.JAX)
                pt_outputs = pt_model(**pt_inputs).to_tuple()
                fx_outputs = fx_model(**fx_inputs)

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs),
                    "Output lengths differ between Flax and PyTorch")

                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
                    self.assert_almost_equals(fx_output, pt_output.numpy(),
                                              5e-3)
예제 #4
0
def test_multiple_sentences(jit):
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
    model = FlaxBertModel.from_pretrained("bert-base-cased")

    sentences = [
        "this is an example sentence", "this is another", "and a third one"
    ]
    encodings = tokenizer(sentences,
                          return_tensors=TensorType.JAX,
                          padding=True,
                          truncation=True)

    @jax.jit
    def model_jitted(input_ids, attention_mask, token_type_ids):
        return model(input_ids, attention_mask, token_type_ids)

    if jit == "disable_jit":
        with jax.disable_jit():
            tokens, pooled = model_jitted(**encodings)
    else:
        tokens, pooled = model_jitted(**encodings)

    assert tokens.shape == (3, 7, 768)
    assert pooled.shape == (3, 768)
예제 #5
0
 def test_model_from_pretrained(self):
     # Only check this for base model, not necessary for all model classes.
     # This will also help speed-up tests.
     model = FlaxBertModel.from_pretrained("bert-base-cased")
     outputs = model(np.ones((1, 1)))
     self.assertIsNotNone(outputs)