Ejemplo n.º 1
0
def get_candidate_sub(question_tokens, pred_tag):
    flag = False
    starts = []
    ends = []
    for i, tag in enumerate(pred_tag):
        if tag == 1 and not flag:
            starts.append(i)
            flag = True
        elif tag == 0 and flag:
            if (i + 1 < len(question_tokens)
                    and pred_tag[i + 1] == 0) or i + 1 == len(question_tokens):
                ends.append(i - 1)
                flag = False
    if flag:
        ends.append(len(question_tokens) - 1)

    sub_list = []
    shift = [0, -1, 1, -2, 2]
    pred_sub = []
    for left in shift:
        for right in shift:
            for i in range(len(starts)):
                if starts[i] + left < 0:
                    continue
                if ends[i] + 1 + right > len(question_tokens):
                    continue
                text = question_tokens[starts[i] + left:ends[i] + 1 + right]
                subject = virtuoso.str_query_id(' '.join(text))
                # print(text, subject)
                if left == 0 and right == 0:
                    pred_sub = subject
                sub_list.extend(subject)
            if sub_list:
                return pred_sub, sub_list
    return pred_sub, sub_list
Ejemplo n.º 2
0
def reverse_link(question, subject):
    tokens = word_tokenize(question)
    text_attention_indices = get_text_attention_indices(question, subject)
    text_subject = None
    if text_attention_indices:
        pass
    else:
        sub_ids = virtuoso.str_query_id(subject)
        sub_alies = []
        for sub_id in sub_ids:
            sub_alies.extend(virtuoso.id_query_alias(sub_id))
        for sub_a in sub_alies:
            text_attention_indices = get_text_attention_indices(question, sub_a)
            if text_attention_indices:
                break
    if text_attention_indices == None and subject in alias_dict.keys():
        subs = []
        subs.append(subject+'s')
        subject = alias_dict[subject]
        if type(subject) != type([]):
            subs.append(subject)
        else:
            subs.extend(subject)
        for sub in subs:
            text_attention_indices = get_text_attention_indices(question, sub)
            if text_attention_indices != None:
                break

    if text_attention_indices:
        text_attention_indices = query_golden_subs(question, text_attention_indices)
        text_subject = " ".join( tokens[text_attention_indices[0]:text_attention_indices[-1]+1] )

    return text_subject, text_attention_indices, form_question_pattern(text_attention_indices,question.lower())
Ejemplo n.º 3
0
def extract_web(val):
    objs = []
    objs_id = []
    for answer in val['answers']:
        objs.append(answer.lower())
        try:
            objs_id.append(virtuoso.str_query_id(answer))
        except:
            print("HTTP 400:")
            print(val)

    sub = val['freebaseKey'].replace("_", " ").lower()
    if sub in err_dict.keys():
        sub = err_dict[sub]

    rels = val["relPaths"]
    if rels != []:
        rels = sorted(rels, key = lambda rel: rel[1], reverse = True)
        rels = rels[0][0]
    new_rels = []
    for rel in rels:
        new_rels.append( "fb:"+rel.replace("\\",".") )
    rels = new_rels
    # print(rels)

    question = val["qText"].lower()

    text_subject, text_attention_indices, question_pattern = reverse_link(question, sub)
    if get_text_attention_indices == None or text_subject == None or question_pattern == None:
        print("\'%s\':\'%s\'"%(sub,question))
        if sub == "maya civilization":
            sub = 'mayans'
    # print(question_pattern)
    return objs, sub, rels, question, text_subject, text_attention_indices, question_pattern
Ejemplo n.º 4
0
def predict_subject_ids(paths,tokens):
    # single sentence
    predict_subject_ids=[]
    for tags in paths:
        n_subjects=sum(tags[i]==1 and (i-1==-1 or tags[i-1]==0)for i in range(len(tags)))
        if n_subjects==1:
            subject_name=' '.join([tokens[i] for i ,tag in enumerate(tags) if tag==1])
            #start_index=[i for i ,tag in enumerate(tags) if tag==1 and (i==0 or tags[i-1]==0 )][0]
            #end_index=[i for i,tag in enumerate(tags) if tag==1 and(i==len(tags)-1 or tags[i+1]==0)][0]+1
            subject_id=virtuoso.str_query_id(subject_name)
            #subject_id=virtuoso.name_query_id(subject_name)
            predict_subject_ids.extend(subject_id)
    return predict_subject_ids
Ejemplo n.º 5
0
def query_golden_subs(question, text_attention_indices):
    tokens = word_tokenize(question)
    for i in [-2, -1, 0]:
        for j in [2, 1, 0]:
            s = text_attention_indices[0] + i
            e = text_attention_indices[-1] + j + 1
            if i < 0 or e > len(question):
                continue
            tmp_sub = ' '.join(tokens[s:e])
            tmp_sub_ids = virtuoso.str_query_id(tmp_sub)
            if tmp_sub_ids != []:
                return list(range(s, e))
    return text_attention_indices
Ejemplo n.º 6
0
def create_seq_ranking_data(batch_size, qa_data, word_vocab, rel_vocab):
    file_type = qa_data.split('.')[-2]
    #    log_file = open('data/%s.relation_ranking.txt' %file_type, 'w')
    seqs = []
    pos_rel = []
    neg_rel = []
    neg_rel_size = []
    batch_index = -1  # the index of sequence batches
    seq_index = 0  # sequence index within each batch
    pad_index = word_vocab.lookup(word_vocab.pad_token)

    data_list = pickle.load(open(qa_data, 'rb'))
    for data in data_list:
        tokens = data.question.split()
        #取相同name的所有subject相连的rel作为负样本
        can_subs = virtuoso.str_query_id(data.text_subject)
        can_rels = []
        for sub in can_subs:
            can_rels.extend(virtuoso.id_query_out_rel(sub))
        can_rels = list(set(can_rels))  # 去除重复的rel
        #        log_file.write('%s\t%s\t%s\n' %(data.question, data.relation, can_rels))

        if seq_index % batch_size == 0:
            seq_index = 0
            batch_index += 1
            seqs.append(
                torch.LongTensor(len(tokens), batch_size).fill_(pad_index))
            pos_rel.append(torch.LongTensor(batch_size).fill_(pad_index))
            neg_rel.append([])
            neg_rel_size.append([])
            print('batch: %d' % batch_index)

        seqs[batch_index][0:len(tokens), seq_index] = torch.LongTensor(
            word_vocab.convert_to_index(tokens))
        pos_rel[batch_index][seq_index] = rel_vocab.lookup(data.relation)
        neg_rel[batch_index].append(rel_vocab.convert_to_index(can_rels))
        neg_rel_size[batch_index].append(len(can_rels))
        seq_index += 1

    torch.save((seqs, pos_rel, neg_rel, neg_rel_size),
               'data/%s.relation_ranking.pt' % file_type)
Ejemplo n.º 7
0
def query_golden_subs(data):
    golden_subs = []
    if data.text_subject:
        # extract fields needed
        relation     = data.relation
        subject      = data.subject
        text_subject = data.text_subject
        
        # query name / alias by subject (id)
        candi_sub_list = virtuoso.str_query_id(text_subject)

        # add candidates to data
        for candi_sub in candi_sub_list:
            candi_rel_list = virtuoso.id_query_out_rel(candi_sub)
            if relation in candi_rel_list:
                golden_subs.append(candi_sub)

    if len(golden_subs) == 0:
        golden_subs = [data.subject]

    return golden_subs
Ejemplo n.º 8
0
def query_candidate(data_list, pred_list,logName):
    log_file = open(logName, 'w')
    new_data_list = []

    NoneMatch = 0
    succ_match = 0
    data_index = 0
    for pred, data in zip(pred_list, data_list):
        # extract scores
        scores = np.array([int(float(score)) for score in pred.decode().strip().split()])

        # extract fields needed
        relation = data.relation
        subject  = data.subject
        question = data.question
        text_attention_indices = data.text_attention_indices
        if not text_attention_indices:
            continue
        # incremnt data_index
        data_index += 1            
        # print([question])
        tokens   = np.array(question.split())

        # query name / alias by subject (id)
        candi_sub_list = []
        # for threshold in np.arange(0.5, 0.0, -0.095):
        #     beg_idx, end_idx = beg_end_indices(scores, threshold)
        #     sub_text = ' '.join(tokens[beg_idx:end_idx+1])
        #     candi_sub_list.extend(virtuoso.str_query_id(sub_text))
        #     if len(candi_sub_list) > 0:
        #         break

        beg_idx, end_idx = beg_end_indices(scores, 0.2)
        tokens_crop = tokens[beg_idx:end_idx+1]
        sub_text = ' '.join(tokens_crop)
        text_list=[]
        # for i in [1,-1,2,-2]:
        #     if beg_idx-i>=0 and beg_idx-i<=end_idx:
        #         tokens_crop2=tokens[beg_idx-i:end_idx+1]
        #         text_list.append(' '.join(tokens_crop2))
        #     if end_idx+i<seq_len and end_idx+i>=beg_idx:
        #         tokens_crop2=tokens[beg_idx:end_idx+i+1]
        #         text_list.append(' '.join(tokens_crop2))
        #         if beg_idx-i>=0 and beg_idx-i<=end_idx+i:
        #             tokens_crop2=tokens[beg_idx-i:end_idx+i+1]
        #             text_list.append(' '.join(tokens_crop2))
        candi_sub_list.extend(virtuoso.str_query_id(sub_text))
        if '' in candi_sub_list:
            candi_sub_list.remove('')

            # if candi_sub_list==[]:
            #     for text in text_list:
            #         candi_sub_list.extend(virtuoso.str_query_id(text))
            #         if '' in candi_sub_list:
            #             candi_sub_list.remove('')
            #         if candi_sub_list!=[]:
            #             break
        if candi_sub_list:
            data.set_strict_flag(True)
            pass
        else:
            data.set_strict_flag(False)
            candi_sub_list.extend(virtuoso.partial_str_query_id(sub_text))
            if '' in candi_sub_list:
                candi_sub_list.remove('')
        if not candi_sub_list:
            for i in range(len(tokens_crop)-1,1,-1):
                tempList = generate_ngrams(tokens_crop,i)
                for x,y in enumerate(tempList):
                    idList = virtuoso.str_query_id(y)
                    if '' in idList:
                        idList.remove('')
                    candi_sub_list.extend(idList)
                if candi_sub_list:
                    break

        candi_sub_list=list(set(candi_sub_list))##[string,string...]
        # if '' in candi_sub_list:
        #     candi_sub_list.remove('')
       # using freebase suggest
        # if len(candi_sub_list) == 0:
        #     beg_idx, end_idx = beg_end_indices(scores, 0.2)
        #     sub_text = ' '.join(tokens[beg_idx:end_idx+1])
        #     sub_text = re.sub(r'\s(\w+)\s(n?\'[tsd])\s', r' \1\2 ', sub_text)
        #     suggest_subs = []
        #     for trial in range(3):
        #         try:
        #             suggest_subs = freebase.suggest_id(sub_text)
        #             print >> log_file, str(suggest_subs)
        #             break
        #         except:
        #             print >> sys.stderr, 'freebase suggest_id error: trial = %d, sub_text = %s' % (trial, sub_text)
        #     candi_sub_list.extend(suggest_subs)
            # if subject not in candi_sub_list:
            #     print >> log_file, '%s' % (str(question))

        # if potential subject founded
        if len(candi_sub_list) > 0:
            # add candidates to data
            countarry = np.zeros(len(candi_sub_list))
            for idx,candi_sub in enumerate(candi_sub_list):
                candi_rel_list = virtuoso.id_query_out_rel(fb,candi_sub)
                candi_rel_list = list(set(candi_rel_list))
                if '' in candi_rel_list:
                    candi_rel_list.remove('')
                if len(candi_rel_list) > 0:
                    if type_dict:
                        candi_type_list = [type_dict[t] for t in virtuoso.id_query_type(candi_sub) if type_dict.has_key(t)]
                        if len(candi_type_list) == 0:
                            candi_type_list.append(len(type_dict))
                        data.add_candidate(candi_sub, candi_rel_list, candi_type_list)
                    else:
                    	data.add_candidate(candi_sub, candi_rel_list)

                        # countarry[idx] = virtuoso.id_query_count(candi_sub)
                        # if '' in text:
                        #     text.remove('')
                        # if len(text) > 0:
                        #     data.add_candidate(candi_sub, candi_rel_list)
                        #     data.add_sub_text(text)
        # data.add_node_score(countarry)
            # make score mat
        if hasattr(data, 'cand_sub') and hasattr(data, 'cand_rel'):##有召回的存储
            # remove duplicate relations
            data.remove_duplicate()
        else :
            NoneMatch += 1
       
        data.anonymous_question = form_anonymous_quesion(question, beg_idx, end_idx) 
        new_data_list.append(data)
            # append to new_data_list
        # elif save_all:
        #     new_data_list.append(data)
                
        # loging information
        if subject in candi_sub_list:
            succ_match += 1

        if data_index % 100 == 0:
            print( '{0} / {1}: {2} / {3} = {4}'.format( data_index, len(data_list), succ_match,data_index,succ_match/float(data_index)))
           
    log_file.write('{0} {1} {2} '.format(succ_match, data_index, NoneMatch))
    log_file.write( '{0} / {1} = {2} '.format(succ_match, data_index, succ_match / float(data_index)))
    log_file.write( 'not matched number is {0}'.format(NoneMatch))

    log_file.close()
    return new_data_list
Ejemplo n.º 9
0
def query_candidate(data_list, pred_list, pid=0):
    log_file = open('logs/log.%d.txt' % (pid), 'wb')
    new_data_list = []

    succ_match = 0
    data_index = 0
    for pred, data in zip(pred_list, data_list):
        # incremnt data_index
        data_index += 1

        # extract scores
        scores = [float(score) for score in pred.strip().split()]

        # extract fields needed
        relation = data.relation
        subject = data.subject
        question = data.question
        tokens = question.split()

        # query name / alias by subject (id)
        candi_sub_list = []
        for threshold in np.arange(0.5, 0.0, -0.095):
            beg_idx, end_idx = beg_end_indices(scores, threshold)
            sub_text = ' '.join(tokens[beg_idx:end_idx + 1])
            candi_sub_list.extend(virtuoso.str_query_id(sub_text))
            if len(candi_sub_list) > 0:
                break

        # # using freebase suggest
        # if len(candi_sub_list) == 0:
        #     beg_idx, end_idx = beg_end_indices(scores, 0.2)
        #     sub_text = ' '.join(tokens[beg_idx:end_idx+1])
        #     sub_text = re.sub(r'\s(\w+)\s(n?\'[tsd])\s', r' \1\2 ', sub_text)
        #     suggest_subs = []
        #     for trial in range(3):
        #         try:
        #             suggest_subs = freebase.suggest_id(sub_text)
        #             break
        #         except:
        #             print >> sys.stderr, 'freebase suggest_id error: trial = %d, sub_text = %s' % (trial, sub_text)
        #     candi_sub_list.extend(suggest_subs)
        #     if data.subject not in candi_sub_list:
        #         print >> log_file, '%s\t\t%s\t\t%s\t\t%d' % (sub_text, data.text_subject, fb2www(data.subject), len(candi_sub_list))

        # if potential subject founded
        if len(candi_sub_list) > 0:
            # add candidates to data
            for candi_sub in candi_sub_list:
                candi_rel_list = virtuoso.id_query_out_rel(candi_sub)
                if len(candi_rel_list) > 0:
                    if type_dict:
                        candi_type_list = [
                            type_dict[t]
                            for t in virtuoso.id_query_type(candi_sub)
                            if type_dict.has_key(t)
                        ]
                        if len(candi_type_list) == 0:
                            candi_type_list.append(len(type_dict))
                        data.add_candidate(candi_sub, candi_rel_list,
                                           candi_type_list)
                    else:
                        data.add_candidate(candi_sub, candi_rel_list)
            data.anonymous_question = form_anonymous_quesion(
                question, beg_idx, end_idx)

            # make score mat
        if hasattr(data, 'cand_sub') and hasattr(data, 'cand_rel'):
            # remove duplicate relations
            data.remove_duplicate()

            # append to new_data_list
            new_data_list.append(data)

        # loging information
        if subject in candi_sub_list:
            succ_match += 1

        if data_index % 100 == 0:
            print >> sys.stderr, '[%d] %d / %d' % (pid, data_index,
                                                   len(data_list))

    print >> log_file, '%d / %d = %f ' % (succ_match, data_index + 1,
                                          succ_match / float(data_index + 1))

    log_file.close()
    pickle.dump(new_data_list, file('temp.%d.cpickle' % (pid), 'wb'))
Ejemplo n.º 10
0
def create_seq_ranking_data(qa_data, word_vocab, rel_sep_vocab, rel_vocab,
                            save_path):
    seqs = []
    seq_len = []
    pos_rel1 = []
    pos_rel2 = []
    neg_rel1 = []
    neg_rel2 = []
    batch_index = -1  # the index of sequence batches
    seq_index = 0  # sequence index within each batch
    pad_index = word_vocab.lookup(word_vocab.pad_token)

    data_list = pickle.load(open(qa_data, 'rb'))

    def get_separated_rel_id(relation):
        rel = relation.split('.')
        rel1 = '.'.join(rel[:-1])
        rel2 = rel[-1]
        rel1_id = rel_sep_vocab[0].lookup(rel1)
        rel2_id = rel_sep_vocab[1].lookup(rel2)
        return rel1_id, rel2_id

    for data in data_list:
        tokens = data.question_pattern.split()
        can_rels = []
        if hasattr(data, 'cand_sub') and data.subject in data.cand_sub:
            can_rels = data.cand_rel
        else:
            can_subs = virtuoso.str_query_id(data.text_subject)
            for sub in can_subs:
                can_rels.extend(virtuoso.id_query_out_rel(sub))
            can_rels = list(set(can_rels))
        if data.relation in can_rels:
            can_rels.remove(data.relation)
        for i in range(len(can_rels), args.neg_size):
            tmp = random.randint(2, len(rel_vocab) - 1)
            while (tmp in can_rels):
                tmp = random.randint(2, len(rel_vocab) - 1)
            can_rels.append(rel_vocab.index2word[tmp])
        can_rels = random.sample(can_rels, args.neg_size)

        if seq_index % args.batch_size == 0:
            seq_index = 0
            batch_index += 1
            seqs.append(
                torch.LongTensor(args.batch_size,
                                 len(tokens)).fill_(pad_index))
            seq_len.append(torch.LongTensor(args.batch_size).fill_(1))
            pos_rel1.append(torch.LongTensor(args.batch_size).fill_(pad_index))
            pos_rel2.append(torch.LongTensor(args.batch_size).fill_(pad_index))
            neg_rel1.append(torch.LongTensor(args.neg_size, args.batch_size))
            neg_rel2.append(torch.LongTensor(args.neg_size, args.batch_size))
            print('batch: %d' % batch_index)

        seqs[batch_index][seq_index, 0:len(tokens)] = torch.LongTensor(
            word_vocab.convert_to_index(tokens))
        seq_len[batch_index][seq_index] = len(tokens)

        pos1, pos2 = get_separated_rel_id(data.relation)
        pos_rel1[batch_index][seq_index] = pos1
        pos_rel2[batch_index][seq_index] = pos2

        for j, neg_rel in enumerate(can_rels):
            neg1, neg2 = get_separated_rel_id(neg_rel)
            if not neg1 or not neg2:
                continue
            neg_rel1[batch_index][j, seq_index] = neg1
            neg_rel2[batch_index][j, seq_index] = neg2

        seq_index += 1

    torch.save((seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2),
               save_path)
Ejemplo n.º 11
0
Archivo: zyk.py Proyecto: zyksir/kbqa
import torch
import pickle
import sys
sys.path.append("../../code/tools")
import virtuoso

qa_data_path = '../../../kbqa/entity_detection/results/QAData.label.valid.pkl'
data_list = pickle.load(open(qa_data_path, "rb"))
for data in data_list:
    print(data.__dict__)
    data.cand_entities = virtuoso.str_query_id(data.text_subject)
    if data.cand_entities == []:
        print(data.__dict__)
        break
Ejemplo n.º 12
0
    for relation in old_relations:
        new_relations.append("fb:" + relation.replace("/", ".")[1:])
    return new_relations


split = "val"
dict_list = load_multi_data(
    split, ["main", "d-freebase-rp", "d-freebase", "d-freebase-mids"])
data_list = []
for qid, dic in dict_list.items():
    object_ids = []
    objects_text = []
    for answer in dic['answers']:
        objects_text.append(answer.lower())
        try:
            object_ids.append(virtuoso.str_query_id(answer))
        except:
            print("HTTP 400:")
            print(dic["answers"])

    subject = dic['freebaseKey'].replace("_", " ").lower()
    if subject in err_dict.keys():
        subject = err_dict[subject]

    question = dic["qText"].lower()
    text_subject, text_attention_indices, question_pattern = reverse_link(
        question, subject)

    subject_id = search_subject_id(dic["freebaseMids"], subject, text_subject)

    relations = dic["relPaths"]