コード例 #1
0
def del_overlap():
    data_path = get_data_path()
    file_in_path = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v2')
    fout_path = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v3')
    fout = open(fout_path, 'wb')
    with open(file_in_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            triple_list = ll['triples']
            name_list = set()
            for triple in triple_list:
                name, _, _ = triple.split('|||')
                name_list.add(name)
            name_list = list(name_list)
            if len(name_list) < 2:
                print>>fout, json.dumps(ll, encoding='utf-8', ensure_ascii=False)
            else:
                name_max = ''
                for name in name_list:
                    if len(name) > len(name_max):
                        if name_max in name:
                            name_max = name
                        else:
                            name_max = name
                            print(ll['question'])
                            print(ll['triples'])
                new_triple_list = []
                for triple in triple_list:
                    name, _, _ = triple.split('|||')
                    if name == name_max:
                        new_triple_list.append(triple)
                print>> fout, json.dumps({'origin_triple': ll['origin_triple'], 'question': ll['question'], 'triples': new_triple_list}, encoding='utf-8', ensure_ascii=False)
コード例 #2
0
ファイル: ana_data.py プロジェクト: xxx-git/ch_qa
def find_diff_rel():
    data_path = get_data_path()
    file1 = os.path.join(data_path, 'qa_train/nlpcc_qa_data_v2')
    file2 = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v2')

    file_out_path = os.path.join(data_path, 'statistics/train_data.rel_dict')
    fout = open(file_out_path, 'wb')
    rel_dict = dict()
    for f in [file1, file2]:
        with open(f) as fin:
            for line in fin:
                print(line)
                ll = json.loads(line.decode('utf-8').strip())
                if 'triples' in ll.keys():
                    triple_list = ll['triples']
                elif 'triple' in ll.keys():
                    triple_list = ll['triple']
                if len(triple_list) <= 1:
                    continue
                rel_set = set()
                for triple in triple_list:
                    rel = triple.split('|||')[1]
                    rel_set.add(rel)
                for r in rel_set:
                    rel_dict[r] = rel_dict.get(r, 0) + 1
    rel_dict = sorted(rel_dict.items(), key=lambda x: x[1], reverse=True)
    all_count = 0
    for rel, count in rel_dict:
        all_count += count
        print>>fout, '\t'.join((rel, str(count))).encode('utf-8')
    print(all_count)
コード例 #3
0
def gen_check_triple():
    data_path = get_data_path()
    triple_dict = read_rel_dict()
    file_path = os.path.join(data_path, 'check_triple_nlpcc.new')
    file_out_path = os.path.join(data_path, 'check_triple_nlpcc.select')
    fout = open(file_out_path, 'wb')
    with open(file_path, 'rb') as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            triple = ll['triple']
            sub_name, rel, _ = triple.split('|||')
            reli_list = list(set(ll['reliable']))
            if len(reli_list) > 0:
                print>>fout, json.dumps({'triple': triple, 'select_triple': reli_list},
                                        encoding='utf-8', ensure_ascii=False)
            else:
                possi_list = list(set(ll['possible']))
                possi_triple_list = []
                if triple not in triple_dict.keys():
                    print(triple)
                    continue
                select_rel = triple_dict[triple]
                for possi in possi_list:
                    _, rel, _ = possi.split('|||')
                    if rel == select_rel:
                        possi_triple_list.append(possi)
                print>>fout, json.dumps({'triple': triple, 'select_triple': possi_triple_list},
                                        encoding='utf-8', ensure_ascii=False)
コード例 #4
0
def read_train_data(file1=None, file2=None):
    data_path = get_data_path()
    file1 = os.path.join(data_path, 'qa_train/nlpcc_qa_data_v2')
    file2 = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v2')
    rel_ques_dict = {}
    ques_triple_dict = {}
    for f in [file1, file2]:
        with open(f) as fin:
            for line in fin:
                ll = json.loads(line.decode('utf-8').strip())
                triple_list = []
                ques = ll['question']
                if 'triple' in ll.keys():
                    triple_list = ll['triple']
                elif 'triples' in ll.keys():
                    triple_list = ll['triples']
                # if ques in ques_triple_dict.keys():
                #     print(ques)
                ques_triple_dict[ques] = triple_list
    #             for triple in triple_list:
    #                 rel = triple.split('|||')[1]
    #                 if rel in rel_ques_dict.keys():
    #                     rel_ques_dict[rel].add(ques)
    #                 else:
    #                     rel_ques_dict[rel] = set()
    #                     rel_ques_dict[rel].add(ques)
    # return rel_ques_dict
    print('load succeed')
    return ques_triple_dict
コード例 #5
0
def get_triple(out_file_name):
    data_path = get_data_path()
    fout_path = os.path.join(data_path, out_file_name)
    fout = open(fout_path, 'wb')
    score_dict = {'0.5': 0, '1.0': 0, '1.5': 0, '2.0': 0, '>2.0': 0}
    data_dir_path = os.path.join(data_path, 'gen_qa/cqa_triple_match')
    file_list = os.listdir(data_dir_path)
    file_path_list = []
    for f in file_list:
        file_path_list.append(os.path.join(data_dir_path, f))
    triple_set = set()
    for file in file_path_list:
        with open(file) as fin:
            for line in fin:
                line = line.decode('utf-8').strip().split('\t')
                score = float(line[-1])
                if score < 0.5:
                    score_dict['0.5'] += 1
                elif 0.5 <= score < 1.0:
                    score_dict['1.0'] += 1
                elif 1.0 <= score < 1.5:
                    score_dict['1.5'] += 1
                elif 1.5 <= score < 2.0:
                    score_dict['2.0'] += 1
                else:
                    score_dict['>2.0'] += 1
                triple_set.add('|||'.join((line[2], line[3], line[4])))
        # break
    print(score_dict)
    for triple in triple_set:
        print >> fout, triple.encode('utf-8')
コード例 #6
0
def get_select_rel():
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'check_triple_nlpcc.tmp')
    file_out_path = os.path.join(data_path, 'check_triple_nlpcc.tmp2')
    fout = open(file_out_path, 'wb')
    with open(file_path) as fin:
        for line in fin:
            line = json.loads(line.decode('utf-8').strip())
            origin_rel = line['triple'].split('|||')[1]
            rel = line['rel']
            possi_rel_list = line['possi_rel_list']
            select_rel = ''
            if rel != '':
                select_rel = rel
            elif len(possi_rel_list) == 1:
                select_rel = possi_rel_list[0]
            elif len(possi_rel_list) > 1:
                print(line['triple'])
                max_sim = 0
                for possi_rel in possi_rel_list:
                    if possi_rel == 'subname':
                        continue
                    sim = get_similarity(possi_rel, origin_rel)
                    if sim > max_sim:
                        max_sim = sim
                        select_rel = possi_rel
            print>>fout, json.dumps({'triple': line['triple'], 'rel': rel, 'possi_rel_list': possi_rel_list,
                                     'select_rel': select_rel}, encoding='utf-8', ensure_ascii=False)
コード例 #7
0
def del_eval_data():
    data_path = get_data_path()
    file1_in_path = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v3')
    file2_in_path = os.path.join(data_path, 'qa_train/nlpcc_qa_data_v3')
    file1_out_path = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v4')
    fout1 = open(file1_out_path, 'wb')
    file2_out_path = os.path.join(data_path, 'qa_train/nlpcc_qa_data_v4')
    fout2 = open(file2_out_path, 'wb')
    ques_set = get_ques_set()
    with open(file1_in_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            if ll['question'] in ques_set:
                continue
            else:
                print >> fout1, json.dumps(ll,
                                           ensure_ascii=False,
                                           encoding='utf-8')
    with open(file2_in_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            if ll['question'] in ques_set:
                continue
            else:
                print >> fout2, json.dumps(ll,
                                           ensure_ascii=False,
                                           encoding='utf-8')
コード例 #8
0
ファイル: ana_data.py プロジェクト: xxx-git/ch_qa
def gen_train_data():
    file_list = ['gen_qa_data_1.0', 'nlpcc_qa_data']
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'qa_train')

    file_out_path = os.path.join(file_path, 'all_qa_data')
    fout = open(file_out_path, 'wb')

    question_dict = read_ques()

    for f in file_list:
        file = os.path.join(file_path, f)
        with open(file) as fin:
            for line in fin:
                ll = json.loads(line.decode('utf-8').strip())
                question = ll['question'].strip()
                if question in question_dict.keys():
                    sub_name = question_dict[question]
                    triple_list = []
                    for triple in ll['triple']:
                        entity_name = triple.split('|||')[0].strip()
                        if entity_name == sub_name:
                            triple_list.append(triple)
                    if triple_list:
                        print>>fout, json.dumps({'question': question, 'triples': triple_list}, encoding='utf-8', ensure_ascii=False)
                else:
                    print>>fout, json.dumps({'question': question, 'triples': ll['triple']}, encoding='utf-8', ensure_ascii=False)
コード例 #9
0
def load_vectors():
    data_path = get_data_path()
    vector_path = os.path.join(data_path, 'vectors/vectors.bin')
    model = gensim.models.KeyedVectors.load_word2vec_format(vector_path,
                                                            binary=True)
    print('load vector model succeed! %d words' % len(model))
    return model
コード例 #10
0
def gen_check_rel():
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'check_triple_nlpcc.new')
    file_out_path = os.path.join(data_path, 'check_triple_nlpcc.tmp')
    fout = open(file_out_path, 'wb')
    with open(file_path, 'rb') as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            triple = ll['triple']
            sub_name, rel, _ = triple.split('|||')
            reli_list = list(set(ll['reliable']))
            possi_list = list(set(ll['possible']))
            possi_rel_list = []
            if len(reli_list) > 0:
                reliable_rel = rel
                print>>fout, json.dumps({'triple': triple, 'rel': reliable_rel,
                                         'possi_rel_list': possi_rel_list},
                                        encoding='utf-8', ensure_ascii=False)
            elif len(possi_list) > 0:
                for item in possi_list:
                    item_rel = item.split('|||')[1]
                    possi_rel_list.append(item_rel)
                possi_rel_list = list(set(possi_rel_list))
                tmp_rel = ''
                if rel in possi_rel_list:
                    tmp_rel = rel
                print>>fout, json.dumps({'triple': triple, 'rel': tmp_rel,
                                         'possi_rel_list': possi_rel_list},
                                        encoding='utf-8', ensure_ascii=False)
コード例 #11
0
def get_ques_set():
    data_path = get_data_path()
    fin_path = os.path.join(data_path, 'qa_train/test_data')
    question_set = set()
    with open(fin_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            question_set.add(ll['question'])
    return question_set
コード例 #12
0
def read_triple_dict():
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'check_triple_nlpcc.select')
    triple_dict = {}
    with open(file_path, 'rb') as fin:
        for line in fin:
            line = json.loads(line.decode('utf-8').strip())
            select_triple_list = line['select_triple']
            triple_dict[line['triple']] = select_triple_list
    return triple_dict
コード例 #13
0
ファイル: ana_data.py プロジェクト: xxx-git/ch_qa
def read_ques():
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'statistics/diff_subject')
    question_dict = {}
    with open(file_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            if ll['question'] in question_dict.keys():
                print(ll)
            else:
                question_dict[ll['question']] = ll['sub_name']
    return question_dict
コード例 #14
0
def read_count_dict():
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'statistics/train_data.rel_dict')
    count_dict = {}
    idx = 1
    with open(file_path) as fin:
        for line in fin:
            rel, count = line.decode('utf-8').strip().split('\t')
            count = int(count)
            for i in range(count):
                count_dict[idx] = rel
                idx += 1
    return count_dict
コード例 #15
0
ファイル: data_prepare.py プロジェクト: xxx-git/ch_qa
def gen_seq2seq_train(file_name_list, dir_name, train=0.8):
    data_path = get_data_path()
    source_train_name = os.path.join(os.path.join(data_path, dir_name),
                                     'question.train')
    target_train_name = os.path.join(os.path.join(data_path, dir_name),
                                     'rel.train')
    source_dev_name = os.path.join(os.path.join(data_path, dir_name),
                                   'question.dev')
    target_dev_name = os.path.join(os.path.join(data_path, dir_name),
                                   'rel.dev')
    source_train_out = open(source_train_name, 'wb')
    target_train_out = open(target_train_name, 'wb')
    source_dev_out = open(source_dev_name, 'wb')
    target_dev_out = open(target_dev_name, 'wb')
    for file_name in file_name_list:
        data_file = os.path.join(os.path.join(data_path, 'qa_train'),
                                 file_name)
        with open(data_file) as fin:
            for line in fin:
                ll = json.loads(line.decode('utf-8').strip())
                question = ll['question']
                triple_list = []
                if 'triple' in ll.keys():
                    triple_list = ll['triple']
                elif 'triples' in ll.keys():
                    triple_list = ll['triples']
                else:
                    print(ll)
                # triple_list = ll['triple']
                rel_set = set()
                for triple in triple_list:
                    s, p, o = triple.split('|||')
                    pattern = question_ner(question, s.strip())
                    pattern = pre_question(pattern)
                    if is_filter_question(pattern):
                        continue
                    if p == 'subname' and len(triple_list) > 1:
                        continue
                    if p in rel_set:
                        continue
                    else:
                        rel_set.add(p)
                    num = random.random()
                    if num <= train:
                        print >> source_train_out, pattern.encode('utf-8')
                        print >> target_train_out, p.strip().encode('utf-8')
                    else:
                        print >> source_dev_out, pattern.encode('utf-8')
                        print >> target_dev_out, p.strip().encode('utf-8')
コード例 #16
0
def read_triple_dict():
    data_path = get_data_path()
    file_triple = os.path.join(data_path, 'check_triple_cqa.all')
    # file_triple = os.path.join(data_path, 'check_triple_nlpcc.all')
    triple_dict = dict()
    with open(file_triple) as fin:
        for line in fin:
            line = json.loads(line.decode('utf-8').strip())
            if not line['reliable'] and not line['possible']:
                continue
            if line['reliable']:
                triple_dict[line['triple']] = line['reliable']
            else:
                triple_dict[line['triple']] = line['possible']
    return triple_dict
コード例 #17
0
def get_eval_data():
    data_path = get_data_path()
    file_in_path = os.path.join(data_path, 'qa_train/test_data_tmp')
    file_out_path = os.path.join(data_path, 'qa_train/test_data')
    fout = open(file_out_path, 'wb')
    ques_triple_dict = read_train_data()
    with open(file_in_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            rel = ll['relation']
            ques = ll['question']
            if ques not in ques_triple_dict.keys():
                print(ques)
                continue
            triple_list = ques_triple_dict[ques]
            print>> fout, json.dumps({'question': ques, 'triple': triple_list},ensure_ascii=False, encoding='utf-8')
コード例 #18
0
def gen_eval_data():
    num = 1000
    data_path = get_data_path()
    file1 = os.path.join(data_path, 'qa_train/nlpcc_qa_data_v2')
    file2 = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v2')
    file_out_path = os.path.join(data_path, 'qa_train/test_data_tmp')
    fout = open(file_out_path, 'wb')
    rel_ques_dict = read_train_data(file1, file2)
    count_dict = read_count_dict()
    for i in range(num):
        idx = random.randint(1, 6594)
        rel = count_dict[idx]
        ques_list = list(rel_ques_dict[rel])
        ques_id = random.randint(0, len(ques_list)-1)
        ques = ques_list[ques_id]
        print>>fout, json.dumps({'question': ques, 'relation': rel}, encoding='utf-8', ensure_ascii=False)
コード例 #19
0
ファイル: data_prepare.py プロジェクト: xxx-git/ch_qa
def count_ques_len():
    file_name_list = ['source.train']
    data_path = get_data_path()
    len_count = dict()
    for file_name in file_name_list:
        data_file = os.path.join(os.path.join(data_path, 'seq2seq_v1'),
                                 file_name)
        with open(data_file) as fin:
            for line in fin:
                ll = line.decode('utf-8').strip()
                seg_list = seg_sentence(ll)
                seg_len = len(seg_list)
                if seg_len in len_count.keys():
                    len_count[seg_len] += 1
                else:
                    len_count[seg_len] = 1
    print(len_count)
コード例 #20
0
def rewrite_data():
    data_path = get_data_path()
    file_in_path = os.path.join(data_path, 'qa_train/test_data')
    fout1_path = os.path.join(data_path, 'seq2seq_v4/test_ques')
    fout2_path = os.path.join(data_path, 'seq2seq_v4/test_rel')
    fout1 = open(fout1_path, 'wb')
    fout2 = open(fout2_path, 'wb')
    with open(file_in_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            triple_list = ll['triple']
            ques = ll['question']
            triple = triple_list[0]
            sub, rel, _ = triple.split('|||')
            ques_pattern = ques.replace(sub, '_')
            print>> fout1, ques_pattern.encode('utf-8')
            print>> fout2, rel.encode('utf-8')
コード例 #21
0
def count_rel():
    data_path = get_data_path()
    fin_path = os.path.join(data_path, 'qa_train/test_data')
    fout_path = os.path.join(data_path, 'statistics/test_rel')
    fout = open(fout_path, 'wb')
    # question_set = set()
    rel_dict = {}
    with open(fin_path) as fin:
        for line in fin:
            ll = json.loads(line.decode('utf-8').strip())
            # question_set.add(ll['question'])
            triple_list = ll['triple']
            for triple in triple_list:
                _, rel, _ = triple.split('|||')
                rel_dict[rel] = rel_dict.get(rel, 0) + 1
    rel_dict = sorted(rel_dict.iteritems(), key=lambda x: x[1], reverse=True)
    for rel, count in rel_dict:
        print>>fout, '%s: %d' % (rel, count)
コード例 #22
0
def gen_train_cqa(min_score):
    triple_dict = read_triple_dict()
    print('load %d triple dict succeed' % len(triple_dict))

    data_path = get_data_path()
    data_dir_path = os.path.join(data_path, 'gen_qa/cqa_triple_match')
    file_list = os.listdir(data_dir_path)
    file_path_list = []
    for f in file_list:
        file_path_list.append(os.path.join(data_dir_path, f))
    fout_path = os.path.join(data_path,
                             ('qa_train/gen_qa_data_%s_v2' % min_score))
    fout = open(fout_path, 'wb')

    count = 0
    adict = dict()
    q_set = set()
    for file in file_path_list:
        with open(file) as fin:
            for line in fin:
                line = line.decode('utf-8').strip().split('\t')
                score = float(line[-1])
                if score <= min_score:
                    continue
                question = pre_question(line[0].strip())
                if question in q_set:
                    continue
                q_set.add(question)
                count += 1
                adict['question'] = question
                adict['triples'] = []

                triple = '|||'.join((line[2], line[3], line[4]))
                adict['origin_triple'] = triple
                if triple not in triple_dict.keys():
                    continue
                adict['triples'] = triple_dict[triple]
                print >> fout, json.dumps(adict,
                                          encoding='utf-8',
                                          ensure_ascii=False)
    print(count)
    print(len(q_set))
コード例 #23
0
def check_triple(triple_file_name, file_out_name):
    count = 0
    data_path = get_data_path()
    triple_file = os.path.join(data_path, triple_file_name)
    out_file = os.path.join(data_path, file_out_name)
    fout = open(out_file, 'wb')
    with open(triple_file) as fin:
        for line in fin:
            # sub_name, rel, obj_name = line.decode('utf-8').strip().split('\t')
            sub_name, rel, obj_name = line.decode('utf-8').strip().split('|||')
            sub_name = del_brackets(sub_name)

            if '\'' in sub_name:
                sub_name = sub_name.replace('\'', '\\\'')
            rel_dict_list = search_node_neo4j(sub_name)
            if not rel_dict_list:
                continue
            adict = dict()
            adict['reliable'] = []
            adict['possible'] = []
            adict['triple'] = '|||'.join((sub_name, rel, obj_name))
            for rel_dict in rel_dict_list:
                for neo4j_rel, value in rel_dict.iteritems():
                    neo4j_rel = ensure_unicode(neo4j_rel)
                    if neo4j_rel == 'description' or neo4j_rel == 'taglist' or not value:
                        continue
                    value = ensure_unicode(value)
                    if rel == neo4j_rel and obj_name in value:
                        adict['reliable'].append('|||'.join(
                            (sub_name, neo4j_rel, value)))
                    elif rel == neo4j_rel:
                        adict['possible'].append('|||'.join(
                            (sub_name, neo4j_rel, value)))
                    elif obj_name in value or value in obj_name:
                        adict['possible'].append('|||'.join(
                            (sub_name, neo4j_rel, value)))
            if adict['reliable'] or adict['possible']:
                count += 1
            print >> fout, json.dumps(adict,
                                      encoding='utf-8',
                                      ensure_ascii=False)
    print(count)
コード例 #24
0
def get_triple_nlpcc(out_file_name):
    triple_set = set()
    data_path = get_data_path()
    fin_path = os.path.join(data_path, 'nlpcc_qa/ch.qatriple_all')
    with open(fin_path) as fin:
        while True:
            line = fin.readline()
            if not line:
                break
            line = line.decode('utf-8').strip()
            if '|||' in line and len(line.split('|||')) == 3:
                triple = line.split('|||')
                for i in range(len(triple)):
                    triple[i] = triple[i].strip()
                triple_set.add('|||'.join(triple))
    print(len(triple_set))
    fout_path = os.path.join(data_path, out_file_name)
    fout = open(fout_path, 'wb')
    for triple in triple_set:
        print >> fout, triple.encode('utf-8')
コード例 #25
0
def gen_qa_nlpcc():
    triple_dict = read_triple_dict()
    print('load %d triple dict succeed' % len(triple_dict))
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'nlpcc_qa/ch.qatriple_all')

    fout_path = os.path.join(data_path, 'qa_train/nlpcc_qa_data')
    fout = open(fout_path, 'wb')

    with open(file_path) as fin:
        count = 1
        while True:
            line = fin.readline()
            if not line:
                break
            line = line.decode('utf-8').strip()
            if line.startswith(str(count)):
                question = line.split('\t')[1].strip()
                triple_list = []
                triple_set = set()
                new_line = fin.readline().decode('utf-8').strip()
                while new_line.strip():
                    triple = new_line.split('|||')
                    for i in range(len(triple)):
                        triple[i] = triple[i].strip()
                    triple[0] = del_brackets(triple[0])
                    triple_set.add('|||'.join(triple))
                    new_line = fin.readline().decode('utf-8').strip()
                for triple in triple_set:
                    # triple = u"高等数学|||出版社|||武汉大学出版社"
                    if triple in triple_dict.keys():
                        print(triple_dict[triple])
                        triple_list.extend(triple_dict[triple])
                if triple_list:
                    adict = dict()
                    adict['question'] = pre_question(question)
                    adict['triple'] = list(set(triple_list))
                    print >> fout, json.dumps(adict,
                                              encoding='utf-8',
                                              ensure_ascii=False)
                count += 1
コード例 #26
0
ファイル: ana_data.py プロジェクト: xxx-git/ch_qa
def find_diff():
    data_path = get_data_path()
    file_path = os.path.join(data_path, 'seq2seq_v1/res.dev')
    question_dict = {}
    with open(file_path) as fin:
        for line in fin:
            ll = line.decode('utf-8').strip().split('|||')
            question = ll[0].strip()
            target_rel = ll[-1].strip()
            # if question in question_dict.keys():
            #     question_dict[question].add(target_rel)
            # else:
            #     question_dict[question] = set()
            #     question_dict[question].add(target_rel)
            temp_set = question_dict.get(question, set())
            temp_set.add(target_rel)
            question_dict[question] = temp_set
    file_out_path = os.path.join(data_path, 'statistics/same_question_rel')
    fout = open(file_out_path, 'wb')
    for ques, rel_set in question_dict.iteritems():
        print>>fout, json.dumps({'question': ques, 'rel_set': list(rel_set)}, encoding='utf-8', ensure_ascii=False)
コード例 #27
0
def merge_rel(rel_dict):
    data_path = get_data_path()
    rel_file = os.path.join(os.path.join(data_path, 'kb'),
                            'all_pro_map_labeled.json')
    with open(rel_file) as fin:
        rel_map_dict = json.load(fin, encoding='utf-8')
    print('load rel succeed!')
    # reversed_dict = {v: k for k, v in rel_map_dict.items()}
    # map_dict = rel_map_dict.copy()
    # map_dict.update(reversed_dict)
    new_rel_dict = dict()
    for rel, count in rel_dict.iteritems():
        if rel in rel_map_dict.keys():
            if rel_map_dict[rel] in rel_dict.keys():
                new_rel_dict[
                    rel_map_dict[rel]] = rel_dict[rel_map_dict[rel]] + count
            else:
                new_rel_dict[rel_map_dict[rel]] = count
        else:
            new_rel_dict[rel] = count
    return new_rel_dict
コード例 #28
0
ファイル: get_ques.py プロジェクト: xxx-git/ch_qa
def read_train_data():
    data_path = get_data_path()
    file1 = os.path.join(data_path, 'qa_train/nlpcc_qa_data_v2')
    file2 = os.path.join(data_path, 'qa_train/gen_qa_data_1.0_v2')
    rel_ques_dict = {}
    for f in [file1, file2]:
        with open(f) as fin:
            for line in fin:
                ll = json.loads(line.decode('utf-8').strip())
                triple_list = []
                ques = ll['question']
                if 'triple' in ll.keys():
                    triple_list = ll['triple']
                elif 'triples' in ll.keys():
                    triple_list = ll['triples']
                for triple in triple_list:
                    rel = triple.split('|||')[1]
                    if rel in rel_ques_dict.keys():
                        rel_ques_dict[rel].add(ques)
                    else:
                        rel_ques_dict[rel] = set()
                        rel_ques_dict[rel].add(ques)
    return rel_ques_dict
コード例 #29
0
def count_rel2(file_in_name, file_out_name):
    data_path = get_data_path()
    file_in = os.path.join(os.path.join(data_path, 'qa_train'), file_in_name)
    file_out = os.path.join(os.path.join(data_path, 'statistics'),
                            file_out_name)
    rel_dict = dict()
    rel_set = set()
    with open(file_in) as fin:
        for line in fin:
            line_dict = json.loads(line.decode('utf-8').strip())
            triple_list = line_dict['triple']
            for triple in triple_list:
                rel = triple.split('|||')[1]
                rel_set.add(rel)
                if rel in rel_dict.keys():
                    rel_dict[rel] += 1
                else:
                    rel_dict[rel] = 1
    rel_dict = sorted(rel_dict.items(), key=lambda d: d[1], reverse=True)
    print(len(rel_set))
    fout = open(file_out, 'wb')
    for rel, count in rel_dict:
        print >> fout, ('%s : %d' % (rel, count)).encode('utf-8')
コード例 #30
0
import numpy as np
from six.moves import xrange
import tensorflow as tf

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
reload(sys)
sys.setdefaultencoding('utf-8')
from global_path import get_data_path
from config import LargeConfig, MediumConfig

import data_utils
import seq2seq_model

project_data_path = get_data_path()
data_path = os.path.join(project_data_path, 'seq2seq_v4'),
data_path = data_path[0]
train_dir = os.path.join(data_path, "ckpt")
if not os.path.exists(train_dir):
    os.makedirs(train_dir)

# We use a number of buckets and pad to the closest one for efficiency.
# See seq2seq_model.Seq2SeqModel for details of how they work.
# article length padded to 120 and summary padded to 30
# tensorflow 1.0
# buckets = [(4, 3), (5, 3), (6, 3), (8, 3), (40, 3)]
# tensorflow 1.2+
buckets = [(31, 3)]