Exemplo n.º 1
0
def dump_model_result(config, model):
    tag_vocab = pickle.load(open(os.path.join(config.data_path, config.tag_vocab_name), 'rb'))
    metrics = SpanFPreRecMetric(tag_vocab, pred='pred', seq_len='seq_len', target='tag')
    dev_data = pickle.load(open(os.path.join(config.data_path, config.dev_name), "rb"))
    data_iterator = Batch(dev_data, config.batch_size, sampler=SequentialSampler(), as_numpy=False)
    model.cuda()

    eval_results = {}
    dev_data.set_input('tag')
    dev_data.set_target('seq_len')
    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(data_iterator):
            print('batch', i)
            #if i > 10:
            #    break
            char = batch_x['char'].cuda()
            word = batch_x['word'].cuda()
            pos = batch_x['pos'].cuda()
            spo = batch_x['spo'].cuda()
            seq_len = batch_x['seq_len'].cuda()

            tag = batch_y['tag'].cuda()
            
            #pred = model(char, word, pos, spo, seq_len, tag)
            pred = model.predict(char, word, pos, spo, seq_len)  # labels?
            #labels = idx2label(pred['pred'], tag_vocab.idx2word)
            #print(pred)
            #print(tag)
            #exit()
            metrics({'pred': pred['pred'].cuda(), 'seq_len':seq_len}, {'tag': batch_y['tag'].cuda()})
        eval_result = metrics.get_metric()
        metric_name = metrics.__class__.__name__
        eval_results[metric_name] = eval_result

    print("[tester] \n{}".format(_format_eval_results(eval_results)))
Exemplo n.º 2
0
class CommonSeqEvaluator(BaseSeqEvaluator):

    def __init__(self, tag_vocab, config):
        super(CommonSeqEvaluator, self).__init__()
        self._config = config
        self._vocab = Vocabulary()
        self._vocab.add_word_lst(tag_vocab.stoi.keys())
        self._evaluator = SpanFPreRecMetric(self._vocab,  only_gross=False, f_type=config.evaluation.type)
        self._pad_index = tag_vocab.stoi['<pad>']

    def _change_type(self, pred, target):
        seq_len = torch.tensor([len(text) for text in pred])
        max_len = max(seq_len)
        for text in pred:
            if len(text) < max_len:
                text.extend([self._pad_index for i in range(max_len - len(text))])
        pred = torch.tensor(pred).to(self._config.device)
        return pred, target, seq_len

    def evaluate(self, pred, target):
        # 送入batch数据
        pred, target, seq_len = self._change_type(pred, target)
        self._evaluator.evaluate(pred, target, seq_len)

    def _get_eval_result(self):
        # 统计所有batch数据的结果
        eval_dict = self._evaluator.get_metric()
        if self._config.data.chip_relation.use_chip_relation:
            names = list(set([label[2:] for label in self._vocab.word2idx.keys()][3:]))
            if '其他' in names:
                names.remove('其他')
        else:
            names = list(set([label[2:] for label in self._vocab.word2idx.keys()][3:]))
        head = ['label', '   precision', '   recall', '  F1_score']
        table = []
        table.append(head)
        for i in range(len(names)):
            ps = str(round(eval_dict['pre-' + names[i].lower()], 3))
            rs = str(round(eval_dict['rec-' + names[i].lower()], 3))
            f1s = str(round(eval_dict['f-' + names[i].lower()], 3))
            table.append([names[i], ps, rs, f1s])
        ps = str(round(eval_dict['pre'], 3))
        rs = str(round(eval_dict['rec'], 3))
        f1s = str(round(eval_dict['f'], 3))
        table.append(['{}_average'.format(self._config.evaluation.type), ps, rs, f1s])
        return eval_dict, table

    def get_eval_output(self):
        # 外部获取结果接口,并且可以配置是否打印(eval结果保存暂时默认保存)
        result, table = self._get_eval_result()
        if self._config.evaluation.is_display:
            self._print_table(table)
        self._write_csv(table)
        return result

    def _print_table(self, List):
        # 展示
        k = len(List)
        v = len(List[0])
        for i in range(k):
            for j in range(v):
                print(List[i][j].rjust(14), end=' ')
            print()

    def _write_csv(self, table):
        wb = Workbook()
        ws = wb['Sheet']
        for line in range(1,len(table)+1):
            for column in range(1, 5):
                ws.cell(line, column, table[line-1][column-1])
        save_path = self._config.learn.dir.saved + '/eval_result.xlsx'
        wb.save(save_path)