コード例 #1
0
ファイル: segment2problem.py プロジェクト: clover3/Chair
 def __init__(self):
     vocab_file = os.path.join(data_path, "bert_voca.txt")
     self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                 do_lower_case=True)
     self.stemmer = CacheStemmer()
     self.stopword = load_stopwords()
     self.df = self.load_galgo_df_stat()
コード例 #2
0
ファイル: keyword_by_idf.py プロジェクト: clover3/Chair
def get_idf_keyword_score(problems: List[QueryDoc],
                          get_idf) -> Iterable[Counter]:
    stemmer = CacheStemmer()
    ticker = TimeEstimator(len(problems))
    for p in problems:
        tokens = p.doc
        tf = Counter()
        reverse_map = {}  # Stemmed -> raw
        tokens = [t for t in tokens if t not in [".", ",", "!"]]
        for raw_t in tokens:
            stem_t = stemmer.stem(raw_t)
            reverse_map[stem_t] = raw_t
            tf[stem_t] += 1

        score_d = Counter()
        for term, cnt in tf.items():

            score = math.log(1 + cnt) * get_idf(term)
            assert type(score) == float
            score_d[term] = score

        score_d_surface_form: Counter = Counter(
            dict_key_map(lambda x: reverse_map[x], score_d))
        ticker.tick()
        yield score_d_surface_form
コード例 #3
0
ファイル: segment_ranker_1.py プロジェクト: clover3/Chair
    def __init__(self, window_size):
        self.stemmer = CacheStemmer()
        self.window_size = window_size
        self.doc_posting = None
        self.stopword = load_stopwords()

        vocab_file = os.path.join(cpath.data_path, "bert_voca.txt")
        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_file, do_lower_case=True)


        def load_pickle(name):
            p = os.path.join(cpath.data_path, "adhoc", name + ".pickle")
            return pickle.load(open(p, "rb"))

        self.doc_len_dict = load_pickle("doc_len")
        self.qdf = load_pickle("robust_qdf_ex")
        self.meta = load_pickle("robust_meta")
        self.head_tokens = load_pickle("robust_title_tokens")
        self.seg_info = load_pickle("robust_seg_info")
        self.not_found = set()

        self.total_doc_n =  len(self.doc_len_dict)
        self.avdl = sum(self.doc_len_dict.values()) / len(self.doc_len_dict)
        tprint("Init PassageRanker")
コード例 #4
0
 def __init__(self):
     self.stopword = load_stopwords()
     self.stemmer = CacheStemmer()
     vocab_file = os.path.join(cpath.data_path, "bert_voca.txt")
     self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                 do_lower_case=True)
     tprint("Loading inv_index for robust")
     self.collection = RobustCollection()
     tprint("Done")
     self.num_candidate = 10
コード例 #5
0
def segment_per_doc_index(task_id):
    token_reader = get_token_reader()
    stemmer = CacheStemmer()
    stopword = load_stopwords()

    p = os.path.join(cpath.data_path, "adhoc", "robust_seg_info.pickle")
    seg_info = pickle.load(open(p, "rb"))

    def get_doc_posting_list(doc_id):
        doc_posting = defaultdict(list)
        for interval in seg_info[doc_id]:
            (loc, loc_ed), (_, _) = interval
            tokens = token_reader.retrieve(doc_id)
            st_tokens = list([stemmer.stem(t) for t in tokens])
            ct = Counter(st_tokens[loc:loc_ed])
            for term, cnt in ct.items():
                if term in stopword:
                    continue
                doc_posting[term].append((loc, cnt))

        return doc_posting

    doc_id_list = get_doc_task(task_id)
    ticker = TimeEstimator(len(doc_id_list))
    doc_posting_d = {}
    for doc_id in doc_id_list:
        doc_posting_d[doc_id] = get_doc_posting_list(doc_id)
        ticker.tick()

    save_path = os.path.join(cpath.data_path, "adhoc",
                             "per_doc_posting_{}.pickle".format(task_id))
    pickle.dump(doc_posting_d, open(save_path, "wb"))
コード例 #6
0
class HintRetriever:
    def __init__(self):
        self.stopword = load_stopwords()
        self.stemmer = CacheStemmer()
        vocab_file = os.path.join(cpath.data_path, "bert_voca.txt")
        self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                    do_lower_case=True)
        tprint("Loading inv_index for robust")
        self.collection = RobustCollection()
        tprint("Done")
        self.num_candidate = 10

    def retrieve_hints(self, inst):
        target_tokens, sent_list, prev_tokens, next_tokens, mask_indice = inst
        basic_tokens, word_mask_indice = translate_mask2token_level(
            sent_list, target_tokens, mask_indice, self.tokenizer)
        # stemming

        visible_tokens = remove(basic_tokens, word_mask_indice)
        stemmed = list([self.stemmer.stem(t) for t in visible_tokens])
        query = self.filter_stopword(stemmed)
        res = self.collection.retrieve_docs(query)

        candi_list = left(res[1:1 + self.num_candidate])

        r = target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, candi_list
        return r

    def query_gen(self, inst):
        target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = inst
        basic_tokens, word_mask_indice = translate_mask2token_level(
            sent_list, target_tokens, mask_indice, self.tokenizer)
        # stemming

        visible_tokens = remove(basic_tokens, word_mask_indice)
        stemmed = list([self.stemmer.stem(t) for t in visible_tokens])
        query = self.filter_stopword(stemmed)
        high_qt_stmed = self.collection.high_idf_q_terms(Counter(query))

        final_query = []
        for t in visible_tokens:
            if self.stemmer.stem(t) in high_qt_stmed:
                final_query.append(t)
        return final_query

    def filter_stopword(self, l):
        return list([t for t in l if t not in self.stopword])
コード例 #7
0
ファイル: pipeline.py プロジェクト: clover3/Chair
 def __init__(self):
     tprint("Pipeline Init")
     self.stemmer = CacheStemmer()
     vocab_file = os.path.join(cpath.data_path, "bert_voca.txt")
     self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                 do_lower_case=True)
     self.iteration_dir = "/mnt/scratch/youngwookim/data/tlm_iter1"
     if not os.path.exists("/mnt/scratch/youngwookim/"):
         self.iteration_dir = "/mnt/nfs/work3/youngwookim/data/tlm_iter1"
     self.seg_max_seq = 256
     self.model_max_seq = 512
     self.rng = random.Random(0)
     self.masked_lm_prob = 0.15
     self.short_seq_prob = 0.1
     self.inst_per_job = 1000
     self.stopword = load_stopwords()
     self.pr = FeatureExtractor(self.seg_max_seq - 3)
     self.tf_record_maker = None
     self.code_tick = CodeTiming()
     tprint("Pipeline Init Done")
コード例 #8
0
def build_co_occur_from_pc_feature(data: Dict[str, List[List[str]]]) \
        -> List[Tuple[str, Counter]]:
    window_size = 10
    stemmer = CacheStemmer()
    r = []
    ticker = TimeEstimator(len(data))
    for cid, tokens_list in data.items():
        ticker.tick()
        counter = build_co_occurrence(tokens_list, window_size, stemmer)
        r.append((cid, counter))
    return r
コード例 #9
0
def build_co_occur_from_pc_feature(
        data: Dict[str, List[ScoreParagraph]]) -> List[Tuple[str, Counter]]:
    window_size = 10
    stemmer = CacheStemmer()
    r = []

    ticker = TimeEstimator(len(data))
    for cid, para_list in data.items():
        ticker.tick()
        tokens_list: List[List[str]] = [e.paragraph.tokens for e in para_list]
        counter = build_co_occurrence(tokens_list, window_size, stemmer)
        r.append((cid, counter))
    return r
コード例 #10
0
def save_qdf_ex():
    ii_path = os.path.join(cpath.data_path, "adhoc", "robust_inv_index.pickle")
    inv_index = pickle.load(open(ii_path, "rb"))
    save_path = os.path.join(cpath.data_path, "adhoc", "robust_meta.pickle")
    meta = pickle.load(open(save_path, "rb"))
    stopwords = load_stopwords()
    stemmer = CacheStemmer()

    simple_posting = {}

    qdf_d = Counter()
    for term in inv_index:
        simple_posting[term] = set()
        for doc_id, _ in inv_index[term]:
            simple_posting[term].add(doc_id)

    for doc in meta:
        date, headline = meta[doc]
        tokens = nltk.tokenize.wordpunct_tokenize(headline)
        terms = set()
        for idx, t in enumerate(tokens):
            if t in stopwords:
                continue

            t_s = stemmer.stem(t)

            terms.add(t_s)

        for t in terms:
            simple_posting[t].add(doc)

    for term in inv_index:
        qdf = len(simple_posting[term])
        qdf_d[term] = qdf

    save_path = os.path.join(cpath.data_path, "adhoc", "robust_qdf_ex.pickle")
    pickle.dump(qdf_d, open(save_path, "wb"))
コード例 #11
0
ファイル: idf_lime.py プロジェクト: clover3/Chair
def explain_by_lime_idf(data: List[str], get_idf) -> List[Tuple[str, float]]:
    stemmer = CacheStemmer()

    def split(t):
        return t.split()

    explainer = lime_text.LimeTextExplainer(split_expression=split, bow=True)

    def evaluate_score(problems: List[str]):
        scores = []
        for problem in problems:
            score = solve(problem)
            scores.append([0, score])
        return np.array(scores)

    def solve(problem: str):
        tokens = split(problem)
        if "[SEP]" not in tokens:
            return 0
        e: QueryDoc = parse_problem(tokens)
        q_terms = lmap(stemmer.stem, e.query)
        doc_terms = lmap(stemmer.stem, e.doc)
        tf = Counter(doc_terms)
        q_terms_set = set(q_terms)
        score = 0
        for term, cnt in tf.items():
            if term in q_terms_set:
                idf = get_idf(term)
                score += log(1 + cnt) * idf
            # TODO add idf multiplication
        return score

    explains = []
    tick = TimeEstimator(len(data))
    for entry in data:
        assert type(entry) == str
        exp = explainer.explain_instance(entry,
                                         evaluate_score,
                                         num_features=512)
        # l = list(exp.local_exp[1])
        # l.sort(key=get_first)
        # indices, scores = zip(*l)
        l2 = exp.as_list()
        l2.sort(key=get_second, reverse=True)
        explains.append(l2)
        tick.tick()
    return explains
コード例 #12
0
def work(q_res_path, save_name):
    ranked_list_d = load_galago_ranked_list(q_res_path)
    window_size = 10
    stemmer = CacheStemmer()
    print(q_res_path)

    ticker = TimeEstimator(len(ranked_list_d))
    r = []
    for claim_id, ranked_list in ranked_list_d.items():
        ticker.tick()
        doc_ids = list([e.doc_id for e in ranked_list])
        print("1")
        counter = build_co_occurrence(get_tokens_form_doc_ids(doc_ids), window_size, stemmer)
        print("2")
        r.append((claim_id, counter))

    save_to_pickle(r, save_name)
コード例 #13
0
def count_it(
        data: Dict[str, List[ScoreParagraph]]) -> List[Tuple[str, Counter]]:
    stemmer = CacheStemmer()
    r = []
    stopword = load_stopwords()

    def remove_stopwords(tokens: List[str]) -> List[str]:
        return list([t for t in tokens if t not in stopword])

    ticker = TimeEstimator(len(data))
    for cid, para_list in data.items():
        ticker.tick()
        tokens_list: List[List[str]] = [e.paragraph.tokens for e in para_list]
        list_tokens: List[List[str]] = lmap(stemmer.stem_list, tokens_list)
        list_tokens: List[List[str]] = lmap(remove_stopwords, list_tokens)

        all_cnt = Counter()
        for tokens in list_tokens:
            all_cnt.update(Counter(tokens))

        r.append((cid, all_cnt))
    return r
コード例 #14
0
ファイル: segment2problem.py プロジェクト: clover3/Chair
class ProblemMaker:
    def __init__(self):
        vocab_file = os.path.join(data_path, "bert_voca.txt")
        self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                    do_lower_case=True)
        self.stemmer = CacheStemmer()
        self.stopword = load_stopwords()
        self.df = self.load_galgo_df_stat()

    def load_galgo_df_stat(self):
        return load_df(os.path.join(data_path, "enwiki", "tf_stat"))

    def generate_mask(self, inst, max_num_tokens, masked_lm_prob,
                      short_seq_prob, rng):
        max_predictions_per_seq = 20

        cur_seg, prev_seg, next_seg = inst

        def get_seg_tokens(seg):
            if seg is None:
                return None
            title, content, st, ed = seg
            return flatten([self.tokenizer.tokenize(t) for t in content])

        title, content, st, ed = cur_seg
        prev_tokens = get_seg_tokens(prev_seg)
        target_tokens = get_seg_tokens(cur_seg)
        next_tokens = get_seg_tokens(next_seg)

        if rng.random() < short_seq_prob and next_tokens is not None:
            target_seq_length = rng.randint(2, max_num_tokens)
            short_seg = target_tokens[:target_seq_length]
            remain_seg = target_tokens[target_seq_length:]
            next_tokens = (remain_seg + next_tokens)[:max_num_tokens]
            target_tokens = short_seg

        num_to_predict = min(
            max_predictions_per_seq,
            max(1, int(round(len(target_tokens) * masked_lm_prob))))

        cand_indice = list(range(0, len(target_tokens)))
        rng.shuffle(cand_indice)
        mask_indice = cand_indice[:num_to_predict]
        doc_id = "{}-{}-{}".format(title, st, ed)
        mask_inst = target_tokens, prev_seg, cur_seg, next_seg, prev_tokens, next_tokens, mask_indice, doc_id
        return mask_inst

    def filter_stopword(self, l):
        return list([t for t in l if t not in self.stopword])

    def generate_query(self, mask_inst):
        target_tokens, prev_seg, cur_seg, next_seg, prev_tokens, next_tokens, mask_indice, doc_id = mask_inst

        title, content, st, ed = cur_seg

        visible_tokens = get_visible(content, target_tokens, mask_indice,
                                     self.tokenizer)

        stemmed = list([self.stemmer.stem(t) for t in visible_tokens])
        query = clean_query(self.filter_stopword(stemmed))
        high_qt_stmed = self.high_idf_q_terms(Counter(query))

        query = []
        for t in visible_tokens:
            if self.stemmer.stem(t) in high_qt_stmed:
                query.append(t)

        if not query:
            print(title)
            print(content)
            print(visible_tokens)
            print(high_qt_stmed)
        return query

    def high_idf_q_terms(self, q_tf, n_limit=10):
        total_doc = 11503029 + 100

        high_qt = Counter()
        for term, qf in q_tf.items():
            qdf = self.df[term]
            w = BM25_3_q_weight(qf, qdf, total_doc)
            high_qt[term] = w

        return set(left(high_qt.most_common(n_limit)))
コード例 #15
0
ファイル: pipeline.py プロジェクト: clover3/Chair
class Pipeline:
    def __init__(self):
        tprint("Pipeline Init")
        self.stemmer = CacheStemmer()
        vocab_file = os.path.join(cpath.data_path, "bert_voca.txt")
        self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                                    do_lower_case=True)
        self.iteration_dir = "/mnt/scratch/youngwookim/data/tlm_iter1"
        if not os.path.exists("/mnt/scratch/youngwookim/"):
            self.iteration_dir = "/mnt/nfs/work3/youngwookim/data/tlm_iter1"
        self.seg_max_seq = 256
        self.model_max_seq = 512
        self.rng = random.Random(0)
        self.masked_lm_prob = 0.15
        self.short_seq_prob = 0.1
        self.inst_per_job = 1000
        self.stopword = load_stopwords()
        self.pr = FeatureExtractor(self.seg_max_seq - 3)
        self.tf_record_maker = None
        self.code_tick = CodeTiming()
        tprint("Pipeline Init Done")

    def run_A(self, job_id):
        output = []
        queries = []
        ticker = TimeEstimator(self.inst_per_job)
        for i in range(self.inst_per_job):
            p, q = self.process_A()
            qid = job_id * self.inst_per_job + i

            output.append((p, qid))
            queries.append((qid, q))
            ticker.tick()

        self.save_query(job_id, queries)
        self.save_output_A(job_id, output)

    def run_B(self, job_id):
        if self.pr.doc_posting is None:
            self.pr.doc_posting = per_doc_posting_server.load_dict()

        output_A = self.load_output_A(job_id)
        candi_docs = self.load_candidate_docs(job_id)
        feature_str_list = []
        seg_candi_list = []
        ticker = TimeEstimator(self.inst_per_job)
        for i in range(self.inst_per_job):
            problem, qid = output_A[i]
            qid_str = str(qid)
            if qid_str in candi_docs:
                doc_candi = candi_docs[qid_str]
                seg_candi, features = self.process_B(problem, doc_candi)
                fstr = "\n".join([libsvm_str(qid, 0, f) for f in features])
                feature_str_list.append(fstr)
                seg_candi_list.append(seg_candi)
            else:
                feature_str_list.append([])
                seg_candi_list.append([])
            ticker.tick()

            if i % 100 == 3:
                self.code_tick.print()

        self.save("seg_candi_list", job_id, seg_candi_list)
        self.save_ltr(job_id, feature_str_list)

    def run_C(self, job_id):
        if self.tf_record_maker is None:
            self.tf_record_maker = TFRecordMaker(self.model_max_seq)

        output_A = self.load_output_A(job_id)
        seg_candi_list = self.load("seg_candi_list", job_id)
        ltr_result = self.load_ltr(job_id)

        uid_list = []
        tf_record_list = []
        for i in range(self.inst_per_job):
            problem, qid = output_A[i]
            insts_id = "{}_{}".format(job_id, i)
            r = self.process_C(problem, seg_candi_list[i], ltr_result[i])
            for idx, tf_record in r:
                tf_id = insts_id + "_{}".format(idx)
                uid = job_id * 1000 * 1000 + i * 10 + idx
                uid_list.append(uid)
                tf_record_list.append(tf_record)

        data = zip(tf_record_list, uid_list)
        self.write_tf_record(job_id, data)

    def inspect_seg(self, job_id):
        output_A = self.load_output_A(job_id)
        seg_candi_list = self.load("seg_candi_list", job_id)
        q_path = self.get_path("query", "g_query_{}.json".format(job_id))
        queries = json.load(open(q_path, "r"))["queries"]
        ltr_result = self.load_ltr(job_id)

        for i in range(self.inst_per_job):
            problem, qid = output_A[i]
            print(qid)
            scl = seg_candi_list[i]
            query = queries[i]["text"][len("#combine("):-1]
            print(query)
            q_terms = query.split()
            q_tf = stemmed_counter(q_terms, self.stemmer)
            print(list(q_tf.keys()))
            for doc_id, loc in scl[:30]:
                doc_tokens = self.pr.token_dump.get(doc_id)
                l = self.pr.get_seg_len_dict(doc_id)[loc]
                passage = doc_tokens[loc:loc + l]
                d_tf = stemmed_counter(passage, self.stemmer)

                arr = []

                for qt in q_tf:
                    arr.append(d_tf[qt])
                print(arr)

    def process_A(self):
        segment = self.sample_segment()
        problem = self.segment2problem(segment)
        query = self.problem2query(problem)
        return problem, query

    def process_B(self, problem, doc_candi):
        self.code_tick.tick_begin("get_seg_candidate")
        seg_candi = self.get_seg_candidate(doc_candi, problem)
        self.code_tick.tick_end("get_seg_candidate")
        target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem
        self.code_tick.tick_begin("get_feature_list")
        feature = self.pr.get_feature_list(doc_id, sent_list, target_tokens,
                                           mask_indice, seg_candi)
        self.code_tick.tick_end("get_feature_list")
        return seg_candi, feature

    def process_C(self, problem, seg_candi, scores):
        target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem

        def select_top(items, scores, n_top):
            return left(
                list(zip(items, scores)).sort(key=lambda x: x[1],
                                              reverse=True)[:n_top])

        e = {
            "target_tokens": target_tokens,
            "sent_list": sent_list,
            "prev_tokens": prev_tokens,
            "next_tokens": next_tokens,
            "mask_indice": mask_indice,
            "doc_id": doc_id,
            "passages": select_top(seg_candi, scores, 3)
        }
        insts = self.tf_record_maker.generate_ir_tfrecord(self, e)

        r = []
        for idx, p in enumerate(insts):
            tf_record = self.tf_record_maker.make_instance(problem)
            r.append((idx, tf_record))
        return r

    def write_tf_record(self, job_id, data):
        inst_list, uid_list = filter_instances(data)
        max_pred = 20
        data = zip(inst_list, uid_list)
        output_path = self.get_path("tf_record_pred", "{}".format(job_id))
        write_predict_instance(data, self.tokenizer, self.model_max_seq,
                               max_pred, [output_path])

    # Use robust_idf_mini
    def sample_segment(self):
        r = get_random_sent()
        s_id, doc_id, loc, g_id, sent = r
        doc_rows = get_doc_sent(doc_id)
        max_seq = self.seg_max_seq
        target_tokens, sent_list, prev_tokens, next_tokens = extend(
            doc_rows, sent, loc, self.tokenizer, max_seq)
        inst = target_tokens, sent_list, prev_tokens, next_tokens, doc_id
        return inst

    def segment2problem(self, segment):
        mask_inst = generate_mask(segment, self.seg_max_seq,
                                  self.masked_lm_prob, self.short_seq_prob,
                                  self.rng)
        return mask_inst

    def problem2query(self, problem):
        target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem
        basic_tokens, word_mask_indice = translate_mask2token_level(
            sent_list, target_tokens, mask_indice, self.tokenizer)
        # stemming

        visible_tokens = remove(basic_tokens, word_mask_indice)
        stemmed = list([self.stemmer.stem(t) for t in visible_tokens])
        query = self.filter_stopword(stemmed)
        high_qt_stmed = self.pr.high_idf_q_terms(Counter(query))

        final_query = []
        for t in visible_tokens:
            if self.stemmer.stem(t) in high_qt_stmed:
                final_query.append(t)
        return final_query

    def filter_res(self, query_res, doc_id):
        t_doc_id, rank, score = query_res[0]

        assert rank == 1
        if t_doc_id == doc_id:
            valid_res = query_res[1:]
        else:
            valid_res = query_res
            for i in range(len(query_res)):
                e_doc_id, e_rank, _ = query_res[i]
                if e_doc_id == t_doc_id:
                    valid_res = query_res[:i] + query_res[i + 1:]
                    break
        return valid_res

    def get_seg_candidate(self, doc_candi, problem):
        target_tokens, sent_list, prev_tokens, next_tokens, mask_indice, doc_id = problem
        valid_res = self.filter_res(doc_candi, doc_id)
        return self.pr.rank(doc_id,
                            valid_res,
                            target_tokens,
                            sent_list,
                            mask_indice,
                            top_k=100)

    def filter_stopword(self, l):
        return list([t for t in l if t not in self.stopword])

    def get_path(self, sub_dir_name, file_name):
        out_path = os.path.join(self.iteration_dir, sub_dir_name, file_name)
        dir_path = os.path.join(self.iteration_dir, sub_dir_name)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)
        return out_path

    def save_query(self, job_idx, queries):
        j_queries = []
        for qid, query in queries:
            query = clean_query(query)
            j_queries.append({
                "number": str(qid),
                "text": "#combine({})".format(" ".join(query))
            })

        data = {"queries": j_queries}
        out_path = self.get_path("query", "g_query_{}.json".format(job_idx))
        fout = open(out_path, "w")
        fout.write(json.dumps(data, indent=True))
        fout.close()

    def save_pickle_at(self, obj, sub_dir_name, file_name):
        out_path = self.get_path(sub_dir_name, file_name)
        fout = open(out_path, "wb")
        pickle.dump(obj, fout)

    def load_pickle_at(self, sub_dir_name, file_name):
        out_path = self.get_path(sub_dir_name, file_name)
        fout = open(out_path, "rb")
        return pickle.load(fout)

    def save_output_A(self, job_idx, output):
        self.save_pickle_at(output, "output_A", "{}.pickle".format(job_idx))

    def load_output_A(self, job_idx):
        return self.load_pickle_at("output_A", "{}.pickle".format(job_idx))

    def save(self, dir_name, job_idx, output):
        self.save_pickle_at(output, dir_name, "{}.pickle".format(job_idx))

    def load(self, dir_name, job_idx):
        return self.load_pickle_at(dir_name, "{}.pickle".format(job_idx))

    def load_candidate_docs(self, job_id):
        out_path = self.get_path("q_res", "{}.txt".format(job_id))
        return load_galago_ranked_list(out_path)

    def save_ltr(self, job_idx, data):
        out_path = self.get_path("ltr", str(job_idx))
        s = "\n".join(data)
        open(out_path, "w").write(s)

    def load_ltr(self, job_idx):
        return NotImplemented
コード例 #16
0
ファイル: segment_ranker_1.py プロジェクト: clover3/Chair
class PassageRanker:
    def __init__(self, window_size):
        self.stemmer = CacheStemmer()
        self.window_size = window_size
        self.doc_posting = None
        self.stopword = load_stopwords()

        vocab_file = os.path.join(cpath.data_path, "bert_voca.txt")
        self.tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_file, do_lower_case=True)


        def load_pickle(name):
            p = os.path.join(cpath.data_path, "adhoc", name + ".pickle")
            return pickle.load(open(p, "rb"))

        self.doc_len_dict = load_pickle("doc_len")
        self.qdf = load_pickle("robust_qdf_ex")
        self.meta = load_pickle("robust_meta")
        self.head_tokens = load_pickle("robust_title_tokens")
        self.seg_info = load_pickle("robust_seg_info")
        self.not_found = set()

        self.total_doc_n =  len(self.doc_len_dict)
        self.avdl = sum(self.doc_len_dict.values()) / len(self.doc_len_dict)
        tprint("Init PassageRanker")

    def get_seg_len_dict(self, doc_id):
        interval_list = self.seg_info[doc_id]
        l_d = {}
        for e in interval_list:
            (loc, loc_ed), (loc_sub, loc_sub_ed)  = e
            l_d[loc] = loc_ed - loc
        return l_d


    def filter_stopword(self, l):
        return list([t for t in l if t not in self.stopword])

    def BM25(self, q_tf, tokens):
        dl = len(tokens)
        tf_d = stemmed_counter(tokens, self.stemmer)
        score = 0
        for term, qf in q_tf.items():
            tf = tf_d[term]
            qdf = self.qdf[term]
            total_doc = self.total_doc_n
            score += BM25_3(tf, qf, qdf, total_doc, dl, self.avdl)
        return score

    def BM25_seg(self, term, q_tf, tf_d, dl):
        qdf = self.qdf[term]
        total_doc = self.total_doc_n
        return BM25_3(tf_d, q_tf, qdf, total_doc, dl, self.avdl)


    def high_idf_q_terms(self, q_tf, n_limit=10):
        total_doc = self.total_doc_n

        high_qt = Counter()
        for term, qf in q_tf.items():
            qdf = self.qdf[term]
            w = BM25_3_q_weight(qf, qdf, total_doc)
            high_qt[term] = w

        return set(left(high_qt.most_common(n_limit)))


    def reform_query(self, tokens):
        stemmed = list([self.stemmer.stem(t) for t in tokens])
        query = self.filter_stopword(stemmed)
        q_tf = Counter(query)
        high_qt = self.high_idf_q_terms(q_tf)
        new_qf = Counter()
        for t in high_qt:
            new_qf[t] = q_tf[t]
        return new_qf

    def get_title_tokens(self, doc_id):
        return self.head_tokens[doc_id]

    def weight_doc(self, title, content):
        new_doc = []
        title_repeat = 3
        content_repeat = 1
        for t in title:
            new_doc += [t] * title_repeat
        for t in content:
            new_doc += [t] * content_repeat
        return new_doc

    def pick_diverse(self, candidate, top_k):
        doc_set = set()
        r = []
        for j in range(len(candidate)):
            score, window_subtokens, window_tokens, doc_id, loc = candidate[j]
            if doc_id not in doc_set:
                r.append(candidate[j])
                doc_set.add(doc_id)
                if len(r) == top_k:
                    break
        return r


    def rank(self, doc_id, query_res, target_tokens, sent_list, mask_indice, top_k):
        if self.doc_posting is None:
            self.doc_posting = per_doc_posting_server.load_dict()
        title = self.get_title_tokens(doc_id)
        visible_tokens = get_visible(sent_list, target_tokens, mask_indice, self.tokenizer)

        source_doc_rep = self.weight_doc(title, visible_tokens)
        q_tf = self.reform_query(source_doc_rep)

        # Rank Each Window by BM25F
        stop_cut = 100
        segments = Counter()

        n_not_found = 0

        for doc_id, rank, doc_score in query_res[:stop_cut]:
            doc_title = self.get_title_tokens(doc_id)
            title_tf = stemmed_counter(doc_title, self.stemmer)
            per_doc_index = self.doc_posting[doc_id] if doc_id in self.doc_posting else None
            seg_len_d = self.get_seg_len_dict(doc_id)

            if per_doc_index is None:
                self.not_found.add(doc_id)
                n_not_found += 1
                continue

            assert doc_id not in self.not_found

            # Returns location after moving subtoken 'skip' times

            for term, qf in q_tf.items():
                if term not in per_doc_index:
                    continue
                for seg_loc, tf in per_doc_index[term]:
                    rep_tf = tf + title_tf[term]
                    dl = seg_len_d[seg_loc] + len(doc_title)
                    segments[(doc_id,seg_loc)] += self.BM25_seg(term, qf, rep_tf, dl)

        r = []
        for e , score in segments.most_common(top_k):
            if score == 0:
                print("Doc {} has 0 score".format(doc_id))
            r.append(e)
        if n_not_found > 0.9 * len(query_res):
            print("WARNING : {} of {} not found".format(n_not_found, len(query_res)))
        return r