def prepare_cluster_data_for_eval(art_qrels, top_qrels, paratext, do_filter, val_samples): page_paras, rev_para_top, _ = get_trec_dat(art_qrels, top_qrels, None) len_paras = np.array([len(page_paras[page]) for page in page_paras.keys()]) print('mean paras: %.2f, std: %.2f, max paras: %.2f' % (np.mean(len_paras), np.std(len_paras), np.max(len_paras))) ptext_dict = get_paratext_dict(paratext) top_cluster_data = [] pages = list(page_paras.keys()) skipped_pages = 0 max_num_doc = max([len(page_paras[p]) for p in page_paras.keys()]) for i in trange(len(pages)): page = pages[i] paras = page_paras[page] paratexts = [ptext_dict[p] for p in paras] top_sections = list(set([rev_para_top[p] for p in paras])) top_labels = [top_sections.index(rev_para_top[p]) for p in paras] query_text = ' '.join(page.split('enwiki:')[1].split('%20')) n = len(paras) if do_filter: if n < 20 or n > 200: skipped_pages += 1 continue paras = paras[:max_num_doc] if n >= max_num_doc else paras + ['dummy'] * (max_num_doc - n) paratexts = paratexts[:max_num_doc] if n >= max_num_doc else paratexts + [''] * (max_num_doc - n) top_labels = top_labels[:max_num_doc] if n >= max_num_doc else top_labels + [-1] * (max_num_doc - n) if do_filter: if len(set(top_labels)) < 2 or n / len(set(top_labels)) < 2.5: ## the page should have at least 2 top level sections and n/k should be at least 2.5 skipped_pages += 1 continue top_cluster_data.append(InputTRECCARExample(qid=page, q_context=query_text, pids=paras, texts=paratexts, label=np.array(top_labels))) if val_samples > 0: top_cluster_data = top_cluster_data[:val_samples] print('Total data instances: %5d' % len(top_cluster_data)) return top_cluster_data
def evaluate_treccar(model_path, test_art_qrels, test_top_qrels, test_hier_qrels, test_paratext, level): test_page_paras, test_rev_para_top, test_rev_para_hier = get_trec_dat( test_art_qrels, test_top_qrels, test_hier_qrels) test_len_paras = np.array( [len(test_page_paras[page]) for page in test_page_paras.keys()]) print('test mean paras: %.2f, std: %.2f, max paras: %.2f' % (np.mean(test_len_paras), np.std(test_len_paras), np.max(test_len_paras))) test_ptext_dict = get_paratext_dict(test_paratext) test_top_cluster_data = [] test_hier_cluster_data = [] max_num_doc_test = max( [len(test_page_paras[p]) for p in test_page_paras.keys()]) test_pages = list(test_page_paras.keys()) for i in trange(len(test_pages)): page = test_pages[i] paras = test_page_paras[page] paratexts = [test_ptext_dict[p] for p in paras] top_sections = list(set([test_rev_para_top[p] for p in paras])) top_labels = [top_sections.index(test_rev_para_top[p]) for p in paras] hier_sections = list(set([test_rev_para_hier[p] for p in paras])) hier_labels = [ hier_sections.index(test_rev_para_hier[p]) for p in paras ] query_text = ' '.join(page.split('enwiki:')[1].split('%20')) n = len(paras) paras = paras + ['dummy'] * (max_num_doc_test - n) paratexts = paratexts + [''] * (max_num_doc_test - n) top_labels = top_labels + [-1] * (max_num_doc_test - n) hier_labels = hier_labels + [-1] * (max_num_doc_test - n) test_top_cluster_data.append( InputTRECCARExample(qid=page, q_context=query_text, pids=paras, texts=paratexts, label=np.array(top_labels))) test_hier_cluster_data.append( InputTRECCARExample(qid=page, q_context=query_text, pids=paras, texts=paratexts, label=np.array(hier_labels))) print("Top-level datasets") print("Test instances: %5d" % len(test_top_cluster_data)) model = SentenceTransformer(model_path) if level == 'h': print('Evaluating hiererchical clusters') test_evaluator = ClusterEvaluator.from_input_examples( test_hier_cluster_data) model.evaluate(test_evaluator) else: print('Evaluating toplevel clusters') test_evaluator = ClusterEvaluator.from_input_examples( test_top_cluster_data) model.evaluate(test_evaluator)
def save_squt_dataset(train_pages_file, art_qrels, top_qrels, paratext, val_samples, outdir): page_paras, rev_para_top, _ = get_trec_dat(art_qrels, top_qrels, None) ptext_dict = get_paratext_dict(paratext) train_cluster_data = [] test_cluster_data = [] pages = [] with open(train_pages_file, 'r') as f: for l in f: pages.append(l.rstrip('\n')) for i in trange(len(pages)): page = pages[i] paras = page_paras[page] page_sec_para_dict = {} for p in paras: sec = rev_para_top[p] if sec not in page_sec_para_dict.keys(): page_sec_para_dict[sec] = [p] else: page_sec_para_dict[sec].append(p) sections = list(set([rev_para_top[p] for p in paras])) random.shuffle(sections) test_sections, train_sections = sections[:len(sections)//2], sections[len(sections)//2:] train_paras = [] test_paras = [] for s in test_sections: test_paras += page_sec_para_dict[s] for s in train_sections: train_paras += page_sec_para_dict[s] test_labels = [sections.index(rev_para_top[p]) for p in test_paras] train_labels = [sections.index(rev_para_top[p]) for p in train_paras] test_paratexts = [ptext_dict[p] for p in test_paras] train_paratexts = [ptext_dict[p] for p in train_paras] query_text = ' '.join(page.split('enwiki:')[1].split('%20')) test_cluster_data.append(InputTRECCARExample(qid=page, q_context=query_text, pids=test_paras, texts=test_paratexts, label=np.array(test_labels))) train_cluster_data.append(InputTRECCARExample(qid=page, q_context=query_text, pids=train_paras, texts=train_paratexts, label=np.array(train_labels))) random.shuffle(test_cluster_data) val_cluster_data = test_cluster_data[:val_samples] test_cluster_data = test_cluster_data[val_samples:] with open(outdir + '/squt_treccar_train.pkl', 'wb') as f: pickle.dump(train_cluster_data, f) with open(outdir + '/squt_treccar_val.pkl', 'wb') as f: pickle.dump(val_cluster_data, f) with open(outdir + '/squt_treccar_test.pkl', 'wb') as f: pickle.dump(test_cluster_data, f) print( 'No. of data instances - Train: %5d, Val: %5d, Test: %5d' % (len(train_cluster_data), len(val_cluster_data), len(test_cluster_data)))
def save_sbert_embeds(sbert_model_name, pages_path, art_qrels, paratext_file, outpath): sbert = SentenceTransformer(sbert_model_name) page_paras, _, _ = get_trec_dat(art_qrels, None, None) paratext_dict = get_paratext_dict(paratext_file) paras = [] paratexts = [] with open(pages_path, 'r') as f: for l in f: page = l.rstrip('\n') paras += page_paras[page] paratexts += [paratext_dict[p] for p in page_paras[page]] print(str(len(paratexts))+' paras to be encoded') para_embeddings = sbert.encode(paratexts, show_progress_bar=True) para_data = {'paraids': paras, 'paravecs': para_embeddings} with open(outpath, 'wb') as f: pickle.dump(para_data, f)
def trec_stats(art_qrels, top_qrels, hier_qrels, paratext_file, i): page_paras, rev_para_top, rev_para_hier = get_trec_dat(art_qrels, top_qrels, hier_qrels) paratext_dict = get_paratext_dict(paratext_file) print('Data loaded') arts = [] stats = [] c = 0 for p in page_paras.keys(): paras = page_paras[p] n = len(paras) top_section_counts = list(Counter([rev_para_top[p] for p in paras]).values()) hier_section_counts = list(Counter([rev_para_hier[p] for p in paras]).values()) top_k = len(top_section_counts) mean_top_k = np.mean(top_section_counts) std_top_k = np.std(top_section_counts) min_top_k = min(top_section_counts) max_top_k = max(top_section_counts) hier_k = len(hier_section_counts) mean_hier_k = np.mean(hier_section_counts) std_hier_k = np.std(hier_section_counts) min_hier_k = min(hier_section_counts) max_hier_k = max(hier_section_counts) lens = [len(paratext_dict[p].split()) for p in paras] l10 = len([x for x in lens if x < 10]) l20 = len([x for x in lens if 10 <= x < 20]) l30 = len([x for x in lens if 20 <= x < 30]) l40 = len([x for x in lens if 30 <= x < 40]) l50 = len([x for x in lens if 40 <= x]) arts.append(p) curr_stat = [n, top_k, mean_top_k, std_top_k, min_top_k, max_top_k, hier_k, mean_hier_k, std_hier_k, min_hier_k, max_hier_k, l10, l20, l30, l40, l50] stats.append(curr_stat) c += 1 if i > 0 and c % 100 == 0: print(str(i-c)+' pages to go') elif i < 0: print(str(len(page_paras) - c) + ' pages to go') if i > 0 and c >= i: break print('Article\tN\ttop_k\tmean_k\tstd_k\tmin_k\tmax_k\thier_k\tmean_k\tstd_k\tmin_k\tmax_k\tl10\tl20\tl30\tl40\tl50') for i, d in enumerate(stats[:100]): print(arts[i]+'\t'+'\t'.join([str(dd) for dd in d])) return arts, stats
def prepare_cluster_data_train(pages_file, art_qrels, top_qrels, paratext): page_paras, rev_para_top, _ = get_trec_dat(art_qrels, top_qrels, None) ptext_dict = get_paratext_dict(paratext) top_cluster_data = [] pages = [] with open(pages_file, 'r') as f: for l in f: pages.append(l.rstrip('\n')) for i in trange(len(pages)): page = pages[i] paras = page_paras[page] paratexts = [ptext_dict[p] for p in paras] top_sections = list(set([rev_para_top[p] for p in paras])) if len(top_sections) < 2: continue top_labels = [top_sections.index(rev_para_top[p]) for p in paras] query_text = ' '.join(page.split('enwiki:')[1].split('%20')) top_cluster_data.append(InputTRECCARExample(qid=page, q_context=query_text, pids=paras, texts=paratexts, label=np.array(top_labels))) print('Total data instances: %5d' % len(top_cluster_data)) return top_cluster_data