コード例 #1
0
 def __init__(self, args, char2idx, word2idx):
     self.bert = args.use_bert
     self.char2idx = char2idx
     self.word2idx = word2idx
     self.max_len = args.max_len
     if self.bert:
         self.tokenizer = Tokenizer(args.bert_model + '/vocab.txt', do_lower_case=True)
コード例 #2
0
 def __init__(self, args, token2idx_dict):
     self.bert = args.use_bert
     self.token2idx_dict = token2idx_dict
     self.max_len = args.max_len
     if self.bert:
         # self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
         self.tokenizer = Tokenizer(args.bert_model + '/vocab.txt', do_lower_case=True)
コード例 #3
0
    def __init__(self, args, data_loaders, examples, char_emb, spo_conf):
        print('using ad')
        self.args = args
        self.tokenizer = Tokenizer(args.bert_model + '/vocab.txt',
                                   do_lower_case=True)
        self.device = torch.device("cuda:{}".format(args.device_id) if torch.
                                   cuda.is_available() else "cpu")
        self.n_gpu = torch.cuda.device_count()

        self.id2rel = {item: key for key, item in spo_conf.items()}
        self.rel2id = spo_conf

        if self.n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)
        if args.use_bert:
            self.model = bert_mpn.ERENet.from_pretrained(
                args.bert_model, classes_num=len(spo_conf))
        else:
            self.model = mpn.ERENet(args, char_emb, spo_conf)

        self.model.to(self.device)
        self.resume(args)
        logging.info('total gpu num is {}'.format(self.n_gpu))
        if self.n_gpu > 1:
            self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])

        self.adversarial_train = FGM(self.model)

        train_dataloader, dev_dataloader = data_loaders
        train_eval, dev_eval = examples
        self.eval_file_choice = {
            "train": train_eval,
            "dev": dev_eval,
        }
        self.data_loader_choice = {
            "train": train_dataloader,
            "dev": dev_dataloader,
        }
        # args.use_bert = False
        self.optimizer = set_optimizer(
            args,
            self.model,
            train_steps=(int(len(train_eval) / args.train_batch_size) + 1) *
            args.epoch_num)
コード例 #4
0
import time

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

import models.spo_net.multi_head_select as mhs
from config.spo_config_v1 import BAIDU_RELATION
from layers.decoders.selection import selection_decode
from utils.data_util import Tokenizer
from utils.optimizer_util import set_optimizer

logger = logging.getLogger(__name__)

tokenizer = Tokenizer('cpt/bert-base-chinese/vocab.txt', do_lower_case=True)


class SPO(tuple):
    """用来存三元组的类
    表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
    使得在判断两个三元组是否等价时容错性更好。
    """
    def __init__(self, spo):
        self.spox = (
            tuple(tokenizer.tokenize(spo[0])),
            spo[1],
            tuple(tokenizer.tokenize(spo[2])),
        )

    def __hash__(self):
コード例 #5
0
class Trainer(object):
    def __init__(self, args, data_loaders, examples, char_emb, spo_conf):
        print('using ad')
        self.args = args
        self.tokenizer = Tokenizer(args.bert_model + '/vocab.txt',
                                   do_lower_case=True)
        self.device = torch.device("cuda:{}".format(args.device_id) if torch.
                                   cuda.is_available() else "cpu")
        self.n_gpu = torch.cuda.device_count()

        self.id2rel = {item: key for key, item in spo_conf.items()}
        self.rel2id = spo_conf

        if self.n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)
        if args.use_bert:
            self.model = bert_mpn.ERENet.from_pretrained(
                args.bert_model, classes_num=len(spo_conf))
        else:
            self.model = mpn.ERENet(args, char_emb, spo_conf)

        self.model.to(self.device)
        self.resume(args)
        logging.info('total gpu num is {}'.format(self.n_gpu))
        if self.n_gpu > 1:
            self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])

        self.adversarial_train = FGM(self.model)

        train_dataloader, dev_dataloader = data_loaders
        train_eval, dev_eval = examples
        self.eval_file_choice = {
            "train": train_eval,
            "dev": dev_eval,
        }
        self.data_loader_choice = {
            "train": train_dataloader,
            "dev": dev_dataloader,
        }
        # args.use_bert = False
        self.optimizer = set_optimizer(
            args,
            self.model,
            train_steps=(int(len(train_eval) / args.train_batch_size) + 1) *
            args.epoch_num)

    def train(self, args):

        best_f1 = 0.0
        patience_stop = 0
        self.model.train()
        step_gap = 20
        for epoch in range(int(args.epoch_num)):

            global_loss = 0.0

            for step, batch in tqdm(enumerate(
                    self.data_loader_choice[u"train"]),
                                    mininterval=5,
                                    desc=u'training at epoch : %d ' % epoch,
                                    leave=False,
                                    file=sys.stdout):

                loss = self.forward(batch)

                if step % step_gap == 0:
                    global_loss += loss
                    current_loss = global_loss / step_gap
                    print(u"step {} / {} of epoch {}, train/loss: {}".format(
                        step, len(self.data_loader_choice["train"]), epoch,
                        current_loss))
                    global_loss = 0.0

                if step % 500 == 0 and epoch >= 6:
                    res_dev = self.eval_data_set("dev")
                    if res_dev['f1'] >= best_f1:
                        best_f1 = res_dev['f1']
                        logging.info(
                            "** ** * Saving fine-tuned model ** ** * ")
                        model_to_save = self.model.module if hasattr(
                            self.model, 'module'
                        ) else self.model  # Only save the model it-self
                        output_model_file = args.output + "/pytorch_model.bin"
                        torch.save(model_to_save.state_dict(),
                                   str(output_model_file))
                        patience_stop = 0
                    else:
                        patience_stop += 1
                    if patience_stop >= args.patience_stop:
                        return

            res_dev = self.eval_data_set("dev")
            if res_dev['f1'] >= best_f1:
                best_f1 = res_dev['f1']
                logging.info("** ** * Saving fine-tuned model ** ** * ")
                model_to_save = self.model.module if hasattr(
                    self.model,
                    'module') else self.model  # Only save the model it-self
                output_model_file = args.output + "/pytorch_model.bin"
                torch.save(model_to_save.state_dict(), str(output_model_file))
                patience_stop = 0
            else:
                patience_stop += 1
            if patience_stop >= args.patience_stop:
                return

    def resume(self, args):
        resume_model_file = args.output + "/pytorch_model.bin"
        logging.info("=> loading checkpoint '{}'".format(resume_model_file))
        checkpoint = torch.load(resume_model_file, map_location='cpu')
        self.model.load_state_dict(checkpoint)

    def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):

        batch = tuple(t.to(self.device) for t in batch)
        if not eval:
            input_ids, segment_ids, token_type_ids, subject_ids, subject_labels, object_labels = batch
            loss = self.model(passage_ids=input_ids,
                              segment_ids=segment_ids,
                              token_type_ids=token_type_ids,
                              subject_ids=subject_ids,
                              subject_labels=subject_labels,
                              object_labels=object_labels)
            if self.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            loss.backward()
            self.adversarial_train.attack()
            loss_adv = self.model(passage_ids=input_ids,
                                  segment_ids=segment_ids,
                                  token_type_ids=token_type_ids,
                                  subject_ids=subject_ids,
                                  subject_labels=subject_labels,
                                  object_labels=object_labels)
            if self.n_gpu > 1:
                loss_adv = loss_adv.mean()
            loss_adv.backward()
            self.adversarial_train.restore()

            loss = loss.item()
            self.optimizer.step()
            self.optimizer.zero_grad()
            return loss
        else:
            p_ids, input_ids, segment_ids = batch
            eval_file = self.eval_file_choice[chosen]
            qids, subject_pred, po_pred = self.model(q_ids=p_ids,
                                                     passage_ids=input_ids,
                                                     segment_ids=segment_ids,
                                                     eval_file=eval_file,
                                                     is_eval=eval)
            ans_dict = self.convert_spo_contour(qids,
                                                subject_pred,
                                                po_pred,
                                                eval_file,
                                                answer_dict,
                                                use_bert=self.args.use_bert)
            return ans_dict

    def eval_data_set(self, chosen="dev"):

        self.model.eval()

        data_loader = self.data_loader_choice[chosen]
        eval_file = self.eval_file_choice[chosen]
        answer_dict = {i: [[], []] for i in range(len(eval_file))}

        last_time = time.time()
        with torch.no_grad():
            for _, batch in tqdm(enumerate(data_loader),
                                 mininterval=5,
                                 leave=False,
                                 file=sys.stdout):
                self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
        used_time = time.time() - last_time
        logging.info('chosen {} took : {} sec'.format(chosen, used_time))
        res = self.evaluate(eval_file, answer_dict, chosen)
        self.model.train()
        return res

    def show(self, chosen="dev"):

        self.model.eval()
        answer_dict = {}

        data_loader = self.data_loader_choice[chosen]
        eval_file = self.eval_file_choice[chosen]
        with torch.no_grad():
            for _, batch in tqdm(enumerate(data_loader),
                                 mininterval=5,
                                 leave=False,
                                 file=sys.stdout):
                loss, answer_dict_ = self.forward(batch, chosen, eval=True)
                answer_dict.update(answer_dict_)
        self.badcase_analysis(eval_file, answer_dict, chosen)

    @staticmethod
    def evaluate_(eval_file, answer_dict, chosen):

        entity_em = 0
        entity_pred_num = 0
        entity_gold_num = 0

        triple_em = 0
        triple_pred_num = 0
        triple_gold_num = 0
        for key, value in answer_dict.items():
            triple_gold = eval_file[key].gold_answer
            entity_gold = eval_file[key].sub_entity_list

            entity_pred, triple_pred = value

            entity_em += len(set(entity_pred) & set(entity_gold))
            entity_pred_num += len(set(entity_pred))
            entity_gold_num += len(set(entity_gold))

            triple_em += len(set(triple_pred) & set(triple_gold))

            if set(triple_pred) != set(triple_gold):
                print(set(triple_pred))
                print(set(triple_gold))
                print('-' * 10)
            triple_pred_num += len(set(triple_pred))
            triple_gold_num += len(set(triple_gold))

        entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
        entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
        entity_f1 = 2 * entity_recall * entity_precision / (
            entity_recall + entity_precision) if (
                entity_recall + entity_precision) != 0 else 0.0

        precision = 100.0 * triple_em / triple_pred_num if triple_pred_num > 0 else 0.
        recall = 100.0 * triple_em / triple_gold_num if triple_gold_num > 0 else 0.
        f1 = 2 * recall * precision / (recall + precision) if (
            recall + precision) != 0 else 0.0
        print('============================================')
        print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".
              format(chosen, entity_em, entity_pred_num, entity_gold_num))
        print("{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".
              format(chosen, entity_f1, entity_precision, entity_recall))
        print('============================================')
        print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, triple_em,
                                                     triple_pred_num,
                                                     triple_gold_num))
        print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(
            chosen, f1, precision, recall))
        return {'f1': f1, "recall": recall, "precision": precision}

    @staticmethod
    def evaluate(eval_file, answer_dict, chosen):

        entity_em = 0
        entity_pred_num = 0
        entity_gold_num = 0

        triple_em = 0
        triple_pred_num = 0
        triple_gold_num = 0
        X, Y, Z = 1e-10, 1e-10, 1e-10
        for key, value in answer_dict.items():
            triple_gold = eval_file[key].gold_answer
            entity_gold = eval_file[key].sub_entity_list

            entity_pred, triple_pred = value

            entity_em += len(set(entity_pred) & set(entity_gold))
            entity_pred_num += len(set(entity_pred))
            entity_gold_num += len(set(entity_gold))

            # triple_em += len(set(triple_pred) & set(triple_gold))
            # triple_pred_num += len(set(triple_pred))
            # triple_gold_num += len(set(triple_gold))

            R = set([SPO(spo) for spo in triple_pred])
            T = set([SPO(spo) for spo in triple_gold])
            # if R != T:
            #     print(R)
            #     print(T)
            X += len(R & T)
            Y += len(R)
            Z += len(T)

        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z

        entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
        entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
        entity_f1 = 2 * entity_recall * entity_precision / (
            entity_recall + entity_precision) if (
                entity_recall + entity_precision) != 0 else 0.0

        print('============================================')
        print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".
              format(chosen, entity_em, entity_pred_num, entity_gold_num))
        print("{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".
              format(chosen, entity_f1, entity_precision, entity_recall))
        print('============================================')
        print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, X, Y, Z))
        print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(
            chosen, f1 * 100, precision * 100, recall * 100))
        return {'f1': f1, "recall": recall, "precision": precision}

    def convert_spo_contour(self,
                            qids,
                            subject_preds,
                            po_preds,
                            eval_file,
                            answer_dict,
                            use_bert=False):

        for qid, subject, po_pred in zip(qids.data.cpu().numpy(),
                                         subject_preds.data.cpu().numpy(),
                                         po_preds.data.cpu().numpy()):
            if qid == -1:
                continue
            tokens = eval_file[qid.item()].bert_tokens
            token_ids = eval_file[qid.item()].token_ids
            start = np.where(po_pred[:, :, 0] > 0.6)
            end = np.where(po_pred[:, :, 1] > 0.5)

            spoes = []
            for _start, predicate1 in zip(*start):
                if _start > len(tokens) - 2 or _start == 0:
                    continue
                for _end, predicate2 in zip(*end):
                    if _start <= _end <= len(
                            tokens) - 2 and predicate1 == predicate2:
                        spoes.append((subject, predicate1, (_start, _end)))
                        break
            po_predict = []
            for s, p, o in spoes:
                po_predict.append(
                    (self.tokenizer.decode(token_ids[s[0]:s[1] + 1],
                                           tokens[s[0]:s[1] + 1]),
                     self.id2rel[p],
                     self.tokenizer.decode(token_ids[o[0]:o[1] + 1],
                                           tokens[o[0]:o[1] + 1])))

            if qid not in answer_dict:
                print('erro in answer_dict ')
            else:
                answer_dict[qid][0].append(
                    self.tokenizer.decode(token_ids[subject[0]:subject[1] + 1],
                                          tokens[subject[0]:subject[1] + 1]))
                answer_dict[qid][1].extend(po_predict)
コード例 #6
0
ファイル: train.py プロジェクト: XingXingXudong/EaR
import torch
from torch import nn
from tqdm import tqdm

import models.spo_net.multi_head_select as mhs
from config.spo_config_v1 import BAIDU_RELATION
from layers.decoders.selection import selection_decode
from utils.data_util import Tokenizer
from utils.optimizer_util import set_optimizer

logging.basicConfig(filename='train_multi_head_selection.log',
                    level=logging.DEBUG)
logger = logging.getLogger(__name__)

# tokenizer = Tokenizer('cpt/bert-base-chinese/vocab.txt', do_lower_case=True)
tokenizer = Tokenizer('transformer_cpt/bert/vocab.txt', do_lower_case=True)


class SPO(tuple):
    """用来存三元组的类
    表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
    使得在判断两个三元组是否等价时容错性更好。
    """
    def __init__(self, spo):
        self.spox = (
            tuple(tokenizer.tokenize(spo[0])),
            spo[1],
            tuple(tokenizer.tokenize(spo[2])),
        )

    def __hash__(self):