def create_and_check_bert_for_question_answering(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = TFBertForQuestionAnswering(config=config)
     inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
     start_logits, end_logits = model(inputs)
     result = {
         "start_logits": start_logits.numpy(),
         "end_logits": end_logits.numpy(),
     }
     self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
     self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
    def create_and_check_bert_for_question_answering(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = TFBertForQuestionAnswering(config=config)
        inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "token_type_ids": token_type_ids,
        }

        result = model(inputs)
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
Exemplo n.º 3
0
# -*- coding: utf-8 -*-
"""
Traverse for text files and question files to put in QA_model testing
@author: Vince
"""
from transformers import BertTokenizer
from transformers.modeling_tf_bert import TFBertForQuestionAnswering
import tensorflow as tf
import os
import os.path
import jellyfish
# Define Model parameters 
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = TFBertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")


# Code to traverse directory of test datasets
# MAKE SURE TO CHANGE DIRECTORY

#dataset_dir = "C:/Users/Admin/Desktop/Acronyms-and-Abbreviation-Expansion/dataset/datagen_p_output"
#answer_dir =  "C:/Users/Admin/Desktop/Acronyms-and-Abbreviation-Expansion/dataset/answer_output/"
#question_dir = "C:/Users/Admin/Desktop/Acronyms-and-Abbreviation-Expansion/dataset/question_output/"

dataset_dir = "E:\\PythonProjects\\acronym-dataset\\smalldata_output\\"
answer_dir = "E:\\PythonProjects\\acronym-dataset\\answer_output\\"
question_dir = "E:\\PythonProjects\\acronym-dataset\\question_output\\" 

answer_file  = ""
question_file = ""

#directory = '/dataset/datagen_output'