def query_candidate(data_list, pid = 0): log_file = open('logs/log.round13.%d.txt'%(pid), 'wb') new_data_list = [] data_index = 0 NoneMatch = 0 maxRelLen = 0 for data in data_list: # incremnt data_index data_index += 1 # extract fields needed relation = data.relation subject = data.subject question = data.question ANquestion = data.anonymous_question if len(question.split()) > 1 and ANquestion: # query name / alias by subject (id) candi_rel_list = [] candi_rel_list.extend(virtuoso.id_query_out_rel(subject)) candi_rel_list=list(set(candi_rel_list))##[string,string...] if '' in candi_rel_list: candi_sub_list.remove('') data.add_candidate(subject, candi_rel_list) if relation in candi_rel_list: new_data_list.append(data) if len(candi_rel_list)>maxRelLen: maxRelLen=len(candi_rel_list) else : NoneMatch += 1 # print >> log_file,'%s' % (question) print ('not matched number is %d' % (NoneMatch)) print('maximum candidate relation number is %d' % (maxRelLen)) log_file.close() pickle.dump(new_data_list, file('temp.%d.cpickle'%(pid),'wb'))
def search_subject_id(concept_list, subject, relations): subject_id_list = [] for concept in concept_list: # name = concept["concept"].lower() # name = name.replace(",", "").replace(".", "") # if name == subject or name == text_subject: # return concept["mid"] mid = "fb:" + concept["mid"] if relations[0] in virtuoso.id_query_out_rel(mid): subject_id_list.append(mid) return subject_id_list
def create_seq_ranking_data(batch_size, qa_data, word_vocab, rel_vocab): file_type = qa_data.split('.')[-2] # log_file = open('data/%s.relation_ranking.txt' %file_type, 'w') seqs = [] pos_rel = [] neg_rel = [] neg_rel_size = [] batch_index = -1 # the index of sequence batches seq_index = 0 # sequence index within each batch pad_index = word_vocab.lookup(word_vocab.pad_token) data_list = pickle.load(open(qa_data, 'rb')) for data in data_list: tokens = data.question.split() #取相同name的所有subject相连的rel作为负样本 can_subs = virtuoso.str_query_id(data.text_subject) can_rels = [] for sub in can_subs: can_rels.extend(virtuoso.id_query_out_rel(sub)) can_rels = list(set(can_rels)) # 去除重复的rel # log_file.write('%s\t%s\t%s\n' %(data.question, data.relation, can_rels)) if seq_index % batch_size == 0: seq_index = 0 batch_index += 1 seqs.append( torch.LongTensor(len(tokens), batch_size).fill_(pad_index)) pos_rel.append(torch.LongTensor(batch_size).fill_(pad_index)) neg_rel.append([]) neg_rel_size.append([]) print('batch: %d' % batch_index) seqs[batch_index][0:len(tokens), seq_index] = torch.LongTensor( word_vocab.convert_to_index(tokens)) pos_rel[batch_index][seq_index] = rel_vocab.lookup(data.relation) neg_rel[batch_index].append(rel_vocab.convert_to_index(can_rels)) neg_rel_size[batch_index].append(len(can_rels)) seq_index += 1 torch.save((seqs, pos_rel, neg_rel, neg_rel_size), 'data/%s.relation_ranking.pt' % file_type)
def query_golden_subs(data): golden_subs = [] if data.text_subject: # extract fields needed relation = data.relation subject = data.subject text_subject = data.text_subject # query name / alias by subject (id) candi_sub_list = virtuoso.str_query_id(text_subject) # add candidates to data for candi_sub in candi_sub_list: candi_rel_list = virtuoso.id_query_out_rel(candi_sub) if relation in candi_rel_list: golden_subs.append(candi_sub) if len(golden_subs) == 0: golden_subs = [data.subject] return golden_subs
def query_candidate(data_list, pred_list,logName): log_file = open(logName, 'w') new_data_list = [] NoneMatch = 0 succ_match = 0 data_index = 0 for pred, data in zip(pred_list, data_list): # extract scores scores = np.array([int(float(score)) for score in pred.decode().strip().split()]) # extract fields needed relation = data.relation subject = data.subject question = data.question text_attention_indices = data.text_attention_indices if not text_attention_indices: continue # incremnt data_index data_index += 1 # print([question]) tokens = np.array(question.split()) # query name / alias by subject (id) candi_sub_list = [] # for threshold in np.arange(0.5, 0.0, -0.095): # beg_idx, end_idx = beg_end_indices(scores, threshold) # sub_text = ' '.join(tokens[beg_idx:end_idx+1]) # candi_sub_list.extend(virtuoso.str_query_id(sub_text)) # if len(candi_sub_list) > 0: # break beg_idx, end_idx = beg_end_indices(scores, 0.2) tokens_crop = tokens[beg_idx:end_idx+1] sub_text = ' '.join(tokens_crop) text_list=[] # for i in [1,-1,2,-2]: # if beg_idx-i>=0 and beg_idx-i<=end_idx: # tokens_crop2=tokens[beg_idx-i:end_idx+1] # text_list.append(' '.join(tokens_crop2)) # if end_idx+i<seq_len and end_idx+i>=beg_idx: # tokens_crop2=tokens[beg_idx:end_idx+i+1] # text_list.append(' '.join(tokens_crop2)) # if beg_idx-i>=0 and beg_idx-i<=end_idx+i: # tokens_crop2=tokens[beg_idx-i:end_idx+i+1] # text_list.append(' '.join(tokens_crop2)) candi_sub_list.extend(virtuoso.str_query_id(sub_text)) if '' in candi_sub_list: candi_sub_list.remove('') # if candi_sub_list==[]: # for text in text_list: # candi_sub_list.extend(virtuoso.str_query_id(text)) # if '' in candi_sub_list: # candi_sub_list.remove('') # if candi_sub_list!=[]: # break if candi_sub_list: data.set_strict_flag(True) pass else: data.set_strict_flag(False) candi_sub_list.extend(virtuoso.partial_str_query_id(sub_text)) if '' in candi_sub_list: candi_sub_list.remove('') if not candi_sub_list: for i in range(len(tokens_crop)-1,1,-1): tempList = generate_ngrams(tokens_crop,i) for x,y in enumerate(tempList): idList = virtuoso.str_query_id(y) if '' in idList: idList.remove('') candi_sub_list.extend(idList) if candi_sub_list: break candi_sub_list=list(set(candi_sub_list))##[string,string...] # if '' in candi_sub_list: # candi_sub_list.remove('') # using freebase suggest # if len(candi_sub_list) == 0: # beg_idx, end_idx = beg_end_indices(scores, 0.2) # sub_text = ' '.join(tokens[beg_idx:end_idx+1]) # sub_text = re.sub(r'\s(\w+)\s(n?\'[tsd])\s', r' \1\2 ', sub_text) # suggest_subs = [] # for trial in range(3): # try: # suggest_subs = freebase.suggest_id(sub_text) # print >> log_file, str(suggest_subs) # break # except: # print >> sys.stderr, 'freebase suggest_id error: trial = %d, sub_text = %s' % (trial, sub_text) # candi_sub_list.extend(suggest_subs) # if subject not in candi_sub_list: # print >> log_file, '%s' % (str(question)) # if potential subject founded if len(candi_sub_list) > 0: # add candidates to data countarry = np.zeros(len(candi_sub_list)) for idx,candi_sub in enumerate(candi_sub_list): candi_rel_list = virtuoso.id_query_out_rel(fb,candi_sub) candi_rel_list = list(set(candi_rel_list)) if '' in candi_rel_list: candi_rel_list.remove('') if len(candi_rel_list) > 0: if type_dict: candi_type_list = [type_dict[t] for t in virtuoso.id_query_type(candi_sub) if type_dict.has_key(t)] if len(candi_type_list) == 0: candi_type_list.append(len(type_dict)) data.add_candidate(candi_sub, candi_rel_list, candi_type_list) else: data.add_candidate(candi_sub, candi_rel_list) # countarry[idx] = virtuoso.id_query_count(candi_sub) # if '' in text: # text.remove('') # if len(text) > 0: # data.add_candidate(candi_sub, candi_rel_list) # data.add_sub_text(text) # data.add_node_score(countarry) # make score mat if hasattr(data, 'cand_sub') and hasattr(data, 'cand_rel'):##有召回的存储 # remove duplicate relations data.remove_duplicate() else : NoneMatch += 1 data.anonymous_question = form_anonymous_quesion(question, beg_idx, end_idx) new_data_list.append(data) # append to new_data_list # elif save_all: # new_data_list.append(data) # loging information if subject in candi_sub_list: succ_match += 1 if data_index % 100 == 0: print( '{0} / {1}: {2} / {3} = {4}'.format( data_index, len(data_list), succ_match,data_index,succ_match/float(data_index))) log_file.write('{0} {1} {2} '.format(succ_match, data_index, NoneMatch)) log_file.write( '{0} / {1} = {2} '.format(succ_match, data_index, succ_match / float(data_index))) log_file.write( 'not matched number is {0}'.format(NoneMatch)) log_file.close() return new_data_list
def predict(dataset=args.test_file, tp='test', save_qadata=args.save_qadata): # load QAdata qa_data_path = '../data/QAData.%s.pkl' % tp qa_data = pickle.load(open(qa_data_path,'rb')) # load batch data for predict data_loader = SeqLabelingLoader(dataset, args.batch_size) print('load %s data, batch_num: %d\tbatch_size: %d' %(tp, data_loader.batch_num, data_loader.batch_size)) model.eval(); n_correct = 0 n_correct_sub = 0 n_correct_extend = 0 n_empty = 0 n_cand_entity=0 linenum = 1 qa_data_idx = 0 new_qa_data = [] gold_list = [] pred_list = [] compare_pred=[] single_correct=0 total=0 EDdata=torch.load(dataset) batches=[EDdata[i*args.batch_size:(i+1)*args.batch_size] for i in range(math.ceil(len(EDdata)/args.batch_size))] for data_batch_idx, data_batch in enumerate(batches): if data_batch_idx % 50 == 0: print(tp, data_batch_idx) seqs,labels,lengths=zip(*data_batch) total+=len(lengths) # sorted seqs,labels,lengths=get_batch_Tensor(seqs,labels,lengths) lengths,indices_len=torch.sort(lengths,descending=True) seqs=seqs[indices_len] labels=labels[indices_len] scores=model(seqs,lengths) mask=model.sequence_mask(lengths,lengths[0]) # recover _ , indices_recover=torch.sort(indices_len) #from ipdb import set_trace #set_trace() scores=scores[indices_recover] lengths=lengths[indices_recover] mask=mask[indices_recover] labels=labels[indices_recover] paths_batch=model.get_path_topk(scores,mask,topk=args.topk) # verify the prediction for label,path_topk,length in zip(labels,paths_batch,lengths): #subjects_list=predict_subject_name(path_topk) #target_subject=predict_subject_name(label) for path in path_topk: if (path.data==label[:length].data).sum(0)==length: single_correct+=1 for i in range(len(lengths)): while qa_data_idx<len(qa_data) and not qa_data[qa_data_idx].text_subject: qa_data_idx+=1 if qa_data_idx>=len(qa_data): break _qa_data=qa_data[qa_data_idx] tokens=_qa_data.question.split() # subjects predict_sub=predict_subject_ids(paths_batch[i],tokens) assert _qa_data.num_text_token==lengths[i] if _qa_data.subject in predict_sub: n_correct_sub+=1 #from ipdb import set_trace #set_trace() ''' flag=False a,b=paths_batch[i].shape for paths in paths_batch[i]: if (labels[i][:b]==paths).sum()==b: flag=True break if not flag: print(labels[i][:lengths[i]]) print(paths_batch[i]) print(_qa_data.subject) print(predict_sub) ''' n_cand_entity+=len(predict_sub) if not predict_sub: n_empty+=1 qa_data_idx+=1 if save_qadata: for sub in predict_sub: rel = virtuoso.id_query_out_rel(sub) _qa_data.add_candidate(sub,rel) if hasattr(_qa_data,'cand_rel'): _qa_data.remove_duplicate() new_qa_data.append((_qa_data,len(_qa_data.question_pattern.split()))) print("Average size of candidate entities:%0.6f"%(n_cand_entity/total)) print("%s\n----------------------------------\n"%(tp)) name_accuracy=1.0*single_correct/total print("name accuracy\taccuracy:%0.6f\tcorrect:%d\ttotal:%d\n"%(name_accuracy,single_correct,total)) id_accuracy=1.0* n_correct_sub/total print("id accuracy\taccuracy:%0.6f\tcorrect:%d\ttotal:%d\n"%(id_accuracy,n_correct_sub,total)) print("subject not found:%0.6f\t%d"%(1.0*n_empty/total,n_empty)) print("-"*80) if save_qadata: qadata_save_path=open(os.path.join(args.results_path,'QAData.label.%s.pkl'%(tp)),'wb') data_list=[data[0] for data in sorted(new_qa_data,key=lambda data:data[1],reverse=True)] pickle.dump(data_list,qadata_save_path)
def query_candidate(data_list, pred_list, pid=0): log_file = open('logs/log.%d.txt' % (pid), 'wb') new_data_list = [] succ_match = 0 data_index = 0 for pred, data in zip(pred_list, data_list): # incremnt data_index data_index += 1 # extract scores scores = [float(score) for score in pred.strip().split()] # extract fields needed relation = data.relation subject = data.subject question = data.question tokens = question.split() # query name / alias by subject (id) candi_sub_list = [] for threshold in np.arange(0.5, 0.0, -0.095): beg_idx, end_idx = beg_end_indices(scores, threshold) sub_text = ' '.join(tokens[beg_idx:end_idx + 1]) candi_sub_list.extend(virtuoso.str_query_id(sub_text)) if len(candi_sub_list) > 0: break # # using freebase suggest # if len(candi_sub_list) == 0: # beg_idx, end_idx = beg_end_indices(scores, 0.2) # sub_text = ' '.join(tokens[beg_idx:end_idx+1]) # sub_text = re.sub(r'\s(\w+)\s(n?\'[tsd])\s', r' \1\2 ', sub_text) # suggest_subs = [] # for trial in range(3): # try: # suggest_subs = freebase.suggest_id(sub_text) # break # except: # print >> sys.stderr, 'freebase suggest_id error: trial = %d, sub_text = %s' % (trial, sub_text) # candi_sub_list.extend(suggest_subs) # if data.subject not in candi_sub_list: # print >> log_file, '%s\t\t%s\t\t%s\t\t%d' % (sub_text, data.text_subject, fb2www(data.subject), len(candi_sub_list)) # if potential subject founded if len(candi_sub_list) > 0: # add candidates to data for candi_sub in candi_sub_list: candi_rel_list = virtuoso.id_query_out_rel(candi_sub) if len(candi_rel_list) > 0: if type_dict: candi_type_list = [ type_dict[t] for t in virtuoso.id_query_type(candi_sub) if type_dict.has_key(t) ] if len(candi_type_list) == 0: candi_type_list.append(len(type_dict)) data.add_candidate(candi_sub, candi_rel_list, candi_type_list) else: data.add_candidate(candi_sub, candi_rel_list) data.anonymous_question = form_anonymous_quesion( question, beg_idx, end_idx) # make score mat if hasattr(data, 'cand_sub') and hasattr(data, 'cand_rel'): # remove duplicate relations data.remove_duplicate() # append to new_data_list new_data_list.append(data) # loging information if subject in candi_sub_list: succ_match += 1 if data_index % 100 == 0: print >> sys.stderr, '[%d] %d / %d' % (pid, data_index, len(data_list)) print >> log_file, '%d / %d = %f ' % (succ_match, data_index + 1, succ_match / float(data_index + 1)) log_file.close() pickle.dump(new_data_list, file('temp.%d.cpickle' % (pid), 'wb'))
def predict(dataset=args.test_file, tp='test', save_qadata=args.save_qadata): # load QAdata qa_data_path = './data/QAData.%s.pkl' % tp qa_data = pickle.load(open(qa_data_path, 'rb')) # load batch data for predict data_loader = SeqLabelingLoader(dataset, args.gpu) print('load %s data, batch_num: %d\tbatch_size: %d' % (tp, data_loader.batch_num, data_loader.batch_size)) model.eval() n_correct = 0 n_correct_sub = 0 n_correct_extend = 0 n_empty = 0 linenum = 1 qa_data_idx = 0 new_qa_data = [] gold_list = [] pred_list = [] for data_batch_idx, data_batch in enumerate( data_loader.next_batch(shuffle=False)): if data_batch_idx % 50 == 0: print(tp, data_batch_idx) scores = model(data_batch) n_correct += ((torch.max(scores, 1)[1].view( data_batch[1].size()).data == data_batch[1].data).sum( dim=0) == data_batch[1].size()[0]).sum() index_tag = np.transpose( torch.max(scores, 1)[1].view(data_batch[1].size()).cpu().data.numpy()) gold_tag = np.transpose(data_batch[1].cpu().data.numpy()) index_question = np.transpose(data_batch[0].cpu().data.numpy()) gold_list.append(np.transpose(data_batch[1].cpu().data.numpy())) pred_list.append(index_tag) for i in range(data_loader.batch_size): while qa_data_idx < len( qa_data) and not qa_data[qa_data_idx].text_subject: qa_data_idx += 1 if qa_data_idx >= len(qa_data): break _qa_data = qa_data[qa_data_idx] tokens = np.array(_qa_data.question.split()) pred_text = ' '.join(tokens[np.where(index_tag[i][:len(tokens)])]) _qa_data.pred_text = pred_text pred_sub, pred_sub_extend = get_candidate_sub(tokens, index_tag[i]) if _qa_data.subject in pred_sub: n_correct_sub += 1 if _qa_data.subject in pred_sub_extend: n_correct_extend += 1 if not pred_sub_extend: n_empty += 1 if save_qadata: for sub in pred_sub_extend: rel = virtuoso.id_query_out_rel(sub) _qa_data.add_candidate(sub, rel) if hasattr(_qa_data, 'cand_rel'): _qa_data.remove_duplicate() # if _qa_data.subject not in pred_sub_extend: # _qa_data.neg_rel = virtuoso.id_query_out_rel(_qa_data.subject) new_qa_data.append( (_qa_data, len(_qa_data.question_pattern.split()))) linenum += 1 qa_data_idx += 1 total = linenum - 1 accuracy = 100. * n_correct / total print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" % (tp, accuracy, n_correct, total)) P, R, F = evaluation(gold_list, pred_list) print("Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format( 100. * P, 100. * R, 100. * F)) sub_accuracy = 100. * n_correct_sub / total print('subject accuracy: %8.6f\tcorrect: %d\ttotal:%d' % (sub_accuracy, n_correct_sub, total)) extend_accuracy = 100. * n_correct_extend / total print('extend accuracy: %8.6f\tcorrect: %d\ttotal:%d' % (extend_accuracy, n_correct_extend, total)) print('suject not found: %8.6f\t%d' % (n_empty / total, n_empty)) print("-" * 80) if save_qadata: qadata_save_path = open( os.path.join(args.results_path, 'QAData.label.%s.pkl' % (tp)), 'wb') data_list = [ data[0] for data in sorted( new_qa_data, key=lambda data: data[1], reverse=True) ] pickle.dump(data_list, qadata_save_path)
def create_seq_ranking_data(qa_data, word_vocab, rel_sep_vocab, rel_vocab, save_path): seqs = [] seq_len = [] pos_rel1 = [] pos_rel2 = [] neg_rel1 = [] neg_rel2 = [] batch_index = -1 # the index of sequence batches seq_index = 0 # sequence index within each batch pad_index = word_vocab.lookup(word_vocab.pad_token) data_list = pickle.load(open(qa_data, 'rb')) def get_separated_rel_id(relation): rel = relation.split('.') rel1 = '.'.join(rel[:-1]) rel2 = rel[-1] rel1_id = rel_sep_vocab[0].lookup(rel1) rel2_id = rel_sep_vocab[1].lookup(rel2) return rel1_id, rel2_id for data in data_list: tokens = data.question_pattern.split() can_rels = [] if hasattr(data, 'cand_sub') and data.subject in data.cand_sub: can_rels = data.cand_rel else: can_subs = virtuoso.str_query_id(data.text_subject) for sub in can_subs: can_rels.extend(virtuoso.id_query_out_rel(sub)) can_rels = list(set(can_rels)) if data.relation in can_rels: can_rels.remove(data.relation) for i in range(len(can_rels), args.neg_size): tmp = random.randint(2, len(rel_vocab) - 1) while (tmp in can_rels): tmp = random.randint(2, len(rel_vocab) - 1) can_rels.append(rel_vocab.index2word[tmp]) can_rels = random.sample(can_rels, args.neg_size) if seq_index % args.batch_size == 0: seq_index = 0 batch_index += 1 seqs.append( torch.LongTensor(args.batch_size, len(tokens)).fill_(pad_index)) seq_len.append(torch.LongTensor(args.batch_size).fill_(1)) pos_rel1.append(torch.LongTensor(args.batch_size).fill_(pad_index)) pos_rel2.append(torch.LongTensor(args.batch_size).fill_(pad_index)) neg_rel1.append(torch.LongTensor(args.neg_size, args.batch_size)) neg_rel2.append(torch.LongTensor(args.neg_size, args.batch_size)) print('batch: %d' % batch_index) seqs[batch_index][seq_index, 0:len(tokens)] = torch.LongTensor( word_vocab.convert_to_index(tokens)) seq_len[batch_index][seq_index] = len(tokens) pos1, pos2 = get_separated_rel_id(data.relation) pos_rel1[batch_index][seq_index] = pos1 pos_rel2[batch_index][seq_index] = pos2 for j, neg_rel in enumerate(can_rels): neg1, neg2 = get_separated_rel_id(neg_rel) if not neg1 or not neg2: continue neg_rel1[batch_index][j, seq_index] = neg1 neg_rel2[batch_index][j, seq_index] = neg2 seq_index += 1 torch.save((seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2), save_path)
def predict(dataset=args.test_file, tp='test', write=args.write, save_qadata=args.save_qadata): # load QAdata qa_data_path = '../data/QAData.%s.pkl' % tp qa_data = pickle.load(open(qa_data_path, 'rb')) # load batch data for predict data_loader = SeqLabelingLoader(dataset, args.gpu) print('load %s data, batch_num: %d\tbatch_size: %d' % (tp, data_loader.batch_num, data_loader.batch_size)) model.eval() n_correct = 0 n_correct_sub = 0 n_correct_extend = 0 n_empty = 0 linenum = 1 qa_data_idx = 0 if write: results_file = open( os.path.join(args.results_path, '%s-results.txt' % tp), 'w') results_file_sub = open( os.path.join(args.results_path, '%s-results-subject.txt' % tp), 'w') new_qa_data = [] gold_list = [] pred_list = [] for data_batch_idx, data_batch in enumerate( data_loader.next_batch(shuffle=False)): if data_batch_idx % 50 == 0: print(tp, data_batch_idx) scores = model(data_batch) # 计算有多少条是和seq_labels完全一样的 n_correct += ((torch.max(scores, 1)[1].view( data_batch[1].size()).data == data_batch[1].data).sum( dim=0) == data_batch[1].size()[0]).sum() # 预测的label和实际的label,后面要转为对应的text。注意都要transpose index_tag = np.transpose( torch.max(scores, 1)[1].view(data_batch[1].size()).cpu().data.numpy()) gold_tag = np.transpose(data_batch[1].cpu().data.numpy()) index_question = np.transpose(data_batch[0].cpu().data.numpy()) gold_list.append(np.transpose(data_batch[1].cpu().data.numpy())) pred_list.append(index_tag) for i in range(data_loader.batch_size): # 转为QAData中对应的text,去FB中查MID,计算subject的准确率 while qa_data_idx < len( qa_data) and not qa_data[qa_data_idx].text_subject: qa_data_idx += 1 # 在loader里去掉了没有text_subject的数据,而QADate是全的 if qa_data_idx >= len( qa_data): # 最后一个batch后面都是<pad>填充的,此时qa_data已经找到头了 break _qa_data = qa_data[qa_data_idx] tokens = np.array(_qa_data.question.split()) pred_text = ' '.join(tokens[np.where( index_tag[i] [:len(tokens)])]) # index_tag可能比实际的question长,因为后面加了<pad> # 计算扩展生成candidate subject的准确率 pred_sub, pred_sub_extend = get_candidate_sub(tokens, index_tag[i]) if _qa_data.subject in pred_sub: n_correct_sub += 1 if _qa_data.subject in pred_sub_extend: n_correct_extend += 1 if not pred_sub_extend: n_empty += 1 if write: if pred_sub == pred_sub_extend: pred_sub = 'RRR' results_file_sub.write('%s-%d\t%s\t%s\t%s\t%s\t%s\t%s\n' %(tp, linenum, _qa_data.question, \ pred_sub, pred_sub_extend, _qa_data.subject, \ pred_text, _qa_data.text_subject)) question_array = np.array( word_vocab.convert_to_word(index_question[i])) pred_array = question_array[np.where(index_tag[i])] gold_array = question_array[np.where(gold_tag[i])] line_to_print = '%s-%d\t%s\t%s\t%s' %(tp, linenum, " ".join(question_array), \ " ".join(pred_array), " ".join(gold_array)) results_file.write(line_to_print + "\n") if save_qadata: for sub in pred_sub_extend: rel = virtuoso.id_query_out_rel(sub) _qa_data.add_candidate(sub, rel) if hasattr(_qa_data, 'cand_rel'): _qa_data.remove_duplicate() # if _qa_data.subject not in pred_sub_extend: # _qa_data.neg_rel = virtuoso.id_query_out_rel(_qa_data.subject) new_qa_data.append( (_qa_data, len(_qa_data.question_pattern.split()))) linenum += 1 qa_data_idx += 1 total = linenum - 1 accuracy = 100. * n_correct / total print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" % (tp, accuracy, n_correct, total)) P, R, F = evaluation(gold_list, pred_list) print("Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format( 100. * P, 100. * R, 100. * F)) sub_accuracy = 100. * n_correct_sub / total print('subject accuracy: %8.6f\tcorrect: %d\ttotal:%d' % (sub_accuracy, n_correct_sub, total)) extend_accuracy = 100. * n_correct_extend / total print('extend accuracy: %8.6f\tcorrect: %d\ttotal:%d' % (extend_accuracy, n_correct_extend, total)) print('suject not found: %8.6f\t%d' % (n_empty / total, n_empty)) print("-" * 80) if write: results_file.close() results_file_sub.close() if save_qadata: qadata_save_path = open( os.path.join(args.results_path, 'QAData.label.%s.pkl' % (tp)), 'wb') data_list = [ data[0] for data in sorted( new_qa_data, key=lambda data: data[1], reverse=True) ] pickle.dump(data_list, qadata_save_path)