def test_multiple_sequences(self):
        tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        model = FlaxRobertaModel.from_pretrained("roberta-base")

        sequences = [
            "this is an example sentence", "this is another", "and a third one"
        ]
        encodings = tokenizer(sequences,
                              return_tensors=TensorType.JAX,
                              padding=True,
                              truncation=True)

        @jax.jit
        def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
            return model(input_ids, attention_mask, token_type_ids)

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                tokens, pooled = model_jitted(**encodings)
                self.assertEqual(tokens.shape, (3, 7, 768))
                self.assertEqual(pooled.shape, (3, 768))

        with self.subTest("JIT Enabled"):
            jitted_tokens, jitted_pooled = model_jitted(**encodings)

            self.assertEqual(jitted_tokens.shape, (3, 7, 768))
            self.assertEqual(jitted_pooled.shape, (3, 768))
    def test_from_pytorch(self):
        with torch.no_grad():
            with self.subTest("roberta-base"):
                tokenizer = RobertaTokenizerFast.from_pretrained(
                    "roberta-base")
                fx_model = FlaxRobertaModel.from_pretrained("roberta-base")
                pt_model = RobertaModel.from_pretrained("roberta-base")

                # Check for simple input
                pt_inputs = tokenizer.encode_plus(
                    "This is a simple input",
                    return_tensors=TensorType.PYTORCH)
                fx_inputs = tokenizer.encode_plus(
                    "This is a simple input", return_tensors=TensorType.JAX)
                pt_outputs = pt_model(**pt_inputs)
                fx_outputs = fx_model(**fx_inputs)

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs),
                    "Output lengths differ between Flax and PyTorch")

                for fx_output, pt_output in zip(fx_outputs,
                                                pt_outputs.to_tuple()):
                    self.assert_almost_equals(fx_output, pt_output.numpy(),
                                              5e-3)
Example #3
0
    def test_roberta_jax_jit(self):
        for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = FlaxRobertaModel.from_pretrained(model_name)
            tokens = tokenizer("Do you support jax jitted function?",
                               return_tensors=TensorType.JAX)

            @jax.jit
            def eval(**kwargs):
                return model(**kwargs)

            eval(**tokens).block_until_ready()
def test_multiple_sentences(jit):
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    model = FlaxRobertaModel.from_pretrained("roberta-base")

    sentences = ["this is an example sentence", "this is another", "and a third one"]
    encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)

    @jax.jit
    def model_jitted(input_ids, attention_mask):
        return model(input_ids, attention_mask)

    if jit == "disable_jit":
        with jax.disable_jit():
            tokens, pooled = model_jitted(**encodings)
    else:
        tokens, pooled = model_jitted(**encodings)

    assert tokens.shape == (3, 7, 768)
    assert pooled.shape == (3, 768)