示例#1
0
    def evaluate(self,
                 word_idxs_list_valid,
                 labels_list_valid,
                 test_texts,
                 terms_true_list,
                 task,
                 opinions_ture_list=None):
        aspect_true_cnt, aspect_sys_cnt, aspect_hit_cnt = 0, 0, 0
        opinion_true_cnt, opinion_sys_cnt, opinion_hit_cnt = 0, 0, 0
        error_sents, error_terms = list(), list()
        correct_sent_idxs = list()
        for sent_idx, (word_idxs, labels, text, terms_true) in enumerate(
                zip(word_idxs_list_valid, labels_list_valid, test_texts,
                    terms_true_list)):
            labels_pred, sequence_lengths = self.predict_batch([word_idxs],
                                                               task)
            labels_pred = labels_pred[0]

            aspect_terms_sys = self.get_terms_from_label_list(
                labels_pred, text, 1, 2)

            new_hit_cnt = utils.count_hit(terms_true, aspect_terms_sys)
            aspect_true_cnt += len(terms_true)
            aspect_sys_cnt += len(aspect_terms_sys)
            aspect_hit_cnt += new_hit_cnt
            if new_hit_cnt == aspect_true_cnt:
                correct_sent_idxs.append(sent_idx)

            if opinions_ture_list is None:
                continue

            opinion_terms_sys = self.get_terms_from_label_list(
                labels_pred, text, 3, 4)
            opinion_terms_true = opinions_ture_list[sent_idx]

            new_hit_cnt = utils.count_hit(opinion_terms_true,
                                          opinion_terms_sys)
            opinion_hit_cnt += new_hit_cnt
            opinion_true_cnt += len(opinion_terms_true)
            opinion_sys_cnt += len(opinion_terms_sys)

        # save_json_objs(error_sents, 'd:/data/aspect/semeval14/error-sents.txt')
        # with open('d:/data/aspect/semeval14/error-sents.txt', 'w', encoding='utf-8') as fout:
        #     for sent, terms in zip(error_sents, error_terms):
        # terms_true = [t['term'].lower() for t in sent['terms']] if 'terms' in sent else list()
        # fout.write('{}\n{}\n\n'.format(sent['text'], terms))
        # with open('d:/data/aspect/semeval14/lstmcrf-correct.txt', 'w', encoding='utf-8') as fout:
        #     fout.write('\n'.join([str(i) for i in correct_sent_idxs]))

        aspect_p, aspect_r, aspect_f1 = utils.prf1(aspect_true_cnt,
                                                   aspect_sys_cnt,
                                                   aspect_hit_cnt)
        if opinions_ture_list is None:
            return aspect_p, aspect_r, aspect_f1, 0, 0, 0

        opinion_p, opinion_r, opinion_f1 = utils.prf1(opinion_true_cnt,
                                                      opinion_sys_cnt,
                                                      opinion_hit_cnt)
        return aspect_p, aspect_r, aspect_f1, opinion_p, opinion_r, opinion_f1
示例#2
0
    def evaluate(self,
                 word_embed_seqs,
                 test_texts,
                 terms_true_list,
                 opinions_ture_list=None,
                 dst_aspects_file=None,
                 dst_opinions_file=None):
        aspect_true_cnt, aspect_sys_cnt, aspect_hit_cnt = 0, 0, 0
        opinion_true_cnt, opinion_sys_cnt, opinion_hit_cnt = 0, 0, 0
        error_sents, error_terms = list(), list()
        correct_sent_idxs = list()
        aspect_terms_sys_list, opinion_terms_sys_list = list(), list()
        for sent_idx, (word_embed_seq, text, terms_true) in enumerate(
                zip(word_embed_seqs, test_texts, terms_true_list)):
            labels_pred, sequence_lengths = self.predict_batch(
                [word_embed_seq])
            labels_pred = labels_pred[0]

            aspect_terms_sys = self.get_terms_from_label_list(
                labels_pred, text, 1, 2)

            new_hit_cnt = utils.count_hit(terms_true, aspect_terms_sys)
            aspect_true_cnt += len(terms_true)
            aspect_sys_cnt += len(aspect_terms_sys)
            aspect_hit_cnt += new_hit_cnt
            if new_hit_cnt == aspect_true_cnt:
                correct_sent_idxs.append(sent_idx)

            if opinions_ture_list is None:
                continue

            opinion_terms_sys = self.get_terms_from_label_list(
                labels_pred, text, 3, 4)
            opinion_terms_true = opinions_ture_list[sent_idx]

            new_hit_cnt = utils.count_hit(opinion_terms_true,
                                          opinion_terms_sys)
            opinion_hit_cnt += new_hit_cnt
            opinion_true_cnt += len(opinion_terms_true)
            opinion_sys_cnt += len(opinion_terms_sys)

        if dst_aspects_file is not None:
            utils.write_terms_list(aspect_terms_sys_list, dst_aspects_file)
        if dst_opinions_file is not None:
            utils.write_terms_list(opinion_terms_sys_list, dst_opinions_file)

        aspect_p, aspect_r, aspect_f1 = utils.prf1(aspect_true_cnt,
                                                   aspect_sys_cnt,
                                                   aspect_hit_cnt)
        if opinions_ture_list is None:
            return aspect_p, aspect_r, aspect_f1, 0, 0, 0

        opinion_p, opinion_r, opinion_f1 = utils.prf1(opinion_true_cnt,
                                                      opinion_sys_cnt,
                                                      opinion_hit_cnt)
        return aspect_p, aspect_r, aspect_f1, opinion_p, opinion_r, opinion_f1
示例#3
0
    def evaluate(self, texts, word_idxs_list_valid, word_span_seqs, tok_texts, terms_true_list, task,
                 opinions_ture_list=None, dst_aspects_result_file=None, dst_opinion_result_file=None,
                 save_result=False):
        aspect_true_cnt, aspect_sys_cnt, aspect_hit_cnt = 0, 0, 0
        opinion_true_cnt, opinion_sys_cnt, opinion_hit_cnt = 0, 0, 0
        correct_sent_idxs = list()
        aspect_terms_sys_list, opinion_terms_sys_list = list(), list()
        for sent_idx, (word_idxs, tok_text, terms_true) in enumerate(zip(
                word_idxs_list_valid, tok_texts, terms_true_list)):
            labels_pred, sequence_lengths = self.predict_batch([word_idxs], task)
            labels_pred = labels_pred[0]

            if word_span_seqs is None:
                aspect_terms_sys = utils.get_terms_from_label_list(labels_pred, tok_text, 1, 2)
            else:
                aspect_terms_sys = utils.recover_terms(texts[sent_idx], word_span_seqs[sent_idx], labels_pred, 1, 2)
                aspect_terms_sys = [t.lower() for t in aspect_terms_sys]
            aspect_terms_sys_list.append(aspect_terms_sys)

            new_hit_cnt = utils.count_hit(terms_true, aspect_terms_sys)
            aspect_true_cnt += len(terms_true)
            aspect_sys_cnt += len(aspect_terms_sys)
            aspect_hit_cnt += new_hit_cnt
            if new_hit_cnt == aspect_true_cnt:
                correct_sent_idxs.append(sent_idx)

            if opinions_ture_list is None:
                continue

            opinion_terms_sys = utils.get_terms_from_label_list(labels_pred, tok_text, 3, 4)
            opinion_terms_sys_list.append(opinion_terms_sys)
            opinion_terms_true = opinions_ture_list[sent_idx]

            new_hit_cnt = utils.count_hit(opinion_terms_true, opinion_terms_sys)
            opinion_hit_cnt += new_hit_cnt
            opinion_true_cnt += len(opinion_terms_true)
            opinion_sys_cnt += len(opinion_terms_sys)

        aspect_p, aspect_r, aspect_f1 = utils.prf1(aspect_true_cnt, aspect_sys_cnt, aspect_hit_cnt)
        if opinions_ture_list is None:
            return aspect_p, aspect_r, aspect_f1, 0, 0, 0

        opinion_p, opinion_r, opinion_f1 = utils.prf1(opinion_true_cnt, opinion_sys_cnt, opinion_hit_cnt)

        if dst_aspects_result_file is not None and save_result:
            datautils.write_terms_list(aspect_terms_sys_list, dst_aspects_result_file)
            logging.info('write aspects to {}'.format(dst_aspects_result_file))
        if dst_opinion_result_file is not None and save_result:
            datautils.write_terms_list(opinion_terms_sys_list, dst_opinion_result_file)
            logging.info('write opinions to {}'.format(dst_opinion_result_file))

        return aspect_p, aspect_r, aspect_f1, opinion_p, opinion_r, opinion_f1
示例#4
0
def check_unseen_terms():
    train_sents_file = 'd:/data/aspect/semeval14/laptops/laptops_train_sents.json'
    train_aspect_terms, train_opinion_terms = __get_all_terms(train_sents_file)
    sents_file = 'd:/data/aspect/semeval14/laptops/laptops_test_sents.json'
    lstmcrf_aspects_file = 'd:/data/aspect/semeval14/lstmcrf-aspects.txt'
    lstmcrf_opinions_file = 'd:/data/aspect/semeval14/lstmcrf-opinions.txt'
    nrdj_aspects_file = 'd:/data/aspect/semeval14/nrdj-opinions-malt.txt'
    nrdj_opinions_file = 'd:/data/aspect/semeval14/nrdj-opinions-malt.txt'
    rule_aspects_file = 'd:/data/aspect/semeval14/laptops/laptops-test-aspect-rule-result.txt'

    sents = utils.load_json_objs(sents_file)
    lc_aspects_list = utils.load_json_objs(lstmcrf_aspects_file)
    nrdj_aspects_list = utils.load_json_objs(nrdj_aspects_file)
    rule_aspects_list = utils.load_json_objs(rule_aspects_file)
    terms_true_list, terms_nrdj_list = list(), list()
    n_true, n_nrdj, n_hit = 0, 0, 0
    n_lc, n_lc_hit = 0, 0
    for sent, lc_aspects, nrdj_aspects, rule_aspects in zip(
            sents, lc_aspects_list, nrdj_aspects_list, rule_aspects_list):
        # terms = [t['term'].lower() for t in sent.get('terms', list())]
        terms = [t.lower() for t in sent.get('opinions', list())]
        # terms = [t for t in terms if t in train_aspect_terms]
        # print(terms, nrdj_aspects)
        terms_true_list.append(terms)
        terms_nrdj_list.append(nrdj_aspects)
        n_true += len(terms)
        n_nrdj += len(nrdj_aspects)
        n_hit += utils.count_hit(terms, nrdj_aspects)
        for t in terms:
            if t not in nrdj_aspects:
                print(t)
                print(sent['text'])

        n_lc += len(lc_aspects)
        n_lc_hit += utils.count_hit(terms, lc_aspects)
        # lc_correct = __is_correct(lc_aspects, terms)
        # nrdj_correct = __is_correct(nrdj_aspects, terms)
        # rule_correct = __is_correct(rule_aspects, terms)
        # if not lc_correct and not rule_correct and nrdj_correct:
        #     print(sent['text'])
        #     print(terms)
        #     print(lc_aspects)
        #     print(rule_aspects)
        #     print(nrdj_aspects)
        #     print()

    print(n_true, n_nrdj)
    p, r, f1 = utils.prf1(n_true, n_nrdj, n_hit)
    print(p, r, f1)
    p, r, f1 = utils.prf1(n_true, n_lc, n_lc_hit)
    print(p, r, f1)
示例#5
0
def __evaluate(terms_sys_list, terms_true_list, dep_tags_list, pos_tags_list, sent_texts):
    correct_sent_idxs = list()
    hit_cnt, true_cnt, sys_cnt = 0, 0, 0
    for sent_idx, (terms_sys, terms_true, dep_tags, pos_tags) in enumerate(
            zip(terms_sys_list, terms_true_list, dep_tags_list, pos_tags_list)):
        true_cnt += len(terms_true)
        sys_cnt += len(terms_sys)
        # new_hit_cnt = __count_hit(terms_true, aspect_terms)
        new_hit_cnt = utils.count_hit(terms_true, terms_sys)
        if new_hit_cnt == len(terms_true) and new_hit_cnt == len(terms_sys):
            correct_sent_idxs.append(sent_idx)
        hit_cnt += new_hit_cnt
        # if len(terms_true) and new_hit_cnt < len(terms_true):
        #     print(terms_true)
        #     print(terms_sys)
        #     print(sent_texts[sent_idx])
        #     print(pos_tags)
        #     print(dep_tags)
        #     print()

    # __save_never_hit_terms(sents, terms_sys_list, 'd:/data/aspect/semeval14/tmp.txt')

    print('hit={}, true={}, sys={}'.format(hit_cnt, true_cnt, sys_cnt))
    p = hit_cnt / (sys_cnt + 1e-8)
    r = hit_cnt / (true_cnt + 1e-8)
    print(p, r, 2 * p * r / (p + r + 1e-8))
    return correct_sent_idxs
示例#6
0
def evaluate_ao_extraction(true_labels_list,
                           pred_labels_list,
                           test_texts,
                           aspects_true_list,
                           opinions_ture_list=None,
                           error_file=None):
    aspect_true_cnt, aspect_sys_cnt, aspect_hit_cnt = 0, 0, 0
    opinion_true_cnt, opinion_sys_cnt, opinion_hit_cnt = 0, 0, 0
    error_sents, error_terms_true, error_terms_sys = list(), list(), list()
    correct_sent_idxs = list()
    if aspects_true_list is not None:
        aspect_label_beg, aspect_label_in, opinion_label_beg, opinion_label_in = 1, 2, 3, 4
    else:
        aspect_label_beg, aspect_label_in, opinion_label_beg, opinion_label_in = 3, 4, 1, 2
    for sent_idx, (true_labels, pred_labels, text) in enumerate(
            zip(true_labels_list, pred_labels_list, test_texts)):
        if aspects_true_list is not None:
            aspects_true = aspects_true_list[sent_idx]
            aspect_terms_sys = get_terms_from_label_list(
                pred_labels, text, aspect_label_beg, aspect_label_in)

            new_hit_cnt = count_hit(aspects_true, aspect_terms_sys)
            aspect_true_cnt += len(aspects_true)
            aspect_sys_cnt += len(aspect_terms_sys)
            aspect_hit_cnt += new_hit_cnt
            if new_hit_cnt == aspect_true_cnt:
                correct_sent_idxs.append(sent_idx)

        if opinions_ture_list is None:
            continue

        opinion_terms_sys = get_terms_from_label_list(pred_labels, text,
                                                      opinion_label_beg,
                                                      opinion_label_in)
        opinion_terms_true = opinions_ture_list[sent_idx]

        new_hit_cnt = count_hit(opinion_terms_true, opinion_terms_sys)
        opinion_hit_cnt += new_hit_cnt
        opinion_true_cnt += len(opinion_terms_true)
        opinion_sys_cnt += len(opinion_terms_sys)

        if new_hit_cnt < len(opinion_terms_true):
            error_sents.append(text)
            error_terms_true.append(opinion_terms_true)
            error_terms_sys.append(opinion_terms_sys)

    # save_json_objs(error_sents, 'd:/data/aspect/semeval14/error-sents.txt')
    if error_file is not None:
        with open(error_file, 'w', encoding='utf-8') as fout:
            for sent, terms_true, terms_sys in zip(error_sents,
                                                   error_terms_true,
                                                   error_terms_sys):
                fout.write('{}\n{}\n{}\n\n'.format(sent, terms_true,
                                                   terms_sys))
        logging.info('error written to {}'.format(error_file))
    # with open('d:/data/aspect/semeval14/lstmcrf-correct.txt', 'w', encoding='utf-8') as fout:
    #     fout.write('\n'.join([str(i) for i in correct_sent_idxs]))

    aspect_p, aspect_r, aspect_f1, opinion_p, opinion_r, opinion_f1 = 0, 0, 0, 0, 0, 0
    if aspects_true_list is not None:
        aspect_p, aspect_r, aspect_f1 = prf1(aspect_true_cnt, aspect_sys_cnt,
                                             aspect_hit_cnt)

    if opinions_ture_list is not None:
        opinion_p, opinion_r, opinion_f1 = prf1(opinion_true_cnt,
                                                opinion_sys_cnt,
                                                opinion_hit_cnt)
    return aspect_p, aspect_r, aspect_f1, opinion_p, opinion_r, opinion_f1
示例#7
0
    def evaluate(self,
                 texts,
                 word_idxs_list_valid,
                 word_span_seqs,
                 tok_texts,
                 terms_true_list,
                 task,
                 opinions_ture_list=None,
                 dst_aspects_result_file=None,
                 dst_opinion_result_file=None,
                 save_result=False):
        aspect_true_cnt, aspect_sys_cnt, aspect_hit_cnt = 0, 0, 0
        opinion_true_cnt, opinion_sys_cnt, opinion_hit_cnt = 0, 0, 0
        error_sents, error_terms = list(), list()
        correct_sent_idxs = list()
        aspect_terms_sys_list, opinion_terms_sys_list = list(), list()
        for sent_idx, (word_idxs, tok_text, terms_true) in enumerate(
                zip(word_idxs_list_valid, tok_texts, terms_true_list)):
            labels_pred, sequence_lengths = self.predict_batch([word_idxs],
                                                               task)
            labels_pred = labels_pred[0]

            if word_span_seqs is None:
                aspect_terms_sys = self.get_terms_from_label_list(
                    labels_pred, tok_text, 1, 2)
            else:
                aspect_terms_sys = utils.recover_terms(
                    texts[sent_idx], word_span_seqs[sent_idx], labels_pred, 1,
                    2)
                aspect_terms_sys = [t.lower() for t in aspect_terms_sys]
            aspect_terms_sys_list.append(aspect_terms_sys)

            # print(terms_true, aspect_terms_sys)
            new_hit_cnt = utils.count_hit(terms_true, aspect_terms_sys)
            aspect_true_cnt += len(terms_true)
            aspect_sys_cnt += len(aspect_terms_sys)
            aspect_hit_cnt += new_hit_cnt
            if new_hit_cnt == aspect_true_cnt:
                correct_sent_idxs.append(sent_idx)

            if opinions_ture_list is None:
                continue

            opinion_terms_sys = self.get_terms_from_label_list(
                labels_pred, tok_text, 3, 4)
            opinion_terms_sys_list.append(opinion_terms_sys)
            opinion_terms_true = opinions_ture_list[sent_idx]

            new_hit_cnt = utils.count_hit(opinion_terms_true,
                                          opinion_terms_sys)
            opinion_hit_cnt += new_hit_cnt
            opinion_true_cnt += len(opinion_terms_true)
            opinion_sys_cnt += len(opinion_terms_sys)

        # save_json_objs(error_sents, 'd:/data/aspect/semeval14/error-sents.txt')
        # with open('d:/data/aspect/semeval14/error-sents.txt', 'w', encoding='utf-8') as fout:
        #     for sent, terms in zip(error_sents, error_terms):
        # terms_true = [t['term'].lower() for t in sent['terms']] if 'terms' in sent else list()
        # fout.write('{}\n{}\n\n'.format(sent['text'], terms))
        # with open('d:/data/aspect/semeval14/lstmcrf-correct.txt', 'w', encoding='utf-8') as fout:
        #     fout.write('\n'.join([str(i) for i in correct_sent_idxs]))
        if dst_aspects_result_file is not None and save_result:
            utils.write_terms_list(aspect_terms_sys_list,
                                   dst_aspects_result_file)
            logging.info('write aspects to {}'.format(dst_aspects_result_file))
        if dst_opinion_result_file is not None and save_result:
            utils.write_terms_list(opinion_terms_sys_list,
                                   dst_opinion_result_file)
            logging.info(
                'write opinions to {}'.format(dst_opinion_result_file))

        aspect_p, aspect_r, aspect_f1 = utils.prf1(aspect_true_cnt,
                                                   aspect_sys_cnt,
                                                   aspect_hit_cnt)
        if opinions_ture_list is None:
            return aspect_p, aspect_r, aspect_f1, 0, 0, 0

        opinion_p, opinion_r, opinion_f1 = utils.prf1(opinion_true_cnt,
                                                      opinion_sys_cnt,
                                                      opinion_hit_cnt)
        return aspect_p, aspect_r, aspect_f1, opinion_p, opinion_r, opinion_f1