def __init__(self, args=None, detect_entities=False): if args is None: self.args = load_pickle("args.pkl") else: self.args = args self.cuda = torch.cuda.is_available() self.detect_entities = detect_entities if self.detect_entities: self.nlp = spacy.load("en_core_web_lg") else: self.nlp = None self.entities_of_interest = [ "PERSON", "NORP", "FAC", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "PER", ] logger.info("Loading tokenizer and model...") from .train_funcs import load_state if self.args.model_no == 0: from model.bert import BertModel as Model model = args.model_size #'bert-base-uncased' model_name = "BERT" self.net = Model.from_pretrained( model, force_download=False, model_size=args.model_size, task="classification", n_classes_=self.args.num_classes, ) elif self.args.model_no == 1: from model.albert.albert import AlbertModel as Model model = args.model_size #'albert-base-v2' model_name = "BERT" self.net = Model.from_pretrained( model, force_download=False, model_size=args.model_size, task="classification", n_classes_=self.args.num_classes, ) elif args.model_no == 2: # BioBert from model.bert import BertModel, BertConfig model = "bert-base-uncased" model_name = "BioBERT" config = BertConfig.from_pretrained( "./additional_models/biobert_v1.1_pubmed/bert_config.json" ) self.net = BertModel.from_pretrained( pretrained_model_name_or_path="./additional_models/biobert_v1.1_pubmed/biobert_v1.1_pubmed.bin", config=config, force_download=False, model_size="bert-base-uncased", task="classification", n_classes_=self.args.num_classes, ) self.tokenizer = load_pickle("%s_tokenizer.pkl" % model_name) self.net.resize_token_embeddings(len(self.tokenizer)) if self.cuda: self.net.cuda() start_epoch, best_pred, amp_checkpoint = load_state( self.net, None, None, self.args, load_best=False ) logger.info("Done!") self.e1_id = self.tokenizer.convert_tokens_to_ids("[E1]") self.e2_id = self.tokenizer.convert_tokens_to_ids("[E2]") self.pad_id = self.tokenizer.pad_token_id self.rm = load_pickle("relations.pkl")
def __init__(self, args=None): if args is None: self.args = load_pickle("args.pkl") else: self.args = args self.cuda = torch.cuda.is_available() if self.args.model_no == 0: from model.bert import BertModel as Model from model.bert_tokenizer import BertTokenizer as Tokenizer model = args.model_size #'bert-large-uncased' 'bert-base-uncased' model_name = "BERT" self.net = Model.from_pretrained( model, force_download=False, model_size=args.model_size, task="fewrel", ) elif self.args.model_no == 1: from model.albert.albert import AlbertModel as Model from model.albert.albert_tokenizer import ( AlbertTokenizer as Tokenizer, ) model = args.model_size #'albert-base-v2' model_name = "BERT" self.net = Model.from_pretrained( model, force_download=False, model_size=args.model_size, task="fewrel", ) elif args.model_no == 2: # BioBert from model.bert import BertModel, BertConfig from model.bert_tokenizer import BertTokenizer as Tokenizer model = "bert-base-uncased" model_name = "BioBERT" config = BertConfig.from_pretrained( "./additional_models/biobert_v1.1_pubmed/bert_config.json" ) self.net = BertModel.from_pretrained( pretrained_model_name_or_path="./additional_models/biobert_v1.1_pubmed/biobert_v1.1_pubmed.bin", config=config, force_download=False, model_size="bert-base-uncased", task="fewrel", ) if os.path.isfile("./data/%s_tokenizer.pkl" % model_name): self.tokenizer = load_pickle("%s_tokenizer.pkl" % model_name) logger.info("Loaded tokenizer from saved file.") else: logger.info( "Saved tokenizer not found, initializing new tokenizer..." ) if args.model_no == 2: self.tokenizer = Tokenizer( vocab_file="./additional_models/biobert_v1.1_pubmed/vocab.txt", do_lower_case=False, ) else: self.tokenizer = Tokenizer.from_pretrained( model, do_lower_case=False ) self.tokenizer.add_tokens( ["[E1]", "[/E1]", "[E2]", "[/E2]", "[BLANK]"] ) save_as_pickle("%s_tokenizer.pkl" % model_name, self.tokenizer) logger.info( "Saved %s tokenizer at ./data/%s_tokenizer.pkl" % (model_name, model_name) ) self.net.resize_token_embeddings(len(self.tokenizer)) self.pad_id = self.tokenizer.pad_token_id if self.cuda: self.net.cuda() if self.args.use_pretrained_blanks == 1: logger.info( "Loading model pre-trained on blanks at ./data/test_checkpoint_%d.pth.tar..." % args.model_no ) checkpoint_path = ( "./data/test_checkpoint_%d.pth.tar" % self.args.model_no ) checkpoint = torch.load(checkpoint_path) model_dict = self.net.state_dict() pretrained_dict = { k: v for k, v in checkpoint["state_dict"].items() if k in model_dict.keys() } model_dict.update(pretrained_dict) self.net.load_state_dict(pretrained_dict, strict=False) del checkpoint, pretrained_dict, model_dict logger.info("Loading Fewrel dataloaders...") self.train_loader, _, self.train_length, _ = load_dataloaders(args)