Exemplo n.º 1
0
    def __init__(self, args=None, detect_entities=False):
        if args is None:
            self.args = load_pickle("args.pkl")
        else:
            self.args = args
        self.cuda = torch.cuda.is_available()
        self.detect_entities = detect_entities

        if self.detect_entities:
            self.nlp = spacy.load("en_core_web_lg")
        else:
            self.nlp = None
        self.entities_of_interest = [
            "PERSON",
            "NORP",
            "FAC",
            "ORG",
            "GPE",
            "LOC",
            "PRODUCT",
            "EVENT",
            "WORK_OF_ART",
            "LAW",
            "LANGUAGE",
            "PER",
        ]

        logger.info("Loading tokenizer and model...")
        from .train_funcs import load_state

        if self.args.model_no == 0:
            from model.bert import BertModel as Model

            model = args.model_size  #'bert-base-uncased'
            model_name = "BERT"
            self.net = Model.from_pretrained(
                model,
                force_download=False,
                model_size=args.model_size,
                task="classification",
                n_classes_=self.args.num_classes,
            )
        elif self.args.model_no == 1:
            from model.albert.albert import AlbertModel as Model

            model = args.model_size  #'albert-base-v2'
            model_name = "BERT"
            self.net = Model.from_pretrained(
                model,
                force_download=False,
                model_size=args.model_size,
                task="classification",
                n_classes_=self.args.num_classes,
            )
        elif args.model_no == 2:  # BioBert
            from model.bert import BertModel, BertConfig

            model = "bert-base-uncased"
            model_name = "BioBERT"
            config = BertConfig.from_pretrained(
                "./additional_models/biobert_v1.1_pubmed/bert_config.json"
            )
            self.net = BertModel.from_pretrained(
                pretrained_model_name_or_path="./additional_models/biobert_v1.1_pubmed/biobert_v1.1_pubmed.bin",
                config=config,
                force_download=False,
                model_size="bert-base-uncased",
                task="classification",
                n_classes_=self.args.num_classes,
            )

        self.tokenizer = load_pickle("%s_tokenizer.pkl" % model_name)
        self.net.resize_token_embeddings(len(self.tokenizer))
        if self.cuda:
            self.net.cuda()
        start_epoch, best_pred, amp_checkpoint = load_state(
            self.net, None, None, self.args, load_best=False
        )
        logger.info("Done!")

        self.e1_id = self.tokenizer.convert_tokens_to_ids("[E1]")
        self.e2_id = self.tokenizer.convert_tokens_to_ids("[E2]")
        self.pad_id = self.tokenizer.pad_token_id
        self.rm = load_pickle("relations.pkl")
Exemplo n.º 2
0
    def __init__(self, args=None):
        if args is None:
            self.args = load_pickle("args.pkl")
        else:
            self.args = args
        self.cuda = torch.cuda.is_available()

        if self.args.model_no == 0:
            from model.bert import BertModel as Model
            from model.bert_tokenizer import BertTokenizer as Tokenizer

            model = args.model_size  #'bert-large-uncased' 'bert-base-uncased'
            model_name = "BERT"
            self.net = Model.from_pretrained(
                model,
                force_download=False,
                model_size=args.model_size,
                task="fewrel",
            )
        elif self.args.model_no == 1:
            from model.albert.albert import AlbertModel as Model
            from model.albert.albert_tokenizer import (
                AlbertTokenizer as Tokenizer,
            )

            model = args.model_size  #'albert-base-v2'
            model_name = "BERT"
            self.net = Model.from_pretrained(
                model,
                force_download=False,
                model_size=args.model_size,
                task="fewrel",
            )
        elif args.model_no == 2:  # BioBert
            from model.bert import BertModel, BertConfig
            from model.bert_tokenizer import BertTokenizer as Tokenizer

            model = "bert-base-uncased"
            model_name = "BioBERT"
            config = BertConfig.from_pretrained(
                "./additional_models/biobert_v1.1_pubmed/bert_config.json"
            )
            self.net = BertModel.from_pretrained(
                pretrained_model_name_or_path="./additional_models/biobert_v1.1_pubmed/biobert_v1.1_pubmed.bin",
                config=config,
                force_download=False,
                model_size="bert-base-uncased",
                task="fewrel",
            )

        if os.path.isfile("./data/%s_tokenizer.pkl" % model_name):
            self.tokenizer = load_pickle("%s_tokenizer.pkl" % model_name)
            logger.info("Loaded tokenizer from saved file.")
        else:
            logger.info(
                "Saved tokenizer not found, initializing new tokenizer..."
            )
            if args.model_no == 2:
                self.tokenizer = Tokenizer(
                    vocab_file="./additional_models/biobert_v1.1_pubmed/vocab.txt",
                    do_lower_case=False,
                )
            else:
                self.tokenizer = Tokenizer.from_pretrained(
                    model, do_lower_case=False
                )
            self.tokenizer.add_tokens(
                ["[E1]", "[/E1]", "[E2]", "[/E2]", "[BLANK]"]
            )
            save_as_pickle("%s_tokenizer.pkl" % model_name, self.tokenizer)
            logger.info(
                "Saved %s tokenizer at ./data/%s_tokenizer.pkl"
                % (model_name, model_name)
            )

        self.net.resize_token_embeddings(len(self.tokenizer))
        self.pad_id = self.tokenizer.pad_token_id

        if self.cuda:
            self.net.cuda()

        if self.args.use_pretrained_blanks == 1:
            logger.info(
                "Loading model pre-trained on blanks at ./data/test_checkpoint_%d.pth.tar..."
                % args.model_no
            )
            checkpoint_path = (
                "./data/test_checkpoint_%d.pth.tar" % self.args.model_no
            )
            checkpoint = torch.load(checkpoint_path)
            model_dict = self.net.state_dict()
            pretrained_dict = {
                k: v
                for k, v in checkpoint["state_dict"].items()
                if k in model_dict.keys()
            }
            model_dict.update(pretrained_dict)
            self.net.load_state_dict(pretrained_dict, strict=False)
            del checkpoint, pretrained_dict, model_dict

        logger.info("Loading Fewrel dataloaders...")
        self.train_loader, _, self.train_length, _ = load_dataloaders(args)