def create_and_check_xlnet_qa(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
    ):
        model = TFXLNetForQuestionAnsweringSimple(config)

        inputs = {
            "input_ids": input_ids_1,
            "attention_mask": input_mask,
            "token_type_ids": segment_ids
        }
        result = model(inputs)

        self.parent.assertEqual(result.start_logits.shape,
                                (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape,
                                (self.batch_size, self.seq_length))
        self.parent.assertListEqual(
            [mem.shape for mem in result.mems],
            [(self.seq_length, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )
        def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2,
                                      input_ids_q, perm_mask, input_mask,
                                      target_mapping, segment_ids, lm_labels,
                                      sequence_labels, is_impossible_labels):
            model = TFXLNetForQuestionAnsweringSimple(config)

            inputs = {
                'input_ids': input_ids_1,
                'attention_mask': input_mask,
                'token_type_ids': segment_ids
            }
            start_logits, end_logits, mems = model(inputs)

            result = {
                "start_logits": start_logits.numpy(),
                "end_logits": end_logits.numpy(),
                "mems": [m.numpy() for m in mems],
            }

            self.parent.assertListEqual(list(result["start_logits"].shape),
                                        [self.batch_size, self.seq_length])
            self.parent.assertListEqual(list(result["end_logits"].shape),
                                        [self.batch_size, self.seq_length])
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] *
                self.num_hidden_layers)