Пример #1
0
def query_candidate(data_list, pid = 0):
    log_file = open('logs/log.round13.%d.txt'%(pid), 'wb')
    new_data_list = []
    data_index = 0
    NoneMatch = 0
    maxRelLen = 0
    for data in data_list:
        # incremnt data_index
        data_index += 1
        # extract fields needed
        relation = data.relation
        subject  = data.subject
        question = data.question
        ANquestion = data.anonymous_question
        if len(question.split()) > 1 and ANquestion:
            # query name / alias by subject (id)
            candi_rel_list = []
            candi_rel_list.extend(virtuoso.id_query_out_rel(subject))
           
            candi_rel_list=list(set(candi_rel_list))##[string,string...]
            if '' in candi_rel_list:
                candi_sub_list.remove('')
            data.add_candidate(subject, candi_rel_list)
            if relation in candi_rel_list:
                new_data_list.append(data)
                if len(candi_rel_list)>maxRelLen:
                    maxRelLen=len(candi_rel_list)
            else :
                NoneMatch += 1
                # print >> log_file,'%s' % (question)
            
    print ('not matched number is %d' % (NoneMatch)) 
    print('maximum candidate relation number is %d' % (maxRelLen))
    log_file.close()
    pickle.dump(new_data_list, file('temp.%d.cpickle'%(pid),'wb'))
Пример #2
0
def search_subject_id(concept_list, subject, relations):
    subject_id_list = []
    for concept in concept_list:
        # name = concept["concept"].lower()
        # name = name.replace(",", "").replace(".", "")
        # if name == subject or name == text_subject:
        #     return concept["mid"]
        mid = "fb:" + concept["mid"]
        if relations[0] in virtuoso.id_query_out_rel(mid):
            subject_id_list.append(mid)
    return subject_id_list
Пример #3
0
def create_seq_ranking_data(batch_size, qa_data, word_vocab, rel_vocab):
    file_type = qa_data.split('.')[-2]
    #    log_file = open('data/%s.relation_ranking.txt' %file_type, 'w')
    seqs = []
    pos_rel = []
    neg_rel = []
    neg_rel_size = []
    batch_index = -1  # the index of sequence batches
    seq_index = 0  # sequence index within each batch
    pad_index = word_vocab.lookup(word_vocab.pad_token)

    data_list = pickle.load(open(qa_data, 'rb'))
    for data in data_list:
        tokens = data.question.split()
        #取相同name的所有subject相连的rel作为负样本
        can_subs = virtuoso.str_query_id(data.text_subject)
        can_rels = []
        for sub in can_subs:
            can_rels.extend(virtuoso.id_query_out_rel(sub))
        can_rels = list(set(can_rels))  # 去除重复的rel
        #        log_file.write('%s\t%s\t%s\n' %(data.question, data.relation, can_rels))

        if seq_index % batch_size == 0:
            seq_index = 0
            batch_index += 1
            seqs.append(
                torch.LongTensor(len(tokens), batch_size).fill_(pad_index))
            pos_rel.append(torch.LongTensor(batch_size).fill_(pad_index))
            neg_rel.append([])
            neg_rel_size.append([])
            print('batch: %d' % batch_index)

        seqs[batch_index][0:len(tokens), seq_index] = torch.LongTensor(
            word_vocab.convert_to_index(tokens))
        pos_rel[batch_index][seq_index] = rel_vocab.lookup(data.relation)
        neg_rel[batch_index].append(rel_vocab.convert_to_index(can_rels))
        neg_rel_size[batch_index].append(len(can_rels))
        seq_index += 1

    torch.save((seqs, pos_rel, neg_rel, neg_rel_size),
               'data/%s.relation_ranking.pt' % file_type)
Пример #4
0
def query_golden_subs(data):
    golden_subs = []
    if data.text_subject:
        # extract fields needed
        relation     = data.relation
        subject      = data.subject
        text_subject = data.text_subject
        
        # query name / alias by subject (id)
        candi_sub_list = virtuoso.str_query_id(text_subject)

        # add candidates to data
        for candi_sub in candi_sub_list:
            candi_rel_list = virtuoso.id_query_out_rel(candi_sub)
            if relation in candi_rel_list:
                golden_subs.append(candi_sub)

    if len(golden_subs) == 0:
        golden_subs = [data.subject]

    return golden_subs
Пример #5
0
def query_candidate(data_list, pred_list,logName):
    log_file = open(logName, 'w')
    new_data_list = []

    NoneMatch = 0
    succ_match = 0
    data_index = 0
    for pred, data in zip(pred_list, data_list):
        # extract scores
        scores = np.array([int(float(score)) for score in pred.decode().strip().split()])

        # extract fields needed
        relation = data.relation
        subject  = data.subject
        question = data.question
        text_attention_indices = data.text_attention_indices
        if not text_attention_indices:
            continue
        # incremnt data_index
        data_index += 1            
        # print([question])
        tokens   = np.array(question.split())

        # query name / alias by subject (id)
        candi_sub_list = []
        # for threshold in np.arange(0.5, 0.0, -0.095):
        #     beg_idx, end_idx = beg_end_indices(scores, threshold)
        #     sub_text = ' '.join(tokens[beg_idx:end_idx+1])
        #     candi_sub_list.extend(virtuoso.str_query_id(sub_text))
        #     if len(candi_sub_list) > 0:
        #         break

        beg_idx, end_idx = beg_end_indices(scores, 0.2)
        tokens_crop = tokens[beg_idx:end_idx+1]
        sub_text = ' '.join(tokens_crop)
        text_list=[]
        # for i in [1,-1,2,-2]:
        #     if beg_idx-i>=0 and beg_idx-i<=end_idx:
        #         tokens_crop2=tokens[beg_idx-i:end_idx+1]
        #         text_list.append(' '.join(tokens_crop2))
        #     if end_idx+i<seq_len and end_idx+i>=beg_idx:
        #         tokens_crop2=tokens[beg_idx:end_idx+i+1]
        #         text_list.append(' '.join(tokens_crop2))
        #         if beg_idx-i>=0 and beg_idx-i<=end_idx+i:
        #             tokens_crop2=tokens[beg_idx-i:end_idx+i+1]
        #             text_list.append(' '.join(tokens_crop2))
        candi_sub_list.extend(virtuoso.str_query_id(sub_text))
        if '' in candi_sub_list:
            candi_sub_list.remove('')

            # if candi_sub_list==[]:
            #     for text in text_list:
            #         candi_sub_list.extend(virtuoso.str_query_id(text))
            #         if '' in candi_sub_list:
            #             candi_sub_list.remove('')
            #         if candi_sub_list!=[]:
            #             break
        if candi_sub_list:
            data.set_strict_flag(True)
            pass
        else:
            data.set_strict_flag(False)
            candi_sub_list.extend(virtuoso.partial_str_query_id(sub_text))
            if '' in candi_sub_list:
                candi_sub_list.remove('')
        if not candi_sub_list:
            for i in range(len(tokens_crop)-1,1,-1):
                tempList = generate_ngrams(tokens_crop,i)
                for x,y in enumerate(tempList):
                    idList = virtuoso.str_query_id(y)
                    if '' in idList:
                        idList.remove('')
                    candi_sub_list.extend(idList)
                if candi_sub_list:
                    break

        candi_sub_list=list(set(candi_sub_list))##[string,string...]
        # if '' in candi_sub_list:
        #     candi_sub_list.remove('')
       # using freebase suggest
        # if len(candi_sub_list) == 0:
        #     beg_idx, end_idx = beg_end_indices(scores, 0.2)
        #     sub_text = ' '.join(tokens[beg_idx:end_idx+1])
        #     sub_text = re.sub(r'\s(\w+)\s(n?\'[tsd])\s', r' \1\2 ', sub_text)
        #     suggest_subs = []
        #     for trial in range(3):
        #         try:
        #             suggest_subs = freebase.suggest_id(sub_text)
        #             print >> log_file, str(suggest_subs)
        #             break
        #         except:
        #             print >> sys.stderr, 'freebase suggest_id error: trial = %d, sub_text = %s' % (trial, sub_text)
        #     candi_sub_list.extend(suggest_subs)
            # if subject not in candi_sub_list:
            #     print >> log_file, '%s' % (str(question))

        # if potential subject founded
        if len(candi_sub_list) > 0:
            # add candidates to data
            countarry = np.zeros(len(candi_sub_list))
            for idx,candi_sub in enumerate(candi_sub_list):
                candi_rel_list = virtuoso.id_query_out_rel(fb,candi_sub)
                candi_rel_list = list(set(candi_rel_list))
                if '' in candi_rel_list:
                    candi_rel_list.remove('')
                if len(candi_rel_list) > 0:
                    if type_dict:
                        candi_type_list = [type_dict[t] for t in virtuoso.id_query_type(candi_sub) if type_dict.has_key(t)]
                        if len(candi_type_list) == 0:
                            candi_type_list.append(len(type_dict))
                        data.add_candidate(candi_sub, candi_rel_list, candi_type_list)
                    else:
                    	data.add_candidate(candi_sub, candi_rel_list)

                        # countarry[idx] = virtuoso.id_query_count(candi_sub)
                        # if '' in text:
                        #     text.remove('')
                        # if len(text) > 0:
                        #     data.add_candidate(candi_sub, candi_rel_list)
                        #     data.add_sub_text(text)
        # data.add_node_score(countarry)
            # make score mat
        if hasattr(data, 'cand_sub') and hasattr(data, 'cand_rel'):##有召回的存储
            # remove duplicate relations
            data.remove_duplicate()
        else :
            NoneMatch += 1
       
        data.anonymous_question = form_anonymous_quesion(question, beg_idx, end_idx) 
        new_data_list.append(data)
            # append to new_data_list
        # elif save_all:
        #     new_data_list.append(data)
                
        # loging information
        if subject in candi_sub_list:
            succ_match += 1

        if data_index % 100 == 0:
            print( '{0} / {1}: {2} / {3} = {4}'.format( data_index, len(data_list), succ_match,data_index,succ_match/float(data_index)))
           
    log_file.write('{0} {1} {2} '.format(succ_match, data_index, NoneMatch))
    log_file.write( '{0} / {1} = {2} '.format(succ_match, data_index, succ_match / float(data_index)))
    log_file.write( 'not matched number is {0}'.format(NoneMatch))

    log_file.close()
    return new_data_list
Пример #6
0
def predict(dataset=args.test_file, tp='test', save_qadata=args.save_qadata):
    # load QAdata
    qa_data_path = '../data/QAData.%s.pkl' % tp
    qa_data = pickle.load(open(qa_data_path,'rb'))

    # load batch data for predict
    data_loader = SeqLabelingLoader(dataset, args.batch_size)
    print('load %s data, batch_num: %d\tbatch_size: %d'
            %(tp, data_loader.batch_num, data_loader.batch_size))

    model.eval();

    n_correct = 0
    n_correct_sub = 0
    n_correct_extend = 0
    n_empty = 0
    n_cand_entity=0
    linenum = 1
    qa_data_idx = 0

    new_qa_data = []

    gold_list = []
    pred_list = []
    compare_pred=[]
    
    single_correct=0
    total=0
    EDdata=torch.load(dataset)
  
    batches=[EDdata[i*args.batch_size:(i+1)*args.batch_size] for i in range(math.ceil(len(EDdata)/args.batch_size))]
    for data_batch_idx, data_batch in enumerate(batches):
        if data_batch_idx % 50 == 0:
            print(tp, data_batch_idx)
        seqs,labels,lengths=zip(*data_batch)
        total+=len(lengths)
        # sorted
        seqs,labels,lengths=get_batch_Tensor(seqs,labels,lengths)
        lengths,indices_len=torch.sort(lengths,descending=True)
        seqs=seqs[indices_len]
        labels=labels[indices_len]
        
        scores=model(seqs,lengths)
        mask=model.sequence_mask(lengths,lengths[0])
        
        # recover
        _ , indices_recover=torch.sort(indices_len)
        #from ipdb import set_trace
        #set_trace()
        scores=scores[indices_recover]
        lengths=lengths[indices_recover]
        mask=mask[indices_recover]
        labels=labels[indices_recover]

        paths_batch=model.get_path_topk(scores,mask,topk=args.topk)

        # verify the prediction
        for label,path_topk,length in zip(labels,paths_batch,lengths):
            #subjects_list=predict_subject_name(path_topk)
            #target_subject=predict_subject_name(label)
            for path in path_topk:
                if (path.data==label[:length].data).sum(0)==length:
                    single_correct+=1
        
        for i in range(len(lengths)):
            while qa_data_idx<len(qa_data) and not qa_data[qa_data_idx].text_subject:
                qa_data_idx+=1
            if qa_data_idx>=len(qa_data):
                break
            _qa_data=qa_data[qa_data_idx]
            tokens=_qa_data.question.split()
            # subjects
            predict_sub=predict_subject_ids(paths_batch[i],tokens)
            assert _qa_data.num_text_token==lengths[i]
            if _qa_data.subject in predict_sub:
                n_correct_sub+=1
                #from ipdb import set_trace
                #set_trace()
                '''
                flag=False
                a,b=paths_batch[i].shape
                
                for paths in paths_batch[i]:
                    if (labels[i][:b]==paths).sum()==b:
                        flag=True
                        break
                
                if not flag:
                    print(labels[i][:lengths[i]])
                    print(paths_batch[i])
                    print(_qa_data.subject)
                    print(predict_sub)
                '''
            n_cand_entity+=len(predict_sub)        
            if not predict_sub:
                n_empty+=1
            qa_data_idx+=1
            if save_qadata:
                for sub in predict_sub:
                    rel = virtuoso.id_query_out_rel(sub)
                    _qa_data.add_candidate(sub,rel)
                if hasattr(_qa_data,'cand_rel'):
                    _qa_data.remove_duplicate()
                new_qa_data.append((_qa_data,len(_qa_data.question_pattern.split())))


    print("Average size of candidate entities:%0.6f"%(n_cand_entity/total))
    print("%s\n----------------------------------\n"%(tp))
    name_accuracy=1.0*single_correct/total
    print("name accuracy\taccuracy:%0.6f\tcorrect:%d\ttotal:%d\n"%(name_accuracy,single_correct,total))
    
    id_accuracy=1.0* n_correct_sub/total
    print("id accuracy\taccuracy:%0.6f\tcorrect:%d\ttotal:%d\n"%(id_accuracy,n_correct_sub,total))

    print("subject not found:%0.6f\t%d"%(1.0*n_empty/total,n_empty))
    print("-"*80)
    
    if save_qadata:
        qadata_save_path=open(os.path.join(args.results_path,'QAData.label.%s.pkl'%(tp)),'wb')
        data_list=[data[0] for data in sorted(new_qa_data,key=lambda data:data[1],reverse=True)]
        pickle.dump(data_list,qadata_save_path)
Пример #7
0
def query_candidate(data_list, pred_list, pid=0):
    log_file = open('logs/log.%d.txt' % (pid), 'wb')
    new_data_list = []

    succ_match = 0
    data_index = 0
    for pred, data in zip(pred_list, data_list):
        # incremnt data_index
        data_index += 1

        # extract scores
        scores = [float(score) for score in pred.strip().split()]

        # extract fields needed
        relation = data.relation
        subject = data.subject
        question = data.question
        tokens = question.split()

        # query name / alias by subject (id)
        candi_sub_list = []
        for threshold in np.arange(0.5, 0.0, -0.095):
            beg_idx, end_idx = beg_end_indices(scores, threshold)
            sub_text = ' '.join(tokens[beg_idx:end_idx + 1])
            candi_sub_list.extend(virtuoso.str_query_id(sub_text))
            if len(candi_sub_list) > 0:
                break

        # # using freebase suggest
        # if len(candi_sub_list) == 0:
        #     beg_idx, end_idx = beg_end_indices(scores, 0.2)
        #     sub_text = ' '.join(tokens[beg_idx:end_idx+1])
        #     sub_text = re.sub(r'\s(\w+)\s(n?\'[tsd])\s', r' \1\2 ', sub_text)
        #     suggest_subs = []
        #     for trial in range(3):
        #         try:
        #             suggest_subs = freebase.suggest_id(sub_text)
        #             break
        #         except:
        #             print >> sys.stderr, 'freebase suggest_id error: trial = %d, sub_text = %s' % (trial, sub_text)
        #     candi_sub_list.extend(suggest_subs)
        #     if data.subject not in candi_sub_list:
        #         print >> log_file, '%s\t\t%s\t\t%s\t\t%d' % (sub_text, data.text_subject, fb2www(data.subject), len(candi_sub_list))

        # if potential subject founded
        if len(candi_sub_list) > 0:
            # add candidates to data
            for candi_sub in candi_sub_list:
                candi_rel_list = virtuoso.id_query_out_rel(candi_sub)
                if len(candi_rel_list) > 0:
                    if type_dict:
                        candi_type_list = [
                            type_dict[t]
                            for t in virtuoso.id_query_type(candi_sub)
                            if type_dict.has_key(t)
                        ]
                        if len(candi_type_list) == 0:
                            candi_type_list.append(len(type_dict))
                        data.add_candidate(candi_sub, candi_rel_list,
                                           candi_type_list)
                    else:
                        data.add_candidate(candi_sub, candi_rel_list)
            data.anonymous_question = form_anonymous_quesion(
                question, beg_idx, end_idx)

            # make score mat
        if hasattr(data, 'cand_sub') and hasattr(data, 'cand_rel'):
            # remove duplicate relations
            data.remove_duplicate()

            # append to new_data_list
            new_data_list.append(data)

        # loging information
        if subject in candi_sub_list:
            succ_match += 1

        if data_index % 100 == 0:
            print >> sys.stderr, '[%d] %d / %d' % (pid, data_index,
                                                   len(data_list))

    print >> log_file, '%d / %d = %f ' % (succ_match, data_index + 1,
                                          succ_match / float(data_index + 1))

    log_file.close()
    pickle.dump(new_data_list, file('temp.%d.cpickle' % (pid), 'wb'))
Пример #8
0
def predict(dataset=args.test_file, tp='test', save_qadata=args.save_qadata):
    # load QAdata
    qa_data_path = './data/QAData.%s.pkl' % tp
    qa_data = pickle.load(open(qa_data_path, 'rb'))

    # load batch data for predict
    data_loader = SeqLabelingLoader(dataset, args.gpu)
    print('load %s data, batch_num: %d\tbatch_size: %d' %
          (tp, data_loader.batch_num, data_loader.batch_size))

    model.eval()

    n_correct = 0
    n_correct_sub = 0
    n_correct_extend = 0
    n_empty = 0
    linenum = 1
    qa_data_idx = 0

    new_qa_data = []

    gold_list = []
    pred_list = []

    for data_batch_idx, data_batch in enumerate(
            data_loader.next_batch(shuffle=False)):
        if data_batch_idx % 50 == 0:
            print(tp, data_batch_idx)
        scores = model(data_batch)
        n_correct += ((torch.max(scores, 1)[1].view(
            data_batch[1].size()).data == data_batch[1].data).sum(
                dim=0) == data_batch[1].size()[0]).sum()

        index_tag = np.transpose(
            torch.max(scores,
                      1)[1].view(data_batch[1].size()).cpu().data.numpy())
        gold_tag = np.transpose(data_batch[1].cpu().data.numpy())
        index_question = np.transpose(data_batch[0].cpu().data.numpy())

        gold_list.append(np.transpose(data_batch[1].cpu().data.numpy()))
        pred_list.append(index_tag)

        for i in range(data_loader.batch_size):
            while qa_data_idx < len(
                    qa_data) and not qa_data[qa_data_idx].text_subject:
                qa_data_idx += 1
            if qa_data_idx >= len(qa_data):
                break
            _qa_data = qa_data[qa_data_idx]
            tokens = np.array(_qa_data.question.split())
            pred_text = ' '.join(tokens[np.where(index_tag[i][:len(tokens)])])
            _qa_data.pred_text = pred_text

            pred_sub, pred_sub_extend = get_candidate_sub(tokens, index_tag[i])
            if _qa_data.subject in pred_sub:
                n_correct_sub += 1
            if _qa_data.subject in pred_sub_extend:
                n_correct_extend += 1
            if not pred_sub_extend:
                n_empty += 1

            if save_qadata:
                for sub in pred_sub_extend:
                    rel = virtuoso.id_query_out_rel(sub)
                    _qa_data.add_candidate(sub, rel)
                if hasattr(_qa_data, 'cand_rel'):
                    _qa_data.remove_duplicate()

                # if _qa_data.subject not in pred_sub_extend:
                #     _qa_data.neg_rel = virtuoso.id_query_out_rel(_qa_data.subject)

                new_qa_data.append(
                    (_qa_data, len(_qa_data.question_pattern.split())))

            linenum += 1
            qa_data_idx += 1

    total = linenum - 1
    accuracy = 100. * n_correct / total
    print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" %
          (tp, accuracy, n_correct, total))
    P, R, F = evaluation(gold_list, pred_list)
    print("Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format(
        100. * P, 100. * R, 100. * F))

    sub_accuracy = 100. * n_correct_sub / total
    print('subject accuracy: %8.6f\tcorrect: %d\ttotal:%d' %
          (sub_accuracy, n_correct_sub, total))

    extend_accuracy = 100. * n_correct_extend / total
    print('extend accuracy: %8.6f\tcorrect: %d\ttotal:%d' %
          (extend_accuracy, n_correct_extend, total))

    print('suject not found: %8.6f\t%d' % (n_empty / total, n_empty))
    print("-" * 80)

    if save_qadata:
        qadata_save_path = open(
            os.path.join(args.results_path, 'QAData.label.%s.pkl' % (tp)),
            'wb')
        data_list = [
            data[0] for data in sorted(
                new_qa_data, key=lambda data: data[1], reverse=True)
        ]
        pickle.dump(data_list, qadata_save_path)
Пример #9
0
def create_seq_ranking_data(qa_data, word_vocab, rel_sep_vocab, rel_vocab,
                            save_path):
    seqs = []
    seq_len = []
    pos_rel1 = []
    pos_rel2 = []
    neg_rel1 = []
    neg_rel2 = []
    batch_index = -1  # the index of sequence batches
    seq_index = 0  # sequence index within each batch
    pad_index = word_vocab.lookup(word_vocab.pad_token)

    data_list = pickle.load(open(qa_data, 'rb'))

    def get_separated_rel_id(relation):
        rel = relation.split('.')
        rel1 = '.'.join(rel[:-1])
        rel2 = rel[-1]
        rel1_id = rel_sep_vocab[0].lookup(rel1)
        rel2_id = rel_sep_vocab[1].lookup(rel2)
        return rel1_id, rel2_id

    for data in data_list:
        tokens = data.question_pattern.split()
        can_rels = []
        if hasattr(data, 'cand_sub') and data.subject in data.cand_sub:
            can_rels = data.cand_rel
        else:
            can_subs = virtuoso.str_query_id(data.text_subject)
            for sub in can_subs:
                can_rels.extend(virtuoso.id_query_out_rel(sub))
            can_rels = list(set(can_rels))
        if data.relation in can_rels:
            can_rels.remove(data.relation)
        for i in range(len(can_rels), args.neg_size):
            tmp = random.randint(2, len(rel_vocab) - 1)
            while (tmp in can_rels):
                tmp = random.randint(2, len(rel_vocab) - 1)
            can_rels.append(rel_vocab.index2word[tmp])
        can_rels = random.sample(can_rels, args.neg_size)

        if seq_index % args.batch_size == 0:
            seq_index = 0
            batch_index += 1
            seqs.append(
                torch.LongTensor(args.batch_size,
                                 len(tokens)).fill_(pad_index))
            seq_len.append(torch.LongTensor(args.batch_size).fill_(1))
            pos_rel1.append(torch.LongTensor(args.batch_size).fill_(pad_index))
            pos_rel2.append(torch.LongTensor(args.batch_size).fill_(pad_index))
            neg_rel1.append(torch.LongTensor(args.neg_size, args.batch_size))
            neg_rel2.append(torch.LongTensor(args.neg_size, args.batch_size))
            print('batch: %d' % batch_index)

        seqs[batch_index][seq_index, 0:len(tokens)] = torch.LongTensor(
            word_vocab.convert_to_index(tokens))
        seq_len[batch_index][seq_index] = len(tokens)

        pos1, pos2 = get_separated_rel_id(data.relation)
        pos_rel1[batch_index][seq_index] = pos1
        pos_rel2[batch_index][seq_index] = pos2

        for j, neg_rel in enumerate(can_rels):
            neg1, neg2 = get_separated_rel_id(neg_rel)
            if not neg1 or not neg2:
                continue
            neg_rel1[batch_index][j, seq_index] = neg1
            neg_rel2[batch_index][j, seq_index] = neg2

        seq_index += 1

    torch.save((seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2),
               save_path)
Пример #10
0
def predict(dataset=args.test_file,
            tp='test',
            write=args.write,
            save_qadata=args.save_qadata):
    # load QAdata
    qa_data_path = '../data/QAData.%s.pkl' % tp
    qa_data = pickle.load(open(qa_data_path, 'rb'))

    # load batch data for predict
    data_loader = SeqLabelingLoader(dataset, args.gpu)
    print('load %s data, batch_num: %d\tbatch_size: %d' %
          (tp, data_loader.batch_num, data_loader.batch_size))

    model.eval()

    n_correct = 0
    n_correct_sub = 0
    n_correct_extend = 0
    n_empty = 0
    linenum = 1
    qa_data_idx = 0

    if write:
        results_file = open(
            os.path.join(args.results_path, '%s-results.txt' % tp), 'w')
        results_file_sub = open(
            os.path.join(args.results_path, '%s-results-subject.txt' % tp),
            'w')

    new_qa_data = []

    gold_list = []
    pred_list = []

    for data_batch_idx, data_batch in enumerate(
            data_loader.next_batch(shuffle=False)):
        if data_batch_idx % 50 == 0:
            print(tp, data_batch_idx)
        scores = model(data_batch)
        # 计算有多少条是和seq_labels完全一样的
        n_correct += ((torch.max(scores, 1)[1].view(
            data_batch[1].size()).data == data_batch[1].data).sum(
                dim=0) == data_batch[1].size()[0]).sum()

        # 预测的label和实际的label,后面要转为对应的text。注意都要transpose
        index_tag = np.transpose(
            torch.max(scores,
                      1)[1].view(data_batch[1].size()).cpu().data.numpy())
        gold_tag = np.transpose(data_batch[1].cpu().data.numpy())
        index_question = np.transpose(data_batch[0].cpu().data.numpy())

        gold_list.append(np.transpose(data_batch[1].cpu().data.numpy()))
        pred_list.append(index_tag)

        for i in range(data_loader.batch_size):
            # 转为QAData中对应的text,去FB中查MID,计算subject的准确率
            while qa_data_idx < len(
                    qa_data) and not qa_data[qa_data_idx].text_subject:
                qa_data_idx += 1  # 在loader里去掉了没有text_subject的数据,而QADate是全的
            if qa_data_idx >= len(
                    qa_data):  # 最后一个batch后面都是<pad>填充的,此时qa_data已经找到头了
                break
            _qa_data = qa_data[qa_data_idx]
            tokens = np.array(_qa_data.question.split())
            pred_text = ' '.join(tokens[np.where(
                index_tag[i]
                [:len(tokens)])])  # index_tag可能比实际的question长,因为后面加了<pad>

            # 计算扩展生成candidate subject的准确率
            pred_sub, pred_sub_extend = get_candidate_sub(tokens, index_tag[i])
            if _qa_data.subject in pred_sub:
                n_correct_sub += 1
            if _qa_data.subject in pred_sub_extend:
                n_correct_extend += 1
            if not pred_sub_extend:
                n_empty += 1

            if write:
                if pred_sub == pred_sub_extend:
                    pred_sub = 'RRR'
                results_file_sub.write('%s-%d\t%s\t%s\t%s\t%s\t%s\t%s\n' %(tp, linenum, _qa_data.question, \
                                                            pred_sub, pred_sub_extend, _qa_data.subject, \
                                                            pred_text, _qa_data.text_subject))

                question_array = np.array(
                    word_vocab.convert_to_word(index_question[i]))
                pred_array = question_array[np.where(index_tag[i])]
                gold_array = question_array[np.where(gold_tag[i])]
                line_to_print = '%s-%d\t%s\t%s\t%s' %(tp, linenum, " ".join(question_array), \
                                                       " ".join(pred_array), " ".join(gold_array))
                results_file.write(line_to_print + "\n")

            if save_qadata:
                for sub in pred_sub_extend:
                    rel = virtuoso.id_query_out_rel(sub)
                    _qa_data.add_candidate(sub, rel)
                if hasattr(_qa_data, 'cand_rel'):
                    _qa_data.remove_duplicate()


#                if _qa_data.subject not in pred_sub_extend:
#                    _qa_data.neg_rel = virtuoso.id_query_out_rel(_qa_data.subject)

                new_qa_data.append(
                    (_qa_data, len(_qa_data.question_pattern.split())))

            linenum += 1
            qa_data_idx += 1

    total = linenum - 1
    accuracy = 100. * n_correct / total
    print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" %
          (tp, accuracy, n_correct, total))
    P, R, F = evaluation(gold_list, pred_list)
    print("Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format(
        100. * P, 100. * R, 100. * F))

    sub_accuracy = 100. * n_correct_sub / total
    print('subject accuracy: %8.6f\tcorrect: %d\ttotal:%d' %
          (sub_accuracy, n_correct_sub, total))

    extend_accuracy = 100. * n_correct_extend / total
    print('extend accuracy: %8.6f\tcorrect: %d\ttotal:%d' %
          (extend_accuracy, n_correct_extend, total))

    print('suject not found: %8.6f\t%d' % (n_empty / total, n_empty))
    print("-" * 80)

    if write:
        results_file.close()
        results_file_sub.close()
    if save_qadata:
        qadata_save_path = open(
            os.path.join(args.results_path, 'QAData.label.%s.pkl' % (tp)),
            'wb')
        data_list = [
            data[0] for data in sorted(
                new_qa_data, key=lambda data: data[1], reverse=True)
        ]
        pickle.dump(data_list, qadata_save_path)