Beispiel #1
0
    def __init__(self,
                 args,
                 hidden_size=768,
                 num_classes=6,
                 discriminator_lambda=0.01,
                 checkpoint_path: str = None,
                 load_path: str = None):
        super(AdversarialModel, self).__init__()
        self.args = args
        # Load models
        if checkpoint_path:
            self.qa_model = DistilBertForQuestionAnswering.from_pretrained(
                checkpoint_path)
        else:
            self.qa_model = DistilBertForQuestionAnswering.from_pretrained(
                'distilbert-base-uncased')

        self.discriminator_model = DiscriminatorModel()

        # Set fields
        self.num_classes = num_classes
        self.discriminator_lambda = discriminator_lambda

        # Create output layer
        self.qa_outputs = nn.Linear(hidden_size, 2)
        self.qa_outputs.weight.data.normal_(mean=0.0, std=0.02)
        self.qa_outputs.bias.data.zero_()

        if load_path is not None:
            self.load_state_dict(
                torch.load(load_path,
                           map_location=lambda storage, loc: storage))
Beispiel #2
0
def ensure_models():
    try:
        tokenizer = DistilBertTokenizer.from_pretrained(
            MODEL_PATH, return_token_type_ids=True)
        model = DistilBertForQuestionAnswering.from_pretrained(MODEL_PATH)
    except:
        if Path(MODEL_PATH).is_dir():
            rmtree(MODEL_PATH)
        makedirs(MODEL_PATH)
        tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', return_token_type_ids=True)
        tokenizer.save_pretrained(MODEL_PATH)
        model = DistilBertForQuestionAnswering.from_pretrained(
            'distilbert-base-uncased-distilled-squad')
        model.save_pretrained(MODEL_PATH)
Beispiel #3
0
    def __init__(self, threshold=0.75):
        '''
      threshold: quality control. discard the question if overlapping < threshold
      
      question generation & answer extracting models can be downloaded at https://drive.google.com/uc?id=1vhsDOW9wUUO83IQasTPlkxb82yxmMH-V

      question answering model can be downloaded at https://huggingface.co/distilbert-base-cased-distilled-squad
      ps: tokenizer can't be downloaded directly
    '''
        self._threshold = threshold
        # self._dir = model_dir

        self.que_model = T5ForConditionalGeneration.from_pretrained(
            't5_que_gen_model/t5_base_que_gen/')
        self.ans_model = T5ForConditionalGeneration.from_pretrained(
            't5_ans_gen_model/t5_base_ans_gen/')
        self.qa_model = DistilBertForQuestionAnswering.from_pretrained(
            'distilbert-base-cased-distilled-squad')

        self.que_tokenizer = T5Tokenizer.from_pretrained(
            't5_que_gen_model/t5_base_tok_que_gen/')
        self.ans_tokenizer = T5Tokenizer.from_pretrained(
            't5_ans_gen_model/t5_base_tok_ans_gen/')
        self.qa_tokenizer = DistilBertTokenizer.from_pretrained(
            "distilbert-base-cased-distilled-squad")

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.que_model = self.que_model.to(self.device)
        self.ans_model = self.ans_model.to(self.device)
        self.qa_model = self.qa_model.to(self.device)
Beispiel #4
0
def answergen(context, question):

    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased',
                                                    return_token_type_ids=True)
    model = DistilBertForQuestionAnswering.from_pretrained(
        'distilbert-base-uncased-distilled-squad')

    encoding = tokenizer.encode_plus(question, context)

    input_ids, attention_mask = encoding["input_ids"], encoding[
        "attention_mask"]

    start_scores, end_scores = model(torch.tensor([input_ids]),
                                     attention_mask=torch.tensor(
                                         [attention_mask]))

    ans_tokens = input_ids[torch.argmax(start_scores
                                        ):torch.argmax(end_scores) + 1]
    answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens,
                                                    skip_special_tokens=True)

    # print ("\nQuestion ",question)
    #print ("\nAnswer Tokens: ")
    #print (answer_tokens)

    answer_tokens_to_string = tokenizer.convert_tokens_to_string(answer_tokens)
    #print ("\nAnswer : ",answer_tokens_to_string)
    return answer_tokens_to_string
    def __init__(self, context):
        super(SimpleAnswerExtraction, self).__init__(context)

        self.tokenizer = DistilBertTokenizer.from_pretrained(
            "distilbert-base-uncased-distilled-squad")
        self.model = DistilBertForQuestionAnswering.from_pretrained(
            "distilbert-base-uncased-distilled-squad")
Beispiel #6
0
    def __init__(self,
                 num_classes=6,
                 hidden_size=768,
                 num_layers=3,
                 dropout=0.1,
                 dis_lambda=0.5,
                 concat=False,
                 anneal=False):
        super(DomainQA, self).__init__()

        self.distilbert = DistilBertForQuestionAnswering.from_pretrained(
            'distilbert-base-uncased')
        self.config = self.distilbert.config
        self.config.output_hidden_states = True
        self.config.output_attentions = True
        self.config.output_scores = True
        self.WEIGHTS_NAME = "DistillBert_DANN"

        if concat:
            input_size = 2 * hidden_size
        else:
            input_size = hidden_size

        self.discriminator = DomainDiscriminator(num_classes, input_size,
                                                 hidden_size, num_layers,
                                                 dropout)
        self.num_classes = num_classes
        self.dis_lambda = dis_lambda
        self.anneal = anneal
        self.concat = concat
        self.sep_id = 102
Beispiel #7
0
 def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
     model = DistilBertForQuestionAnswering(config=config)
     model.eval()
     loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels)
     result = {
         "loss": loss,
         "start_logits": start_logits,
         "end_logits": end_logits,
     }
     self.parent.assertListEqual(
         list(result["start_logits"].size()),
         [self.batch_size, self.seq_length])
     self.parent.assertListEqual(
         list(result["end_logits"].size()),
         [self.batch_size, self.seq_length])
     self.check_loss_output(result)
Beispiel #8
0
 def model_load(self, path: str):
     config = DistilBertConfig.from_pretrained(path + "/config.json")
     tokenizer = DistilBertTokenizer.from_pretrained(
         path, do_lower_case=self.do_lower_case)
     model = DistilBertForQuestionAnswering.from_pretrained(path,
                                                            from_tf=False,
                                                            config=config)
     return model, tokenizer
Beispiel #9
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.model = DistilBertForQuestionAnswering.from_pretrained(
         self.model_dir)
     self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_dir)
     self.device = torch.device(
         "cuda" if torch.cuda.is_available() else "cpu")
     self.model.to(self.device)
 def __init__(self,
              model_name="distilbert-base-uncased-distilled-squad",
              device="cuda"):
     super().__init__()
     self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
     self.model = DistilBertForQuestionAnswering.from_pretrained(model_name)
     self.device = device
     self.model = self.model.to(self.device)
Beispiel #11
0
 def __init__(self, hidden_size=768):
     super(DistilBertEncoder, self).__init__()
     self.distilbert = DistilBertForQuestionAnswering.from_pretrained(
         'distilbert-base-uncased')
     self.config = self.distilbert.config
     self.config.output_hidden_states = True
     self.config.output_attentions = True
     self.config.output_scores = True
     self.pooler = nn.Linear(hidden_size, hidden_size)
Beispiel #12
0
def load_model():
    # model_path = '/Users/neelbhandari/Downloads/distilbert_weights'
    model_path='distilbert_weights'
# load model
    model = DistilBertForQuestionAnswering.from_pretrained(model_path)
    model.to(device)
    model.eval()
    model.zero_grad()
    return model
Beispiel #13
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = set_logger(self.model_dir, verbose=True)
        self.logger.info('Distil Loading from checkpoint %s' % self.model_dir)

        self.model = DistilBertForQuestionAnswering.from_pretrained(self.model_dir)
        self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_dir)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
Beispiel #14
0
    def model_load(self, path):

        s3_model_url = 'https://distilbert-finetuned-model.s3.eu-west-2.amazonaws.com/pytorch_model.bin'
        path_to_model = download_model(s3_model_url, model_name="pytorch_model.bin")

        config = DistilBertConfig.from_pretrained(path + "/config.json")
        tokenizer = DistilBertTokenizer.from_pretrained(path, do_lower_case=self.do_lower_case)
        model = DistilBertForQuestionAnswering.from_pretrained(path_to_model, from_tf=False, config=config)

        return model, tokenizer
Beispiel #15
0
    def __init__(self):
        self.logger = logging.getLogger(__name__)
        self.logger.info("Initializing " + __name__)

        self.logger.info("started initializing tokenizer")
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', return_token_type_ids=True)
        self.logger.info("started initializing model")
        self.model = DistilBertForQuestionAnswering.from_pretrained(
            'distilbert-base-uncased-distilled-squad')
        self.logger.info("Reader initialized successfully")
 def __init__(self):
     MycroftSkill.__init__(self)
     # Use this instead of 'print' within a skill
     self.log.info("Loading Distilbert")
     # Initialize language generation model
     self.model = DistilBertForQuestionAnswering.from_pretrained(
         'distilbert-base-uncased-distilled-squad')
     # Initialise a tokenizer
     self.tokenizer = DistilBertTokenizer.from_pretrained(
         'distilbert-base-uncased-distilled-squad')
     self.log.info("Distilbert Loaded Successfully")
Beispiel #17
0
def download_and_save_DistilBERT_model(name):
    """Download and save DistilBERT transformer model to MODELS_DIR"""
    print(f"Downloading: {name}")
    try:
        os.makedirs(MODELS_DIR + f"{name}")
    except FileExistsError as _e:
        pass
    model = DistilBertForQuestionAnswering.from_pretrained(f"{name}")
    tokenizer = DistilBertTokenizer.from_pretrained(f"{name}")
    model.save_pretrained(MODELS_DIR + f"{name}")
    tokenizer.save_pretrained(MODELS_DIR + f"{name}")
    return
 def create_and_check_distilbert_for_question_answering(
     self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = DistilBertForQuestionAnswering(config=config)
     model.to(torch_device)
     model.eval()
     result = model(
         input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
     )
     self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
     self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
Beispiel #19
0
    def getAnswer(self, question, questionContext):
        distilBertTokenizer = DistilBertTokenizer.from_pretrained(
            'distilbert-base-uncased', return_token_type_ids=True)
        distilBertForQuestionAnswering = DistilBertForQuestionAnswering.from_pretrained(
            'distilbert-base-uncased-distilled-squad')

        encodings = distilBertTokenizer.encode_plus(question, questionContext)

        inputIds, attentionMask = encodings["input_ids"], encodings[
            "attention_mask"]

        scoresStart, scoresEnd = distilBertForQuestionAnswering(
            torch.tensor([inputIds]),
            attention_mask=torch.tensor([attentionMask]))

        tokens = inputIds[torch.argmax(scoresStart):torch.argmax(scoresEnd) +
                          1]
        answerTokens = distilBertTokenizer.convert_ids_to_tokens(
            tokens, skip_special_tokens=True)
        return distilBertTokenizer.convert_tokens_to_string(answerTokens)
Beispiel #20
0
def distilbert_experiment(args, sketch_schedule, tokenizer, train_dataset):
    model_dir = 'drive/distilbert_uncased_output'
    tokenizer_dir = 'drive/distilbert_uncased_output'
    model = DistilBertForQuestionAnswering.from_pretrained(model_dir)
    tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_dir)

    for i in range(len(model.distilbert.transformer.layer)):
        if sketch_schedule[i] < 0:
            continue
        output_m = sketch_schedule[i]
        model.distilbert.transformer.layer[i].attention.q_lin = SketchedLinear(
            model.distilbert.transformer.layer[i].attention.q_lin, m=256)
        model.distilbert.transformer.layer[i].attention.k_lin = SketchedLinear(
            model.distilbert.transformer.layer[i].attention.k_lin, m=256)
        model.distilbert.transformer.layer[i].attention.v_lin = SketchedLinear(
            model.distilbert.transformer.layer[i].attention.v_lin, m=256)

        model.distilbert.transformer.layer[i].ffn.lin2 = SketchedLinear(
            model.distilbert.transformer.layer[i].ffn.lin2, m=output_m)
    model = model.cuda()
    global_step, tr_loss = train(args, train_dataset, model, tokenizer)
    return evaluate(args, model, tokenizer)
Beispiel #21
0
 def create_model(self, transformer="longformer"):
     if transformer == "distilbert":
         from transformers import DistilBertForQuestionAnswering
         self.model = DistilBertForQuestionAnswering.from_pretrained(
             "distilbert-base-uncased")
     elif transformer == "bert":
         from transformers import BertForQuestionAnswering
         self.model = BertForQuestionAnswering.from_pretrained(
             "bert-base-uncased")
     elif transformer == "roberta":
         from transformers import RobertaForQuestionAnswering
         self.model = RobertaForQuestionAnswering.from_pretrained(
             "roberta-base")
     elif transformer == "roberta_squad":
         from transformers import RobertaForQuestionAnswering
         self.model = RobertaForQuestionAnswering.from_pretrained(
             "deepset/roberta-base-squad2")
     elif transformer == "longformer":
         from transformers import LongformerForQuestionAnswering
         self.model = LongformerForQuestionAnswering.from_pretrained(
             "allenai/longformer-base-4096")
     elif transformer == "bart":
         from transformers import BartForQuestionAnswering
         self.model = BartForQuestionAnswering.from_pretrained(
             "facebook/bart-base")
     elif transformer == "electra":
         from transformers import ElectraForQuestionAnswering
         self.model = ElectraForQuestionAnswering.from_pretrained(
             "google/electra-small-discriminator")
     else:
         print(
             "The model you chose is not available in this version. You can try to manually change the code or manually overwrite the variable self.model"
         )
         print(
             "The available choices are 'distilbert' , 'bert' , 'roberta' , 'longformer' , 'bart' , 'electra' "
         )
Beispiel #22
0
    return contexts, questions, ids


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    torch.cuda.set_device(DEVICE_ID)  # use an unoccupied GPU
'''
load data
'''
val_contexts, val_questions, val_ids = read_squad('data/dev-v2.0.json')
'''
tokenizers and models
'''
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForQuestionAnswering.from_pretrained(
    'distilbert-base-uncased').to(device)
model.load_state_dict(
    torch.load(os.path.join('model_weights',
                            f'distilBERT_epoch_{NUM_EPOCH}.pt'),
               map_location=device))

model.eval()

res = dict()
with torch.no_grad():
    for i, (context, question,
            id) in tqdm(enumerate(zip(val_contexts, val_questions, val_ids))):
        encoding = tokenizer(context,
                             question,
                             return_tensors='pt',
                             truncation=True)
Beispiel #23
0
def main():

    # define parser and arguments
    args = get_train_test_args()

    util.set_seed(args.seed)

    if args.mixture_of_experts and args.do_eval:
        model = MoE(load_gate=True)
        experts = True
        model.gate.eval()

    elif args.mixture_of_experts and args.do_train:
        model = MoE(load_gate=False)
        experts = True

    else:
        model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
        experts = False

    if args.reinit > 0:
        transformer_temp = getattr(model, 'distilbert')
        for layer in transformer_temp.transformer.layer[-args.reinit:]:
            for module in layer.modules():
                print(type(module))
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=transformer_temp.config.initializer_range)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()

    tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

    if args.do_train:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        args.save_dir = util.get_save_dir(args.save_dir, args.run_name)
        log = util.get_logger(args.save_dir, 'log_train')
        log.info(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
        log.info("Preparing Training Data...")
        args.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        trainer = Trainer(args, log)
        train_dataset, _ = get_dataset(args, args.train_datasets, args.train_dir, tokenizer, 'train')
        log.info("Preparing Validation Data...")
        val_dataset, val_dict = get_dataset(args, args.train_datasets, args.val_dir, tokenizer, 'val')
        train_loader = DataLoader(train_dataset,
                                batch_size=args.batch_size,
                                sampler=RandomSampler(train_dataset))
        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                sampler=SequentialSampler(val_dataset))
        best_scores = trainer.train(model, train_loader, val_loader, val_dict, experts)
    if args.do_eval:
        args.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        split_name = 'test' if 'test' in args.eval_dir else 'validation'
        log = util.get_logger(args.save_dir, f'log_{split_name}')
        trainer = Trainer(args, log)
        if args.mixture_of_experts is False:
            checkpoint_path = os.path.join(args.save_dir, 'checkpoint')
            model = DistilBertForQuestionAnswering.from_pretrained(checkpoint_path)
            model.to(args.device)
        eval_dataset, eval_dict = get_dataset(args, args.eval_datasets, args.eval_dir, tokenizer, split_name)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 sampler=SequentialSampler(eval_dataset))
    
        eval_preds, eval_scores = trainer.evaluate(model, eval_loader,
                                                   eval_dict, return_preds=True,
                                                   split=split_name, MoE = True)

        results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in eval_scores.items())
        log.info(f'Eval {results_str}')
        # Write submission file
        sub_path = os.path.join(args.save_dir, split_name + '_' + args.sub_file)
        log.info(f'Writing submission file to {sub_path}...')
        with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
            csv_writer = csv.writer(csv_fh, delimiter=',')
            csv_writer.writerow(['Id', 'Predicted'])
            for uuid in sorted(eval_preds):
                csv_writer.writerow([uuid, eval_preds[uuid]])
Beispiel #24
0
                v['abstract_full'] += p['text'] + '\n\n'
            
        # looks like in some cases the abstract can be straight up text so we can actually leave that alone
        if isinstance(abs_dirty, str):
            v['abstract_paragraphs'].append(abs_dirty)
            v['abstract_full'] += abs_dirty + '\n\n'            
            
    if print_current:
        if limit_print: print_current = False


import torch
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizer
import numpy as np

model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')

def makeBERTSQuADPrediction(model, document, question):
    input_ids = tokenizer.encode(question, document)
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    sep_index = input_ids.index(tokenizer.sep_token_id)
    num_seg_a = sep_index + 1
    num_seg_b = len(input_ids) - num_seg_a
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    assert len(segment_ids) == len(input_ids)
    n_ids = len(segment_ids)
    # TODO: This was stolen from example. Figure out what start positions and 
    # end positions mean here -  weonly chose the first 512 tokens because lazy
    start_scores, end_scores = model(torch.tensor([input_ids[:512]]))
    answer_start = torch.argmax(start_scores)
Beispiel #25
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.model = DistilBertForQuestionAnswering.from_pretrained(
         self.model_dir)
     self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_dir)
Beispiel #26
0
dev_dataset = SquadDataset(dev_encodings)

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(train_contexts,
                            train_questions,
                            truncation=True,
                            padding=True)
dev_encodings = tokenizer(dev_contexts,
                          dev_questions,
                          truncation=True,
                          padding=True)

from transformers import DistilBertForQuestionAnswering
model = DistilBertForQuestionAnswering.from_pretrained(
    "distilbert-base-uncased")

from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',  # output directory
    num_train_epochs=1,  # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,  # batch size for evaluation
    warmup_steps=100,  # number of warmup steps for learning rate scheduler
    weight_decay=0.01,  # strength of weight decay
    logging_dir='./logs',  # directory for storing logs
    logging_steps=10,
    load_best_model_at_end=True,
    save_strategy="steps",
    logging_strategy="steps",
Beispiel #27
0
# argument parsing
app = Flask(__name__)
api = Api(app)
parser = reqparse.RequestParser()
parser.add_argument('question')

N_HITS = 10
# TODO: Analyse the hard-coded keywords and assess if anything needs to change here.
KEYWORDS = ''
# LUCENE_DATABASE_DIR = '/mnt/lucene-database'
LUCENE_DATABASE_PATH = 'lucene-index-covid-2020-04-10'

# Load these models locally - distilbert-base-uncased-distilled-squad
DISTILBERT_MODEL_PATH = 'distilbert-base-uncased-distilled-squad'
model = DistilBertForQuestionAnswering.from_pretrained(DISTILBERT_MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(DISTILBERT_MODEL_PATH)

# document = "Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision."
# input_ids = tokenizer.encode('Why is this strange thing here?')
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
# start_scores, end_scores = model(torch.tensor([input_ids[:512]]))


def makeBERTSQuADPrediction(model, document, question):
    input_ids = tokenizer.encode(question, document)
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    sep_index = input_ids.index(tokenizer.sep_token_id)
    num_seg_a = sep_index + 1
    num_seg_b = len(input_ids) - num_seg_a
Beispiel #28
0
def main():
    # define parser and arguments
    args = get_train_test_args()

    util.set_seed(args.seed)
    model = DistilBertForQuestionAnswering.from_pretrained(
        "distilbert-base-uncased")
    tokenizer = DistilBertTokenizerFast.from_pretrained(
        'distilbert-base-uncased')

    if args.do_train:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        if args.resume_training:
            checkpoint_path = os.path.join(args.save_dir, 'checkpoint')
            model = DistilBertForQuestionAnswering.from_pretrained(
                checkpoint_path)
            model.to(args.device)
        else:
            args.save_dir = util.get_save_dir(args.save_dir, args.run_name)
        log = util.get_logger(args.save_dir, 'log_train')
        log.info(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
        log.info("Preparing Training Data...")
        args.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        trainer = Trainer(args, log)
        train_dataset, _ = get_dataset(args, args.train_datasets,
                                       args.train_dir, tokenizer, 'train',
                                       args.outdomain_data_repeat)
        log.info("Preparing Validation Data...")
        val_dataset, val_dict = get_dataset(args, args.train_datasets,
                                            args.val_dir, tokenizer, 'val',
                                            args.outdomain_data_repeat)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  sampler=RandomSampler(train_dataset))
        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                sampler=RandomSampler(val_dataset))
        best_scores = trainer.train(model, train_loader, val_loader, val_dict)
    if args.do_eval:
        args.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        split_name = 'test' if 'test' in args.eval_dir else 'validation'
        log = util.get_logger(args.save_dir, f'log_{split_name}')
        trainer = Trainer(args, log)
        checkpoint_path = os.path.join(args.save_dir, 'checkpoint')
        model = DistilBertForQuestionAnswering.from_pretrained(checkpoint_path)
        discriminator_input_size = 768
        if args.full_adv:
            discriminator_input_size = 384 * 768
        discriminator = DomainDiscriminator(
            input_size=discriminator_input_size)
        # discriminator.load_state_dict(torch.load(checkpoint_path + '/discriminator'))
        model.to(args.device)
        discriminator.to(args.device)
        eval_dataset, eval_dict = get_dataset(args, args.eval_datasets,
                                              args.eval_dir, tokenizer,
                                              split_name,
                                              args.outdomain_data_repeat)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 sampler=SequentialSampler(eval_dataset))
        eval_preds, eval_scores = trainer.evaluate(model,
                                                   discriminator,
                                                   eval_loader,
                                                   eval_dict,
                                                   return_preds=True,
                                                   split=split_name)
        results_str = ', '.join(f'{k}: {v:05.2f}'
                                for k, v in eval_scores.items())
        log.info(f'Eval {results_str}')
        # Write submission file
        sub_path = os.path.join(args.save_dir,
                                split_name + '_' + args.sub_file)
        log.info(f'Writing submission file to {sub_path}...')
        with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
            csv_writer = csv.writer(csv_fh, delimiter=',')
            csv_writer.writerow(['Id', 'Predicted'])
            for uuid in sorted(eval_preds):
                csv_writer.writerow([uuid, eval_preds[uuid]])
def main():
    # define parser and arguments
    args = get_train_test_args()

    util.set_seed(args.seed)

    #### Change Made By Xuran Wang: Comment out original lines #######

    # model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
    # tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

    #### Change End #######

    #### Change Made By Xuran Wang: Add custom lines #######

    tokenizer = DistilBertTokenizerFast.from_pretrained(
        'distilbert-base-uncased')
    finetuned_model_path = 'save/baseline-01/'

    #### Change End #######

    if args.do_train:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        args.save_dir = util.get_save_dir(args.save_dir, args.run_name)
        log = util.get_logger(args.save_dir, 'log_train')
        log.info(f'Args: {json.dumps(vars(args), indent=4, sort_keys=True)}')
        log.info("Preparing Training Data...")

        #### Change Made By Xuran Wang: Add custom lines #######

        checkpoint_path = os.path.join(finetuned_model_path, 'checkpoint')
        model = DistilBertForQuestionAnswering.from_pretrained(checkpoint_path)

        #### Change End #######
        '''###'''
        # if args.reinit_pooler:
        #     encoder_temp = getattr(model, "distilbert")  # Equivalent to model.distilbert
        #     encoder_temp.pooler.dense.weight.data.normal_(mean=0.0, std=encoder_temp.config.initializer_range)
        #     encoder_temp.pooler.dense.bias.data.zero_()  # The change of encoder_temp would affect the model
        #     for p in encoder_temp.pooler.parameters():
        #         p.requires_grad = True

        if args.reinit_layers > 0:
            import torch.nn as nn
            from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, FFN
            # model_distilbert = getattr(model, "distilbert")  # model.distilbert; change of model_distilbert affects model!
            # Reinitialization for the last few layers
            for layer in model.distilbert.transformer.layer[-args.
                                                            reinit_layers:]:
                for module in layer.modules():
                    # print(module)
                    model.distilbert._init_weights(
                        module)  # It's the line equivalent to below approach
                    # if isinstance(module, nn.modules.linear.Linear):  # Original form for nn.Linear
                    #     # model.config.initializer_range == model.distilbert.config.initializer_range => True
                    #     module.weight.data.normal_(mean=0.0, std=model.distilbert.config.initializer_range)
                    #     if module.bias is not None:
                    #         module.bias.data.zero_()
                    # elif isinstance(module, nn.modules.normalization.LayerNorm):
                    #     module.weight.data.fill_(1.0)
                    #     module.bias.data.zero_()
                    # elif isinstance(module, FFN):
                    #     for param in [module.lin1, module.lin2]:
                    #         param.weight.data.normal_(mean=0.0, std=model.distilbert.config.initializer_range)
                    #         if param.bias is not None:
                    #             param.bias.data.zero_()
                    # elif isinstance(module, MultiHeadSelfAttention):
                    #     for param in [module.q_lin, module.k_lin, module.v_lin, module.out_lin]:
                    #         param.data.weight.normal_(mean=0.0, std=model.distilbert.config.initializer_range)
                    #         if param.bias is not None:
                    #             param.bias.data.zero_()

        args.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(args.device)

        trainer = Trainer(args, log)

        #### Change Made By Xuran Wang: Add custom lines, comment out original line #######

        # train_dataset, _ = get_dataset(args, args.train_datasets, args.train_dir, tokenizer, 'train')

        train_dataset, _ = get_dataset_eda_revised(args, args.train_datasets,
                                                   args.train_dir, tokenizer,
                                                   'train', train_fraction)

        #### Change End #######

        log.info("Preparing Validation Data...")
        val_dataset, val_dict = get_dataset(args, args.train_datasets,
                                            args.val_dir, tokenizer, 'val')
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  sampler=RandomSampler(train_dataset))
        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                sampler=SequentialSampler(val_dataset))
        best_scores = trainer.train(model, train_loader, val_loader, val_dict)
    if args.do_eval:
        args.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        split_name = 'test' if 'test' in args.eval_dir else 'validation'
        log = util.get_logger(args.save_dir, f'log_{split_name}')
        trainer = Trainer(args, log)
        checkpoint_path = os.path.join(args.save_dir, 'checkpoint')
        model = DistilBertForQuestionAnswering.from_pretrained(checkpoint_path)
        model.to(args.device)
        eval_dataset, eval_dict = get_dataset(args, args.eval_datasets,
                                              args.eval_dir, tokenizer,
                                              split_name)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 sampler=SequentialSampler(eval_dataset))
        eval_preds, eval_scores = trainer.evaluate(model,
                                                   eval_loader,
                                                   eval_dict,
                                                   return_preds=True,
                                                   split=split_name)
        results_str = ', '.join(f'{k}: {v:05.2f}'
                                for k, v in eval_scores.items())
        log.info(f'Eval {results_str}')
        # Write submission file
        sub_path = os.path.join(args.save_dir,
                                split_name + '_' + args.sub_file)
        log.info(f'Writing submission file to {sub_path}...')
        with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
            csv_writer = csv.writer(csv_fh, delimiter=',')
            csv_writer.writerow(['Id', 'Predicted'])
            for uuid in sorted(eval_preds):
                csv_writer.writerow([uuid, eval_preds[uuid]])
def main():
    """
    Main function
    """

    # Parse cmd line arguments
    args = nlp_parser.parse_arguments()

    source = ""
    subject = ""
    context = ""
    question = ""
    answer = ""

    # Setup the question, either from a specified SQuAD record
    # or from cmd line arguments.
    # If no question details are provided, a random
    # SQuAD example will be chosen.
    if args["question"] is not None:
        question = args["question"]
        if args["text"] is not None:
            source = args["text"]
            with open(source, "r") as text_file_handle:
                context = text_file_handle.read()

        else:
            print("No text provided, searching SQuAD dev-2.0 dataset")
            squad_data = nlp.import_squad_data()
            squad_records = squad_data.loc[squad_data["question"] == question]
            if squad_records.empty:
                sys.exit(
                    "Question not found in SQuAD data, please provide context using `--text`."
                )
            subject = squad_records["subject"].iloc[0]
            context = squad_records["context"].iloc[0]
            question = squad_records["question"].iloc[0]
            answer = squad_records["answer"]

    else:
        squad_data = nlp.import_squad_data()

        if args["squadid"] is not None:
            source = args["squadid"]
            squad_records = squad_data.loc[squad_data["id"] == source]
            i_record = 0
        else:
            if args["subject"] is not None:
                print(
                    "Picking a question at random on the subject: ",
                    args["subject"],
                )
                squad_records = squad_data.loc[
                    squad_data["subject"] == args["subject"]
                ]
            else:
                print(
                    "No SQuAD ID or question provided, picking one at random!"
                )
                squad_records = squad_data

            n_records = len(squad_records.index)
            i_record = random.randint(0, max(0, n_records - 1))

        if squad_records.empty:
            sys.exit(
                "No questions found in SQuAD data, please provide valid ID or subject."
            )

        n_records = len(squad_records.index)
        i_record = random.randint(0, n_records - 1)
        source = squad_records["id"].iloc[i_record]
        subject = squad_records["subject"].iloc[i_record]
        context = squad_records["context"].iloc[i_record]
        question = squad_records["question"].iloc[i_record]
        answer = squad_records["answer"].iloc[i_record]

    # DistilBERT question answering using pre-trained model.
    token = DistilBertTokenizer.from_pretrained(
        "distilbert-base-uncased", return_token_type_ids=True
    )

    model = DistilBertForQuestionAnswering.from_pretrained(
        "distilbert-base-uncased-distilled-squad"
    )

    encoding = token.encode_plus(question, context)

    input_ids, attention_mask = (
        encoding["input_ids"],
        encoding["attention_mask"],
    )
    start_scores, end_scores = model(
        torch.tensor([input_ids]),
        attention_mask=torch.tensor([attention_mask]),
        return_dict=False,
    )

    answer_ids = input_ids[
        torch.argmax(start_scores) : torch.argmax(end_scores) + 1
    ]
    answer_tokens = token.convert_ids_to_tokens(
        answer_ids, skip_special_tokens=True
    )
    answer_tokens_to_string = token.convert_tokens_to_string(answer_tokens)

    # Display results
    print("\nDistilBERT question answering example.")
    print("======================================")
    print("Reading from: ", subject, source)
    print("\nContext: ", context)
    print("--")
    print("Question: ", question)
    print("Answer: ", answer_tokens_to_string)
    print("Reference Answers: ", answer)