Example #1
0
    def create_and_check_for_question_answering(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.return_dict = True
        model = TFLongformerForQuestionAnswering(config=config)
        result = model(
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            start_positions=sequence_labels,
            end_positions=sequence_labels,
        )

        self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
        self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
 def create_and_check_longformer_for_question_answering(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = TFLongformerForQuestionAnswering(config=config)
     loss, start_logits, end_logits = model(
         input_ids,
         attention_mask=input_mask,
         token_type_ids=token_type_ids,
         start_positions=sequence_labels,
         end_positions=sequence_labels,
     )
     result = {
         "loss": loss,
         "start_logits": start_logits,
         "end_logits": end_logits,
     }
     self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length])
     self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length])