def test_TFXLMForQuestionAnsweringSimple(self):
     from transformers import XLMConfig, TFXLMForQuestionAnsweringSimple
     keras.backend.clear_session()
     # pretrained_weights = 'xlm-mlm-enfr-1024'
     tokenizer_file = 'xlm_xlm-mlm-enfr-1024.pickle'
     tokenizer = self._get_tokenzier(tokenizer_file)
     text, inputs, inputs_onnx = self._prepare_inputs(tokenizer)
     config = XLMConfig()
     model = TFXLMForQuestionAnsweringSimple(config)
     predictions = model.predict(inputs)
     onnx_model = keras2onnx.convert_keras(model, model.name)
     self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files))
        def create_and_check_xlm_qa(
            self,
            config,
            input_ids,
            token_type_ids,
            input_lengths,
            sequence_labels,
            token_labels,
            is_impossible_labels,
            input_mask,
        ):
            model = TFXLMForQuestionAnsweringSimple(config)

            inputs = {"input_ids": input_ids, "lengths": input_lengths}

            start_logits, end_logits = model(inputs)

            result = {
                "start_logits": start_logits.numpy(),
                "end_logits": end_logits.numpy(),
            }

            self.parent.assertListEqual(list(result["start_logits"].shape),
                                        [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].shape),
                                        [self.batch_size, self.seq_length])
 def test_TFXLMForQuestionAnsweringSimple(self):
     from transformers import XLMTokenizer, TFXLMForQuestionAnsweringSimple
     pretrained_weights = 'xlm-mlm-enfr-1024'
     tokenizer = XLMTokenizer.from_pretrained(pretrained_weights)
     text, inputs, inputs_onnx = self._prepare_inputs(tokenizer)
     model = TFXLMForQuestionAnsweringSimple.from_pretrained(
         pretrained_weights)
     predictions = model.predict(inputs)
     onnx_model = keras2onnx.convert_keras(model, model.name)
     self.assertTrue(
         run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx,
                          predictions, self.model_files))
    def create_and_check_xlm_qa(
        self,
        config,
        input_ids,
        token_type_ids,
        input_lengths,
        sequence_labels,
        token_labels,
        is_impossible_labels,
        choice_labels,
        input_mask,
    ):
        model = TFXLMForQuestionAnsweringSimple(config)

        inputs = {"input_ids": input_ids, "lengths": input_lengths}

        result = model(inputs)

        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))