def __init__(self): vocab_file = os.path.join(data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) self.stemmer = CacheStemmer() self.stopword = load_stopwords() self.df = self.load_galgo_df_stat()
def get_idf_keyword_score(problems: List[QueryDoc], get_idf) -> Iterable[Counter]: stemmer = CacheStemmer() ticker = TimeEstimator(len(problems)) for p in problems: tokens = p.doc tf = Counter() reverse_map = {} # Stemmed -> raw tokens = [t for t in tokens if t not in [".", ",", "!"]] for raw_t in tokens: stem_t = stemmer.stem(raw_t) reverse_map[stem_t] = raw_t tf[stem_t] += 1 score_d = Counter() for term, cnt in tf.items(): score = math.log(1 + cnt) * get_idf(term) assert type(score) == float score_d[term] = score score_d_surface_form: Counter = Counter( dict_key_map(lambda x: reverse_map[x], score_d)) ticker.tick() yield score_d_surface_form
def __init__(self, window_size): self.stemmer = CacheStemmer() self.window_size = window_size self.doc_posting = None self.stopword = load_stopwords() vocab_file = os.path.join(cpath.data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer( vocab_file=vocab_file, do_lower_case=True) def load_pickle(name): p = os.path.join(cpath.data_path, "adhoc", name + ".pickle") return pickle.load(open(p, "rb")) self.doc_len_dict = load_pickle("doc_len") self.qdf = load_pickle("robust_qdf_ex") self.meta = load_pickle("robust_meta") self.head_tokens = load_pickle("robust_title_tokens") self.seg_info = load_pickle("robust_seg_info") self.not_found = set() self.total_doc_n = len(self.doc_len_dict) self.avdl = sum(self.doc_len_dict.values()) / len(self.doc_len_dict) tprint("Init PassageRanker")
def __init__(self): self.stopword = load_stopwords() self.stemmer = CacheStemmer() vocab_file = os.path.join(cpath.data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) tprint("Loading inv_index for robust") self.collection = RobustCollection() tprint("Done") self.num_candidate = 10
def segment_per_doc_index(task_id): token_reader = get_token_reader() stemmer = CacheStemmer() stopword = load_stopwords() p = os.path.join(cpath.data_path, "adhoc", "robust_seg_info.pickle") seg_info = pickle.load(open(p, "rb")) def get_doc_posting_list(doc_id): doc_posting = defaultdict(list) for interval in seg_info[doc_id]: (loc, loc_ed), (_, _) = interval tokens = token_reader.retrieve(doc_id) st_tokens = list([stemmer.stem(t) for t in tokens]) ct = Counter(st_tokens[loc:loc_ed]) for term, cnt in ct.items(): if term in stopword: continue doc_posting[term].append((loc, cnt)) return doc_posting doc_id_list = get_doc_task(task_id) ticker = TimeEstimator(len(doc_id_list)) doc_posting_d = {} for doc_id in doc_id_list: doc_posting_d[doc_id] = get_doc_posting_list(doc_id) ticker.tick() save_path = os.path.join(cpath.data_path, "adhoc", "per_doc_posting_{}.pickle".format(task_id)) pickle.dump(doc_posting_d, open(save_path, "wb"))
class HintRetriever: def __init__(self): self.stopword = load_stopwords() self.stemmer = CacheStemmer() vocab_file = os.path.join(cpath.data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) tprint("Loading inv_index for robust") self.collection = RobustCollection() tprint("Done") self.num_candidate = 10 def retrieve_hints(self, inst): target_tokens, sent_list, prev_tokens, next_tokens, mask_indice = inst basic_tokens, word_mask_indice = translate_mask2token_level( sent_list, target_tokens, mask_indice, self.tokenizer) # stemming visible_tokens = remove(basic_tokens, word_mask_indice) stemmed = list([self.stemmer.stem(t) for t in visible_tokens]) query = self.filter_stopword(stemmed) res = self.collection.retrieve_docs(query) candi_list = left(res[1:1 + self.num_candidate]) r = target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, candi_list return r def query_gen(self, inst): target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = inst basic_tokens, word_mask_indice = translate_mask2token_level( sent_list, target_tokens, mask_indice, self.tokenizer) # stemming visible_tokens = remove(basic_tokens, word_mask_indice) stemmed = list([self.stemmer.stem(t) for t in visible_tokens]) query = self.filter_stopword(stemmed) high_qt_stmed = self.collection.high_idf_q_terms(Counter(query)) final_query = [] for t in visible_tokens: if self.stemmer.stem(t) in high_qt_stmed: final_query.append(t) return final_query def filter_stopword(self, l): return list([t for t in l if t not in self.stopword])
def __init__(self): tprint("Pipeline Init") self.stemmer = CacheStemmer() vocab_file = os.path.join(cpath.data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) self.iteration_dir = "/mnt/scratch/youngwookim/data/tlm_iter1" if not os.path.exists("/mnt/scratch/youngwookim/"): self.iteration_dir = "/mnt/nfs/work3/youngwookim/data/tlm_iter1" self.seg_max_seq = 256 self.model_max_seq = 512 self.rng = random.Random(0) self.masked_lm_prob = 0.15 self.short_seq_prob = 0.1 self.inst_per_job = 1000 self.stopword = load_stopwords() self.pr = FeatureExtractor(self.seg_max_seq - 3) self.tf_record_maker = None self.code_tick = CodeTiming() tprint("Pipeline Init Done")
def build_co_occur_from_pc_feature(data: Dict[str, List[List[str]]]) \ -> List[Tuple[str, Counter]]: window_size = 10 stemmer = CacheStemmer() r = [] ticker = TimeEstimator(len(data)) for cid, tokens_list in data.items(): ticker.tick() counter = build_co_occurrence(tokens_list, window_size, stemmer) r.append((cid, counter)) return r
def build_co_occur_from_pc_feature( data: Dict[str, List[ScoreParagraph]]) -> List[Tuple[str, Counter]]: window_size = 10 stemmer = CacheStemmer() r = [] ticker = TimeEstimator(len(data)) for cid, para_list in data.items(): ticker.tick() tokens_list: List[List[str]] = [e.paragraph.tokens for e in para_list] counter = build_co_occurrence(tokens_list, window_size, stemmer) r.append((cid, counter)) return r
def save_qdf_ex(): ii_path = os.path.join(cpath.data_path, "adhoc", "robust_inv_index.pickle") inv_index = pickle.load(open(ii_path, "rb")) save_path = os.path.join(cpath.data_path, "adhoc", "robust_meta.pickle") meta = pickle.load(open(save_path, "rb")) stopwords = load_stopwords() stemmer = CacheStemmer() simple_posting = {} qdf_d = Counter() for term in inv_index: simple_posting[term] = set() for doc_id, _ in inv_index[term]: simple_posting[term].add(doc_id) for doc in meta: date, headline = meta[doc] tokens = nltk.tokenize.wordpunct_tokenize(headline) terms = set() for idx, t in enumerate(tokens): if t in stopwords: continue t_s = stemmer.stem(t) terms.add(t_s) for t in terms: simple_posting[t].add(doc) for term in inv_index: qdf = len(simple_posting[term]) qdf_d[term] = qdf save_path = os.path.join(cpath.data_path, "adhoc", "robust_qdf_ex.pickle") pickle.dump(qdf_d, open(save_path, "wb"))
def explain_by_lime_idf(data: List[str], get_idf) -> List[Tuple[str, float]]: stemmer = CacheStemmer() def split(t): return t.split() explainer = lime_text.LimeTextExplainer(split_expression=split, bow=True) def evaluate_score(problems: List[str]): scores = [] for problem in problems: score = solve(problem) scores.append([0, score]) return np.array(scores) def solve(problem: str): tokens = split(problem) if "[SEP]" not in tokens: return 0 e: QueryDoc = parse_problem(tokens) q_terms = lmap(stemmer.stem, e.query) doc_terms = lmap(stemmer.stem, e.doc) tf = Counter(doc_terms) q_terms_set = set(q_terms) score = 0 for term, cnt in tf.items(): if term in q_terms_set: idf = get_idf(term) score += log(1 + cnt) * idf # TODO add idf multiplication return score explains = [] tick = TimeEstimator(len(data)) for entry in data: assert type(entry) == str exp = explainer.explain_instance(entry, evaluate_score, num_features=512) # l = list(exp.local_exp[1]) # l.sort(key=get_first) # indices, scores = zip(*l) l2 = exp.as_list() l2.sort(key=get_second, reverse=True) explains.append(l2) tick.tick() return explains
def work(q_res_path, save_name): ranked_list_d = load_galago_ranked_list(q_res_path) window_size = 10 stemmer = CacheStemmer() print(q_res_path) ticker = TimeEstimator(len(ranked_list_d)) r = [] for claim_id, ranked_list in ranked_list_d.items(): ticker.tick() doc_ids = list([e.doc_id for e in ranked_list]) print("1") counter = build_co_occurrence(get_tokens_form_doc_ids(doc_ids), window_size, stemmer) print("2") r.append((claim_id, counter)) save_to_pickle(r, save_name)
def count_it( data: Dict[str, List[ScoreParagraph]]) -> List[Tuple[str, Counter]]: stemmer = CacheStemmer() r = [] stopword = load_stopwords() def remove_stopwords(tokens: List[str]) -> List[str]: return list([t for t in tokens if t not in stopword]) ticker = TimeEstimator(len(data)) for cid, para_list in data.items(): ticker.tick() tokens_list: List[List[str]] = [e.paragraph.tokens for e in para_list] list_tokens: List[List[str]] = lmap(stemmer.stem_list, tokens_list) list_tokens: List[List[str]] = lmap(remove_stopwords, list_tokens) all_cnt = Counter() for tokens in list_tokens: all_cnt.update(Counter(tokens)) r.append((cid, all_cnt)) return r
class ProblemMaker: def __init__(self): vocab_file = os.path.join(data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) self.stemmer = CacheStemmer() self.stopword = load_stopwords() self.df = self.load_galgo_df_stat() def load_galgo_df_stat(self): return load_df(os.path.join(data_path, "enwiki", "tf_stat")) def generate_mask(self, inst, max_num_tokens, masked_lm_prob, short_seq_prob, rng): max_predictions_per_seq = 20 cur_seg, prev_seg, next_seg = inst def get_seg_tokens(seg): if seg is None: return None title, content, st, ed = seg return flatten([self.tokenizer.tokenize(t) for t in content]) title, content, st, ed = cur_seg prev_tokens = get_seg_tokens(prev_seg) target_tokens = get_seg_tokens(cur_seg) next_tokens = get_seg_tokens(next_seg) if rng.random() < short_seq_prob and next_tokens is not None: target_seq_length = rng.randint(2, max_num_tokens) short_seg = target_tokens[:target_seq_length] remain_seg = target_tokens[target_seq_length:] next_tokens = (remain_seg + next_tokens)[:max_num_tokens] target_tokens = short_seg num_to_predict = min( max_predictions_per_seq, max(1, int(round(len(target_tokens) * masked_lm_prob)))) cand_indice = list(range(0, len(target_tokens))) rng.shuffle(cand_indice) mask_indice = cand_indice[:num_to_predict] doc_id = "{}-{}-{}".format(title, st, ed) mask_inst = target_tokens, prev_seg, cur_seg, next_seg, prev_tokens, next_tokens, mask_indice, doc_id return mask_inst def filter_stopword(self, l): return list([t for t in l if t not in self.stopword]) def generate_query(self, mask_inst): target_tokens, prev_seg, cur_seg, next_seg, prev_tokens, next_tokens, mask_indice, doc_id = mask_inst title, content, st, ed = cur_seg visible_tokens = get_visible(content, target_tokens, mask_indice, self.tokenizer) stemmed = list([self.stemmer.stem(t) for t in visible_tokens]) query = clean_query(self.filter_stopword(stemmed)) high_qt_stmed = self.high_idf_q_terms(Counter(query)) query = [] for t in visible_tokens: if self.stemmer.stem(t) in high_qt_stmed: query.append(t) if not query: print(title) print(content) print(visible_tokens) print(high_qt_stmed) return query def high_idf_q_terms(self, q_tf, n_limit=10): total_doc = 11503029 + 100 high_qt = Counter() for term, qf in q_tf.items(): qdf = self.df[term] w = BM25_3_q_weight(qf, qdf, total_doc) high_qt[term] = w return set(left(high_qt.most_common(n_limit)))
class Pipeline: def __init__(self): tprint("Pipeline Init") self.stemmer = CacheStemmer() vocab_file = os.path.join(cpath.data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) self.iteration_dir = "/mnt/scratch/youngwookim/data/tlm_iter1" if not os.path.exists("/mnt/scratch/youngwookim/"): self.iteration_dir = "/mnt/nfs/work3/youngwookim/data/tlm_iter1" self.seg_max_seq = 256 self.model_max_seq = 512 self.rng = random.Random(0) self.masked_lm_prob = 0.15 self.short_seq_prob = 0.1 self.inst_per_job = 1000 self.stopword = load_stopwords() self.pr = FeatureExtractor(self.seg_max_seq - 3) self.tf_record_maker = None self.code_tick = CodeTiming() tprint("Pipeline Init Done") def run_A(self, job_id): output = [] queries = [] ticker = TimeEstimator(self.inst_per_job) for i in range(self.inst_per_job): p, q = self.process_A() qid = job_id * self.inst_per_job + i output.append((p, qid)) queries.append((qid, q)) ticker.tick() self.save_query(job_id, queries) self.save_output_A(job_id, output) def run_B(self, job_id): if self.pr.doc_posting is None: self.pr.doc_posting = per_doc_posting_server.load_dict() output_A = self.load_output_A(job_id) candi_docs = self.load_candidate_docs(job_id) feature_str_list = [] seg_candi_list = [] ticker = TimeEstimator(self.inst_per_job) for i in range(self.inst_per_job): problem, qid = output_A[i] qid_str = str(qid) if qid_str in candi_docs: doc_candi = candi_docs[qid_str] seg_candi, features = self.process_B(problem, doc_candi) fstr = "\n".join([libsvm_str(qid, 0, f) for f in features]) feature_str_list.append(fstr) seg_candi_list.append(seg_candi) else: feature_str_list.append([]) seg_candi_list.append([]) ticker.tick() if i % 100 == 3: self.code_tick.print() self.save("seg_candi_list", job_id, seg_candi_list) self.save_ltr(job_id, feature_str_list) def run_C(self, job_id): if self.tf_record_maker is None: self.tf_record_maker = TFRecordMaker(self.model_max_seq) output_A = self.load_output_A(job_id) seg_candi_list = self.load("seg_candi_list", job_id) ltr_result = self.load_ltr(job_id) uid_list = [] tf_record_list = [] for i in range(self.inst_per_job): problem, qid = output_A[i] insts_id = "{}_{}".format(job_id, i) r = self.process_C(problem, seg_candi_list[i], ltr_result[i]) for idx, tf_record in r: tf_id = insts_id + "_{}".format(idx) uid = job_id * 1000 * 1000 + i * 10 + idx uid_list.append(uid) tf_record_list.append(tf_record) data = zip(tf_record_list, uid_list) self.write_tf_record(job_id, data) def inspect_seg(self, job_id): output_A = self.load_output_A(job_id) seg_candi_list = self.load("seg_candi_list", job_id) q_path = self.get_path("query", "g_query_{}.json".format(job_id)) queries = json.load(open(q_path, "r"))["queries"] ltr_result = self.load_ltr(job_id) for i in range(self.inst_per_job): problem, qid = output_A[i] print(qid) scl = seg_candi_list[i] query = queries[i]["text"][len("#combine("):-1] print(query) q_terms = query.split() q_tf = stemmed_counter(q_terms, self.stemmer) print(list(q_tf.keys())) for doc_id, loc in scl[:30]: doc_tokens = self.pr.token_dump.get(doc_id) l = self.pr.get_seg_len_dict(doc_id)[loc] passage = doc_tokens[loc:loc + l] d_tf = stemmed_counter(passage, self.stemmer) arr = [] for qt in q_tf: arr.append(d_tf[qt]) print(arr) def process_A(self): segment = self.sample_segment() problem = self.segment2problem(segment) query = self.problem2query(problem) return problem, query def process_B(self, problem, doc_candi): self.code_tick.tick_begin("get_seg_candidate") seg_candi = self.get_seg_candidate(doc_candi, problem) self.code_tick.tick_end("get_seg_candidate") target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem self.code_tick.tick_begin("get_feature_list") feature = self.pr.get_feature_list(doc_id, sent_list, target_tokens, mask_indice, seg_candi) self.code_tick.tick_end("get_feature_list") return seg_candi, feature def process_C(self, problem, seg_candi, scores): target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem def select_top(items, scores, n_top): return left( list(zip(items, scores)).sort(key=lambda x: x[1], reverse=True)[:n_top]) e = { "target_tokens": target_tokens, "sent_list": sent_list, "prev_tokens": prev_tokens, "next_tokens": next_tokens, "mask_indice": mask_indice, "doc_id": doc_id, "passages": select_top(seg_candi, scores, 3) } insts = self.tf_record_maker.generate_ir_tfrecord(self, e) r = [] for idx, p in enumerate(insts): tf_record = self.tf_record_maker.make_instance(problem) r.append((idx, tf_record)) return r def write_tf_record(self, job_id, data): inst_list, uid_list = filter_instances(data) max_pred = 20 data = zip(inst_list, uid_list) output_path = self.get_path("tf_record_pred", "{}".format(job_id)) write_predict_instance(data, self.tokenizer, self.model_max_seq, max_pred, [output_path]) # Use robust_idf_mini def sample_segment(self): r = get_random_sent() s_id, doc_id, loc, g_id, sent = r doc_rows = get_doc_sent(doc_id) max_seq = self.seg_max_seq target_tokens, sent_list, prev_tokens, next_tokens = extend( doc_rows, sent, loc, self.tokenizer, max_seq) inst = target_tokens, sent_list, prev_tokens, next_tokens, doc_id return inst def segment2problem(self, segment): mask_inst = generate_mask(segment, self.seg_max_seq, self.masked_lm_prob, self.short_seq_prob, self.rng) return mask_inst def problem2query(self, problem): target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem basic_tokens, word_mask_indice = translate_mask2token_level( sent_list, target_tokens, mask_indice, self.tokenizer) # stemming visible_tokens = remove(basic_tokens, word_mask_indice) stemmed = list([self.stemmer.stem(t) for t in visible_tokens]) query = self.filter_stopword(stemmed) high_qt_stmed = self.pr.high_idf_q_terms(Counter(query)) final_query = [] for t in visible_tokens: if self.stemmer.stem(t) in high_qt_stmed: final_query.append(t) return final_query def filter_res(self, query_res, doc_id): t_doc_id, rank, score = query_res[0] assert rank == 1 if t_doc_id == doc_id: valid_res = query_res[1:] else: valid_res = query_res for i in range(len(query_res)): e_doc_id, e_rank, _ = query_res[i] if e_doc_id == t_doc_id: valid_res = query_res[:i] + query_res[i + 1:] break return valid_res def get_seg_candidate(self, doc_candi, problem): target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem valid_res = self.filter_res(doc_candi, doc_id) return self.pr.rank(doc_id, valid_res, target_tokens, sent_list, mask_indice, top_k=100) def filter_stopword(self, l): return list([t for t in l if t not in self.stopword]) def get_path(self, sub_dir_name, file_name): out_path = os.path.join(self.iteration_dir, sub_dir_name, file_name) dir_path = os.path.join(self.iteration_dir, sub_dir_name) if not os.path.exists(dir_path): os.mkdir(dir_path) return out_path def save_query(self, job_idx, queries): j_queries = [] for qid, query in queries: query = clean_query(query) j_queries.append({ "number": str(qid), "text": "#combine({})".format(" ".join(query)) }) data = {"queries": j_queries} out_path = self.get_path("query", "g_query_{}.json".format(job_idx)) fout = open(out_path, "w") fout.write(json.dumps(data, indent=True)) fout.close() def save_pickle_at(self, obj, sub_dir_name, file_name): out_path = self.get_path(sub_dir_name, file_name) fout = open(out_path, "wb") pickle.dump(obj, fout) def load_pickle_at(self, sub_dir_name, file_name): out_path = self.get_path(sub_dir_name, file_name) fout = open(out_path, "rb") return pickle.load(fout) def save_output_A(self, job_idx, output): self.save_pickle_at(output, "output_A", "{}.pickle".format(job_idx)) def load_output_A(self, job_idx): return self.load_pickle_at("output_A", "{}.pickle".format(job_idx)) def save(self, dir_name, job_idx, output): self.save_pickle_at(output, dir_name, "{}.pickle".format(job_idx)) def load(self, dir_name, job_idx): return self.load_pickle_at(dir_name, "{}.pickle".format(job_idx)) def load_candidate_docs(self, job_id): out_path = self.get_path("q_res", "{}.txt".format(job_id)) return load_galago_ranked_list(out_path) def save_ltr(self, job_idx, data): out_path = self.get_path("ltr", str(job_idx)) s = "\n".join(data) open(out_path, "w").write(s) def load_ltr(self, job_idx): return NotImplemented
class PassageRanker: def __init__(self, window_size): self.stemmer = CacheStemmer() self.window_size = window_size self.doc_posting = None self.stopword = load_stopwords() vocab_file = os.path.join(cpath.data_path, "bert_voca.txt") self.tokenizer = tokenization.FullTokenizer( vocab_file=vocab_file, do_lower_case=True) def load_pickle(name): p = os.path.join(cpath.data_path, "adhoc", name + ".pickle") return pickle.load(open(p, "rb")) self.doc_len_dict = load_pickle("doc_len") self.qdf = load_pickle("robust_qdf_ex") self.meta = load_pickle("robust_meta") self.head_tokens = load_pickle("robust_title_tokens") self.seg_info = load_pickle("robust_seg_info") self.not_found = set() self.total_doc_n = len(self.doc_len_dict) self.avdl = sum(self.doc_len_dict.values()) / len(self.doc_len_dict) tprint("Init PassageRanker") def get_seg_len_dict(self, doc_id): interval_list = self.seg_info[doc_id] l_d = {} for e in interval_list: (loc, loc_ed), (loc_sub, loc_sub_ed) = e l_d[loc] = loc_ed - loc return l_d def filter_stopword(self, l): return list([t for t in l if t not in self.stopword]) def BM25(self, q_tf, tokens): dl = len(tokens) tf_d = stemmed_counter(tokens, self.stemmer) score = 0 for term, qf in q_tf.items(): tf = tf_d[term] qdf = self.qdf[term] total_doc = self.total_doc_n score += BM25_3(tf, qf, qdf, total_doc, dl, self.avdl) return score def BM25_seg(self, term, q_tf, tf_d, dl): qdf = self.qdf[term] total_doc = self.total_doc_n return BM25_3(tf_d, q_tf, qdf, total_doc, dl, self.avdl) def high_idf_q_terms(self, q_tf, n_limit=10): total_doc = self.total_doc_n high_qt = Counter() for term, qf in q_tf.items(): qdf = self.qdf[term] w = BM25_3_q_weight(qf, qdf, total_doc) high_qt[term] = w return set(left(high_qt.most_common(n_limit))) def reform_query(self, tokens): stemmed = list([self.stemmer.stem(t) for t in tokens]) query = self.filter_stopword(stemmed) q_tf = Counter(query) high_qt = self.high_idf_q_terms(q_tf) new_qf = Counter() for t in high_qt: new_qf[t] = q_tf[t] return new_qf def get_title_tokens(self, doc_id): return self.head_tokens[doc_id] def weight_doc(self, title, content): new_doc = [] title_repeat = 3 content_repeat = 1 for t in title: new_doc += [t] * title_repeat for t in content: new_doc += [t] * content_repeat return new_doc def pick_diverse(self, candidate, top_k): doc_set = set() r = [] for j in range(len(candidate)): score, window_subtokens, window_tokens, doc_id, loc = candidate[j] if doc_id not in doc_set: r.append(candidate[j]) doc_set.add(doc_id) if len(r) == top_k: break return r def rank(self, doc_id, query_res, target_tokens, sent_list, mask_indice, top_k): if self.doc_posting is None: self.doc_posting = per_doc_posting_server.load_dict() title = self.get_title_tokens(doc_id) visible_tokens = get_visible(sent_list, target_tokens, mask_indice, self.tokenizer) source_doc_rep = self.weight_doc(title, visible_tokens) q_tf = self.reform_query(source_doc_rep) # Rank Each Window by BM25F stop_cut = 100 segments = Counter() n_not_found = 0 for doc_id, rank, doc_score in query_res[:stop_cut]: doc_title = self.get_title_tokens(doc_id) title_tf = stemmed_counter(doc_title, self.stemmer) per_doc_index = self.doc_posting[doc_id] if doc_id in self.doc_posting else None seg_len_d = self.get_seg_len_dict(doc_id) if per_doc_index is None: self.not_found.add(doc_id) n_not_found += 1 continue assert doc_id not in self.not_found # Returns location after moving subtoken 'skip' times for term, qf in q_tf.items(): if term not in per_doc_index: continue for seg_loc, tf in per_doc_index[term]: rep_tf = tf + title_tf[term] dl = seg_len_d[seg_loc] + len(doc_title) segments[(doc_id,seg_loc)] += self.BM25_seg(term, qf, rep_tf, dl) r = [] for e , score in segments.most_common(top_k): if score == 0: print("Doc {} has 0 score".format(doc_id)) r.append(e) if n_not_found > 0.9 * len(query_res): print("WARNING : {} of {} not found".format(n_not_found, len(query_res))) return r