示例#1
0
def main():
    opt = {'train': 'train_wiki',
           'val': 'val_wiki',
           'test': 'val_semeval',
           'adv': None,
           'trainN': 10,
           'N': 5,
           'K': 1,
           'Q': 1,
           'batch_size': 1,
           'train_iter': 5,
           'val_iter': 1000,
           'test_iter': 10,
           'val_step': 2000,
           'model': 'pair',
           'encoder': 'bert',
           'max_length': 64,
           'lr': -1,
           'weight_decay': 1e-5,
           'dropout': 0.0,
           'na_rate':0,
           'optim': 'adam',
           'load_ckpt': './src/FewRel-master/Checkpoints/ckpt_5-Way-1-Shot_FewRel.pth',
           'save_ckpt': './src/FewRel-master/Checkpoints/post_ckpt_5-Way-1-Shot_FewRel.pth',
           'fp16':False,
           'only_test': False,
           'ckpt_name': 'Checkpoints/ckpt_5-Way-1-Shot_FewRel.pth',
           'pair': True,
           'pretrain_ckpt': '',
           'cat_entity_rep': False,
           'dot': False,
           'no_dropout': False,
           'mask_entity': False,
           'use_sgd_for_bert': False
           }
    opt = DotMap(opt)
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    encoder_name = opt.encoder
    max_length = opt.max_length
    
    print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K))
    print("model: {}".format(model_name))
    print("encoder: {}".format(encoder_name))
    print("max_length: {}".format(max_length))

    pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
    sentence_encoder = BERTPAIRSentenceEncoder(
            pretrain_ckpt,
            max_length)
    
    train_data_loader = get_loader_pair(opt.train, sentence_encoder,
            N=trainN, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name)
    val_data_loader = get_loader_pair(opt.val, sentence_encoder,
            N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name)
    test_data_loader = get_loader_pair(opt.test, sentence_encoder,
            N=N, K=K, Q=Q, na_rate=opt.na_rate, batch_size=batch_size, encoder_name=encoder_name)
   
    if opt.optim == 'sgd':
        pytorch_optim = optim.SGD
    elif opt.optim == 'adam':
        pytorch_optim = optim.Adam
    elif opt.optim == 'adamw':
        pytorch_optim = AdamW
    else:
        raise NotImplementedError
    framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader)
        
    prefix = '-'.join([model_name, encoder_name, opt.train, opt.val, str(N), str(K)])
    if opt.na_rate != 0:
        prefix += '-na{}'.format(opt.na_rate)
    if opt.dot:
        prefix += '-dot'
    if opt.cat_entity_rep:
        prefix += '-catentity'
    if len(opt.ckpt_name) > 0:
        prefix += '-' + opt.ckpt_name

    model = Pair(sentence_encoder, hidden_size=opt.hidden_size)
    
    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt

    if torch.cuda.is_available():
        model.cuda()

    if not opt.only_test:
        if encoder_name in ['bert', 'roberta']:
            bert_optim = True
        else:
            bert_optim = False

        if opt.lr == -1:
            if bert_optim:
                opt.lr = 2e-5
            else:
                opt.lr = 1e-1

        framework.train(model, prefix, batch_size, trainN, N, K, Q,
                pytorch_optim=pytorch_optim, load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
                na_rate=opt.na_rate, val_step=opt.val_step, fp16=opt.fp16, pair=opt.pair, 
                train_iter=opt.train_iter, val_iter=opt.val_iter, bert_optim=bert_optim, 
                learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert)
    else:
        ckpt = opt.load_ckpt
        if ckpt is None:
            print("Warning: --load_ckpt is not specified. Will load Huggingface pre-trained checkpoint.")
            ckpt = 'none'

    acc = framework.eval(model, batch_size, N, K, Q, opt.test_iter, na_rate=opt.na_rate, ckpt=ckpt, pair=opt.pair)
    print("RESULT: %.2f" % (acc * 100))
class Detector:
    #RUNS USING GPU if available and pytoch has CUDA support

    def __init__(self, chpt_path, max_length=128):
        """
            Initializer
        """
        self.checkpoint_path = chpt_path
        self.bert_pretrained_checkpoint = 'bert-base-uncased'
        self.max_length = max_length
        self.sentence_encoder = BERTPAIRSentenceEncoder(
            self.bert_pretrained_checkpoint, self.max_length)

        self.model = Pair(self.sentence_encoder, hidden_size=768)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        self.model.eval()

        #         self.nlp_coref = spacy.load("en_core_web_sm")
        #         neuralcoref.add_to_pipe(self.nlp_coref)
        self.nlp_no_coref = spacy.load("en_core_web_sm")
        self.load_model()

    def __load_model_from_checkpoint__(self, ckpt):
        '''
        ckpt: Path of the checkpoint
        return: Checkpoint dict
        '''
        if os.path.isfile(ckpt):
            checkpoint = torch.load(ckpt)
            print("Successfully loaded checkpoint '%s'" % ckpt)
            return checkpoint
        else:
            raise Exception("No checkpoint found at '%s'" % ckpt)

    def bert_tokenize(self, tokens, head_indices, tail_indices):
        word = self.sentence_encoder.tokenize(tokens, head_indices,
                                              tail_indices)
        return word

    def load_model(self):
        """
            Loads the model from the checkpoint
        """
        state_dict = self.__load_model_from_checkpoint__(
            self.checkpoint_path)['state_dict']
        own_state = self.model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)

#     def spacy_tokenize_coref(self,sentence):
#         """
#             Tokenizes the sentence using spacy
#         """
#         return list(map(str, self.nlp_coref(sentence)))

    def spacy_tokenize_no_coref(self, sentence):
        """
            Tokenizes the sentence using spacy
        """
        try:
            return list(map(str, self.nlp_no_coref(sentence)))
        except TypeError as e:
            print("problem sentence: '{}'".format(sentence))
            raise e

#     def get_head_tail_pairs(self,sentence):
#         """
#             Gets pairs of heads and tails of named entities so that relation identification can be done on these.
#         """
#         acceptable_entity_types = ['PERSON', 'NORP', 'ORG', 'GPE', 'PRODUCT', 'EVENT', 'LAW', 'LOC', 'FAC']
#         doc = self.nlp_coref(sentence)
#         entity_info = [(X.text, X.label_) for X in doc.ents]
#         entity_info = set(map(lambda x:x[0], filter(lambda x:x[1] in acceptable_entity_types, entity_info)))

#         return combinations(entity_info, 2)

    def _get_indices_alt(self, tokens, tokenized_head, tokenized_tail):
        """
            Alternative implemention for getting the indices of the head and tail if exact matches cannot be done.
        """
        head_indices = None
        tail_indices = None
        for i in range(len(tokens)):
            if tokens[i] in tokenized_head[0] or tokenized_head[0] in tokens[i]:
                broke = False
                for k, j in zip(tokens[i:i + len(tokenized_head)],
                                tokenized_head):
                    if k not in j and j not in k:
                        broke = True
                        break
                if not broke:
                    head_indices = list(range(i, i + len(tokenized_head)))
                    break
        for i in range(len(tokens)):
            if tokens[i] in tokenized_tail[0] or tokenized_tail[0] in tokens[i]:
                broke = False
                for k, j in zip(tokens[i:i + len(tokenized_tail)],
                                tokenized_tail):
                    if k not in j and j not in k:
                        broke = True
                        break
                if not broke:
                    tail_indices = list(range(i, i + len(tokenized_tail)))
                    break
        return head_indices, tail_indices

    def _calculate_conf(self, logits, order, pred):
        exp = list(float(i) for i in logits[0][0])
        exp = [math.exp(i) for i in exp]
        if pred == 'NA':
            return exp[-1] * 100 / sum(exp)
        return exp[order.index(pred)] * 100 / sum(exp)

    def run_detection_algorithm(self, query, relation_data):
        """
            Runs the algorithm/model on the given query using the given support data.
        """
        N = len(relation_data)
        K = len(relation_data[0]['examples'])
        Q = 1
        head = query['head']
        tail = query['tail']
        fusion_set = {'word': [], 'mask': [], 'seg': []}
        tokens = self.spacy_tokenize_no_coref(query['sentence'])

        print("head: '{}' tail: '{}' sentence: '{}'".format(
            head, tail, query['sentence']))

        tokenized_head = self.spacy_tokenize_no_coref(head)
        tokenized_tail = self.spacy_tokenize_no_coref(tail)

        head_indices = None
        tail_indices = None
        for i in range(len(tokens)):
            if tokens[i] == tokenized_head[0] and tokens[
                    i:i + len(tokenized_head)] == tokenized_head:
                head_indices = list(range(i, i + len(tokenized_head)))
                break
        for i in range(len(tokens)):
            if tokens[i] == tokenized_tail[0] and tokens[
                    i:i + len(tokenized_tail)] == tokenized_tail:
                tail_indices = list(range(i, i + len(tokenized_tail)))
                break

        if head_indices is None or tail_indices is None:
            head_indices, tail_indices = self._get_indices_alt(
                tokens, tokenized_head, tokenized_tail)

        if head_indices is None or tail_indices is None:
            print(tokenized_head)
            print(tokenized_tail)
            print(tokens)
            raise ValueError(
                "Head/Tail indices error: head: {} \n tail: {} \n sentence: {}"
                .format(head, tail, query['sentence']))

        bert_query_tokens = self.bert_tokenize(tokens, head_indices,
                                               tail_indices)
        for relation in relation_data:
            for ex in relation['examples']:
                tokens = self.spacy_tokenize_no_coref(ex['sentence'])
                tokenized_head = self.spacy_tokenize_no_coref(
                    ex['head']
                )  #head and tail spelling and punctuation should match the corefered output exactly
                tokenized_tail = self.spacy_tokenize_no_coref(ex['tail'])

                head_indices = None
                tail_indices = None
                for i in range(len(tokens)):
                    if tokens[i] == tokenized_head[0] and tokens[
                            i:i + len(tokenized_head)] == tokenized_head:
                        head_indices = list(range(i, i + len(tokenized_head)))
                        break
                for i in range(len(tokens)):
                    if tokens[i] == tokenized_tail[0] and tokens[
                            i:i + len(tokenized_tail)] == tokenized_tail:
                        tail_indices = list(range(i, i + len(tokenized_tail)))
                        break
                if head_indices is None or tail_indices is None:
                    raise ValueError(
                        "Head/Tail indices error: head: {} \n tail: {} \n sentence: {}"
                        .format(ex['head'], ex['tail'], ex['sentence']))

                bert_relation_example_tokens = self.bert_tokenize(
                    tokens, head_indices, tail_indices)

                SEP = self.sentence_encoder.tokenizer.convert_tokens_to_ids(
                    ['[SEP]'])
                CLS = self.sentence_encoder.tokenizer.convert_tokens_to_ids(
                    ['[CLS]'])
                word_tensor = torch.zeros((self.max_length)).long()

                new_word = CLS + bert_relation_example_tokens + SEP + bert_query_tokens + SEP
                for i in range(min(self.max_length, len(new_word))):
                    word_tensor[i] = new_word[i]
                mask_tensor = torch.zeros((self.max_length)).long()
                mask_tensor[:min(self.max_length, len(new_word))] = 1
                seg_tensor = torch.ones((self.max_length)).long()
                seg_tensor[:min(self.max_length,
                                len(bert_relation_example_tokens) + 1)] = 0
                fusion_set['word'].append(word_tensor)
                fusion_set['mask'].append(mask_tensor)
                fusion_set['seg'].append(seg_tensor)

        fusion_set['word'] = torch.stack(fusion_set['word'])
        fusion_set['seg'] = torch.stack(fusion_set['seg'])
        fusion_set['mask'] = torch.stack(fusion_set['mask'])

        if torch.cuda.is_available():
            fusion_set['word'] = fusion_set['word'].cuda()
            fusion_set['seg'] = fusion_set['seg'].cuda()
            fusion_set['mask'] = fusion_set['mask'].cuda()

        logits, pred = self.model(fusion_set, N, K, Q)
        gc.collect()
        order = list(r['name'] for r in relation_data)
        pred_relation = relation_data[
            pred.item()]['name'] if pred.item() < len(relation_data) else 'NA'
        return {
            'sentence': query['sentence'],
            'head': head,
            'tail': tail,
            'pred_relation': pred_relation,
            'conf': int(self._calculate_conf(logits, order, pred_relation))
        }  #returns (sentence, head, tail, prediction relation name)

    def print_result(self, sentence, head, tail, prediction):
        """
            Helper function to print the results to the stdout.
        """
        print('Sentence: \"{}\", head: \"{}\", tail: \"{}\", prediction: {}'.
              format(sentence, head, tail, prediction))