Esempio n. 1
0
class ElmoEncoder:
    default_max_layer = 2

    def __init__(self):
        self.model = ElmoEmbedder(PRETRAINED_ELMO_OPTION_URL,
                                  PRETRAINED_ELMO_WEIGHT_URL)
        self.pretrained_name = 'elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights'
        self.tokenizer = WordTokenizer()

    def encode(self, data, file_prefix=DEFAULT_FILE_PREFIX):
        emb_pkl_path = embed_elements_path(file_prefix, self.pretrained_name)

        if os.path.isfile(emb_pkl_path):
            with open(emb_pkl_path, 'rb') as f:
                entities, relations = pickle.load(f)
                return entities, relations

        ent_embs = []
        rel_embs = []

        for triple in data.triples:
            sent_tokens = self.tokenizer.batch_tokenize(triple)
            sbj_p, rel_p, obj_p = [len(tokens) for tokens in sent_tokens]
            tokens = [token.text for token in sum(sent_tokens, [])]
            hid_states = self.model.embed_sentence(tokens)
            hid_states = torch.from_numpy(hid_states)

            sbj_states = hid_states[:, :sbj_p, :]
            ent_embs.append(torch.mean(sbj_states, axis=1))

            rel_p_padded = sbj_p + rel_p
            rel_states = hid_states[:, sbj_p:rel_p_padded, :]
            rel_embs.append(torch.mean(rel_states, axis=1))

            obj_p_padded = rel_p_padded + obj_p
            obj_states = hid_states[:, rel_p_padded:obj_p_padded, :]
            ent_embs.append(torch.mean(obj_states, axis=1))

        ent_embs = torch.stack(ent_embs, axis=0)
        rel_embs = torch.stack(rel_embs, axis=0)

        with open(emb_pkl_path, 'wb') as f:
            pickle.dump((ent_embs, rel_embs), f,
                        protocol=pickle.HIGHEST_PROTOCOL)

        return ent_embs, rel_embs
Esempio n. 2
0
class ProcessStory:
    def __init__(self, sentence_splitter):
        self._sentence_splitter = sentence_splitter
        self._word_tokenizer = WordTokenizer()

    def __call__(self, lines: List[List[str]], story_ids: List[int]) -> Tuple:

        story_metrics = []
        story_sentences_to_save = []

        story_sentences_split = self._sentence_splitter.batch_split_sentences(lines)

        for sentences, story_id in zip(story_sentences_split, story_ids):
            try:

                tokenized_sentences = self._word_tokenizer.batch_tokenize(sentences)

                total_story_tokens = 0

                for i, sentence in enumerate(tokenized_sentences):
                    start_span = total_story_tokens
                    sentence_len = len(sentence)
                    total_story_tokens += sentence_len
                    end_span = total_story_tokens

                    text = " ".join([s.text for s in sentence])

                    story_sentences_to_save.append(
                        dict(sentence_num=i, story_id=story_id, text=text,
                             sentence_len=sentence_len, start_span=start_span, end_span=end_span))

                story_metrics.append(dict(sentence_num=len(sentences), tokens_num=total_story_tokens, id=story_id))

            except Exception as e:
                logging.error(e)

        return story_ids, story_sentences_to_save, story_metrics
Esempio n. 3
0
async def save_coreferences(coreference_model: Model, dataset_db: str, cuda_device: Union[List[int], int] = None,
                            save_batch_size: int = 25, sentence_chunks: int = 200):
    with dataset.connect(dataset_db, engine_kwargs=engine_kwargs) as db:

        coref_table = db.create_table('coreference')
        coref_table.create_column('story_id', db.types.bigint)
        coref_table.create_column('coref_id', db.types.integer)
        coref_table.create_column('start_span', db.types.integer)
        coref_table.create_column('end_span', db.types.integer)
        coref_table.create_column('mention_text', db.types.string)
        coref_table.create_column('context_text', db.types.string)
        coref_table.create_index(['story_id'])
        coref_table.create_index(['start_span'])
        coref_table.create_index(['end_span'])

        gpu_max_workers = 1

        if isinstance(cuda_device, (list, tuple)):
            gpu_max_workers = len(cuda_device)
            gpus = cuda_device
        else:
            gpus = [cuda_device]

        word_tokenizer = WordTokenizer()

        loop = asyncio.get_event_loop()

        with ThreadPoolExecutor(max_workers=gpu_max_workers) as executor:

            processors = []
            for gpu in gpus:
                processors.append(CoreferenceProcessor(coreference_model, dataset_db, cuda_device=gpu))
            processors_cycle = itertools.cycle(processors)

            tasks = []
            # Order by shortest to longest so possible failures are at the end.
            for story in db['story'].find(order_by=['sentence_num', 'id']):

                sentence_list = [s["text"] for s in db["sentence"].find(story_id=story["id"], order_by='id')]
                sentence_tokens = word_tokenizer.batch_tokenize(sentence_list)

                for sentence_chunk in more_itertools.chunked(sentence_tokens, n=sentence_chunks):
                    sentence_chunk_flat = list(more_itertools.flatten(sentence_chunk))

                    if len(sentence_chunk_flat) < 10:
                        continue

                    sentence_chunk_text = [t.text for t in sentence_chunk_flat]

                    tasks.append(loop.run_in_executor(executor, next(processors_cycle), sentence_chunk_text, story["id"]))

                    if len(tasks) == save_batch_size:
                        results = await asyncio.gather(*tasks)

                        for coref_to_save in results:
                            try:

                                db["coreference"].insert_many(copy.deepcopy(coref_to_save))

                            except Exception as e:
                                logging.error(e)

                        tasks = []

            results = await asyncio.gather(*tasks)

            for coref_to_save in results:
                try:

                    db["coreference"].insert_many(coref_to_save)

                except Exception as e:
                    logging.error(e)

            logger.info(f"Coreferences Saved")
Esempio n. 4
0
class DocumentOracleDerivation(object):
    def __init__(self,
                 min_combination_num: int = 3,
                 max_combination_num: int = 5,
                 rm_stop_word: bool = True,
                 synonyms: bool = True,
                 stem: bool = False,
                 tokenization: bool = True,
                 beam_sz: int = 5,
                 candidate_percent: float = 1.0):
        self.min_combination_num = min_combination_num
        self.max_combination_num = max_combination_num
        self.rm_stop_word = rm_stop_word
        self.stem = stem
        self.tokenization = tokenization
        self.beam_sz = beam_sz
        self.candidate_percent = candidate_percent
        if self.stem:
            self.stemmer = PorterStemmer().stem_word
        else:
            self.stemmer = lambda x: x
        self.synonyms = synonyms
        if self.tokenization:
            from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
            self.tokenizer = WordTokenizer()
        if self.rm_stop_word:
            self.stop_words = list(set(stopwords.words('english'))) + [x for x in string.punctuation] + ['``', '\'\'']
        else:
            self.stop_words = []

    def get_rouge_w_annotation_ready_to_use(self, gold_tokens: List[str],
                                            pred_tokens: List[str]):
        gold_lower = list(set([x.lower() for x in gold_tokens]))
        gold_wo_stop = [x for x in gold_lower if x not in self.stop_words]  # change of index
        gold_wo_stop = replace_w_morphy(gold_wo_stop)
        gold_stem = [ps.stem(x) for x in gold_wo_stop]

        pred_lower = list([x.lower() for x in pred_tokens])

        pred_lower = replace_w_morphy(pred_lower)
        pred_lower = remove_duplicate_tok(pred_lower)

        pred_stem = [ps.stem(x) for x in pred_lower]
        pred_stem = remove_duplicate_tok(pred_stem)
        size_of_gold = len(gold_stem)
        size_of_pred = len(pred_stem)

        gold_key, gold_value = [], []
        for idx, word in enumerate(gold_wo_stop):
            # for one gold word, we have a minigroup
            _tmp = []
            if word in pred_lower:
                _tmp.append(word)
            elif word in pred_stem:
                _tmp.append(word)
            elif gold_stem[idx] in pred_lower:
                _tmp.append(gold_stem[idx])
            elif gold_stem[idx] in pred_stem:
                _tmp.append(gold_stem[idx])

            # if word or stm word could match, we don't need to search syn
            if _tmp != []:
                _tmp = _tmp[0]
                gold_key.append(_tmp)
                gold_value.append(1)
            else:
                if word not in cache_for_th:
                    try:
                        cache_for_th[word] = flatten(th.Word(word).synonyms('all', relevance=[3]))
                    except:
                        cache_for_th[word] = []

                if gold_stem[idx] not in cache_for_th:
                    try:
                        cache_for_th[gold_stem[idx]] = flatten(
                            th.Word(gold_stem[idx]).synonyms('all', relevance=[3]))
                    except:
                        cache_for_th[gold_stem[idx]] = []
                syn = cache_for_th[word]
                syn_stem = cache_for_th[gold_stem[idx]]
                syn = list(set(syn + syn_stem))
                # print(syn)
                l_syn = len(syn)
                if l_syn != 0:
                    gold_key += syn
                    gold_value += [float(1 / l_syn)] * l_syn

        gold_tokens = [ps.stem(x) for x in gold_key]
        # pred_set = set(pred)
        # comp intersection
        vs = 0
        key_index = []
        for p_idx in range(len(pred_lower)):
            p_word = pred_lower[p_idx]
            p_stem_word = pred_stem[p_idx]

            if p_word in gold_key:
                idx = gold_key.index(p_word)
                v = gold_value[idx]
                vs += v
                key_index.append(p_idx)
            elif p_stem_word in gold_tokens:
                idx = gold_tokens.index(p_stem_word)
                v = gold_value[idx]
                vs += v
                key_index.append(p_idx)

        rouge_recall_1 = 0
        if size_of_gold != 0:
            rouge_recall_1 = vs / float(size_of_gold)
        rouge_pre_1 = 0
        if size_of_pred != 0:
            rouge_pre_1 = vs / float(size_of_pred)
        # print(rouge_recall_1, rouge_pre_1)
        # assert rouge_recall_1 <= 1
        # assert rouge_pre_1 <= 1
        if random.random() < 0.00001:
            print("Recall: {}\tPre: {}".format(rouge_recall_1, rouge_pre_1))
            print(pred_tokens)
        customed_recall = rouge_recall_1 + rouge_pre_1 * 0.01 - 0.01
        f1 = 0 if (rouge_recall_1 + rouge_pre_1 == 0) else 2 * (rouge_recall_1 * rouge_pre_1) / (
                rouge_recall_1 + rouge_pre_1)
        return customed_recall, f1, key_index
        # f1 = 0 if (rouge_recall_1 + rouge_pre_1 == 0) else 2 * (rouge_recall_1 * rouge_pre_1) / (
        #         rouge_recall_1 + rouge_pre_1)
        # f1 = rouge_recall_1 * 5 + rouge_pre_1

    def comp_num_seg_out_of_p_sent_beam(self, _filtered_doc_list,
                                        num_sent_in_combination,
                                        target_ref_sum_list,
                                        map_from_new_to_ori_idx) -> dict:
        beam: List[dict] = []
        if len(_filtered_doc_list) < num_sent_in_combination:
            return {"nlabel": num_sent_in_combination,
                    "data": {},
                    "best": None
                    }

        combs = list(range(0, len(_filtered_doc_list)))
        # _num_edu seq_len
        cur_beam = {
            "in": [],
            "todo": combs,
            "val": 0
        }
        beam.append(cur_beam)
        for t in range(num_sent_in_combination):
            dict_pattern = {}
            # compute top beam_sz for every beam
            global_board = []
            for b in beam:
                already_in_beam = b['in']
                todo = b['todo']

                leaderboard = {}
                for to_add in todo:
                    after_add = already_in_beam + [to_add]
                    candidate_doc_list = list(itertools.chain.from_iterable([_filtered_doc_list[i] for i in after_add]))
                    # average_f_score = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list)
                    _, average_f_score, _ = self.get_rouge_w_annotation_ready_to_use(gold_tokens=target_ref_sum_list,
                                                                                     pred_tokens=candidate_doc_list)
                    leaderboard[to_add] = average_f_score
                sorted_beam = [(k, leaderboard[k]) for k in sorted(leaderboard, key=leaderboard.get, reverse=True)]

                for it in sorted_beam:
                    new_in = already_in_beam + [it[0]]
                    new_in.sort()
                    str_new_in = [str(x) for x in new_in]
                    if '_'.join(str_new_in) in dict_pattern:
                        continue
                    else:
                        dict_pattern['_'.join(str_new_in)] = True
                    new_list = todo.copy()
                    new_list.remove(it[0])
                    _beam = {
                        "in": new_in,
                        "todo": new_list,
                        "val": it[1]
                    }
                    global_board.append(_beam)
            # merge and get the top beam_sz among all

            sorted_global_board = sorted(global_board, key=lambda x: x["val"], reverse=True)

            _cnt = 0
            check_dict = []
            beam_waitlist = []
            for it in sorted_global_board:
                str_in = sorted(it['in'])
                str_in = [str(x) for x in str_in]
                _tmp_key = '_'.join(str_in)
                if _tmp_key in check_dict:
                    continue
                else:
                    beam_waitlist.append(it)
                    check_dict.append(_tmp_key)
                _cnt += 1
                if _cnt >= self.beam_sz:
                    break
            beam = beam_waitlist
        # if len(beam) < 2:
        #     print(len(_filtered_doc_list))
        #     print(_num_edu)
        # Write oracle to a string like: 0.4 0.3 0.4
        _comb_bag = {}
        for it in beam:
            n_comb = it['in']
            n_comb.sort()
            n_comb_original = [map_from_new_to_ori_idx[a] for a in n_comb]
            n_comb_original.sort()  # json label
            n_comb_original = [int(x) for x in n_comb_original]
            candidate_doc_list = list(itertools.chain.from_iterable([_filtered_doc_list[i] for i in n_comb]))
            # f1 = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list)
            _, f1, _ = self.get_rouge_w_annotation_ready_to_use(target_ref_sum_list, candidate_doc_list)

            # f_avg = (f1 + f2 + fl) / 3
            _comb_bag[f1] = {"label": n_comb_original,
                             "R1": f1,
                             "nlabel": num_sent_in_combination}
        # print(len(_comb_bag))
        if len(_comb_bag) == 0:
            return {"nlabel": num_sent_in_combination,
                    "data": {},
                    "best": None
                    }
        else:
            best_key = sorted(_comb_bag.keys(), reverse=True)[0]
            rt_dict = {"nlabel": num_sent_in_combination,
                       "data": _comb_bag,
                       "best": _comb_bag[best_key]
                       }
            return rt_dict

    def derive_doc_oracle(self, doc_list: List[str],
                          ref_sum: str,
                          prefix_summary: str = ""
                          ):
        processed_doc_list, processed_ref_sum_str, processed_prefix_sum_str = [], '', ''
        if self.tokenization:
            token_doc_list = self.tokenizer.batch_tokenize(doc_list)
            for doc in token_doc_list:
                processed_doc_list.append([word.text for word in doc])
            processed_ref_sum_list = [w.text for w in self.tokenizer.tokenize(ref_sum)]
            processed_prefix_sum_list = [w.text for w in self.tokenizer.tokenize(prefix_summary)]
        else:
            processed_doc_list = [d.split(" ") for d in doc_list]
            processed_ref_sum_list = ref_sum.split(" ")
            processed_prefix_sum_list = prefix_summary.split(" ")
        processed_doc_list = [[x.lower() for x in sent] for sent in processed_doc_list]
        processed_ref_sum_list = [x.lower() for x in processed_ref_sum_list]
        processed_prefix_sum_list = [x.lower() for x in processed_prefix_sum_list]
        if self.rm_stop_word:
            processed_doc_list = [[x for x in sent if x not in self.stop_words] for sent in processed_doc_list]
            processed_ref_sum_list = [x for x in processed_ref_sum_list if x not in self.stop_words]
            processed_prefix_sum_list = [x for x in processed_prefix_sum_list if x not in self.stop_words]

        target_ref_sum_list = [x for x in processed_ref_sum_list if x not in processed_prefix_sum_list]

        # preprocessing finished
        filtered_doc_list, map_from_new_to_ori_idx = self.pre_prune(processed_doc_list, target_ref_sum_list)

        combination_data_dict = {}

        for num_sent_in_combination in range(self.min_combination_num, self.max_combination_num):
            combination_data = self.comp_num_seg_out_of_p_sent_beam(_filtered_doc_list=filtered_doc_list,
                                                                    num_sent_in_combination=num_sent_in_combination,
                                                                    target_ref_sum_list=target_ref_sum_list,
                                                                    map_from_new_to_ori_idx=map_from_new_to_ori_idx)
            combination_data_dict = {**combination_data_dict, **combination_data['data']}
            combination_data_dict[num_sent_in_combination] = combination_data
        return combination_data_dict

    def pre_prune(self, list_of_doc: List[List[str]],
                  ref_sum: List[str]
                  ):
        keep_candidate_num = math.ceil(len(list_of_doc) * self.candidate_percent)
        # f_score_list = [self.get_approximate_rouge(ref_sum, x) for x in list_of_doc]
        f_score_list = [self.get_rouge_w_annotation_ready_to_use(ref_sum, x)[1] for x in list_of_doc]
        top_p_sent_idx = numpy.argsort(f_score_list)[-keep_candidate_num:]

        map_from_new_to_ori_idx = []
        # filter
        filtered_doc_list = []
        for i in range(len(top_p_sent_idx)):
            filtered_doc_list.append(list_of_doc[top_p_sent_idx[i]])
            map_from_new_to_ori_idx.append(top_p_sent_idx[i])
        return filtered_doc_list, map_from_new_to_ori_idx
Esempio n. 5
0
class DocumentOracleDerivation(object):
    def __init__(self,
                 mixed_combination: bool,
                 min_combination_num: int = 1,
                 max_combination_num: int = 8,
                 rm_stop_word: bool = True,
                 stem: bool = False,
                 morphy: bool = False,
                 tokenization: bool = True,
                 beam_sz: int = 5,
                 prune_candidate_percent: float = 0.4):
        self.mixed_combination = mixed_combination
        self.min_combination_num = min_combination_num
        self.max_combination_num = max_combination_num
        self.rm_stop_word = rm_stop_word
        self.stem = stem
        self.tokenization = tokenization
        self.beam_sz = beam_sz
        self.prune_candidate_percent = prune_candidate_percent
        if self.stem:
            self.stemmer = PorterStemmer().stem_word
        else:
            self.stemmer = lambda x: x

        self.morphy = morphy

        if self.tokenization:
            from allennlp.data.tokenizers.word_tokenizer import WordTokenizer
            self.tokenizer = WordTokenizer()
        if self.rm_stop_word:
            self.stop_words = list(set(stopwords.words('english'))) + [
                x for x in string.punctuation
            ] + ['``', '\'\'']
        else:
            self.stop_words = []

    def derive_doc_oracle(
        self,
        doc_list: List[str],
        ref_sum: str,
        prefix_summary: str = "",
    ):
        # return a dict where key=rouge-f1 and value= [0,0,0,1,0,1,0,...] same size as doc_list
        # processed_doc_list, processed_ref_sum_str, processed_prefix_sum_str = [], '', ''
        len_of_doc = len(doc_list)
        processed_doc_list = [self._rouge_clean(x) for x in doc_list]
        processed_ref_sum_str = self._rouge_clean(ref_sum)
        processed_prefix_sum_str = self._rouge_clean(prefix_summary)
        if self.tokenization:
            new_processed_doc_list = []
            token_doc_list = self.tokenizer.batch_tokenize(processed_doc_list)
            for doc in token_doc_list:
                new_processed_doc_list.append([word.text for word in doc])
            processed_doc_list = new_processed_doc_list
            processed_ref_sum_list = [
                w.text for w in self.tokenizer.tokenize(processed_ref_sum_str)
            ]
            processed_prefix_sum_list = [
                w.text
                for w in self.tokenizer.tokenize(processed_prefix_sum_str)
            ]
        else:
            processed_doc_list = [d.split(" ") for d in processed_doc_list]
            processed_ref_sum_list = processed_ref_sum_str.split(" ")
            processed_prefix_sum_list = processed_prefix_sum_str.split(" ")

        # must do lower
        processed_doc_list = [[x.lower() for x in sent]
                              for sent in processed_doc_list]
        processed_ref_sum_list = [x.lower() for x in processed_ref_sum_list]
        processed_prefix_sum_list = [
            x.lower() for x in processed_prefix_sum_list
        ]

        # if self.rm_stop_word:
        #     processed_doc_list = [[x for x in sent if x not in self.stop_words] for sent in processed_doc_list]
        #     processed_ref_sum_list = [x for x in processed_ref_sum_list if x not in self.stop_words]
        #     processed_prefix_sum_list = [x for x in processed_prefix_sum_list if x not in self.stop_words]

        target_ref_sum_list = [
            x for x in processed_ref_sum_list
            if x not in processed_prefix_sum_list
        ]

        # TODO
        f_score_list, score_matrix = self.iter_rouge(processed_doc_list,
                                                     target_ref_sum_list)

        # preprocessing finished
        filtered_doc_list, map_from_new_to_ori_idx = self.pre_prune(
            processed_doc_list, target_ref_sum_list)
        combination_data_dict = {}
        for num_sent_in_combination in range(self.min_combination_num,
                                             self.max_combination_num):
            combination_data = self.comp_num_seg_out_of_p_sent_beam(
                _filtered_doc_list=filtered_doc_list,
                num_sent_in_combination=num_sent_in_combination,
                target_ref_sum_list=target_ref_sum_list,
                map_from_new_to_ori_idx=map_from_new_to_ori_idx)
            if combination_data['best'] is None:
                break
            best_rouge_of_this_batch = combination_data['best']['R1']
            if len(combination_data_dict) >= self.beam_sz:
                rouge_in_bag = [
                    float(k) for k, v in combination_data_dict.items()
                ]
                if best_rouge_of_this_batch < min(rouge_in_bag):
                    break

            combination_data_dict = {
                **combination_data_dict,
                **combination_data['data']
            }
            combination_data_dict = collections.OrderedDict(
                sorted(combination_data_dict.items(), reverse=True))
            sliced = islice(combination_data_dict.items(), self.beam_sz)
            combination_data_dict = collections.OrderedDict(sliced)
            # combination_data_dict[num_sent_in_combination] = combination_data

        # prepare return data
        return_dict = {}
        for k, v in combination_data_dict.items():
            # tmp_list = [0 for _ in range(len_of_doc)]
            # for i in v['label']:
            #     tmp_list[i] = 1
            return_dict[k] = v['label']
        return return_dict

    def iter_rouge(self, list_of_doc, ref_sum):
        f_score_list = [
            self.get_rouge_ready_to_use(ref_sum, x) for x in list_of_doc
        ]
        # score_matrix_delta = [[0 for _ in range(len(list_of_doc))] for _ in range(len(list_of_doc))]
        score_matrix = [[0 for _ in range(len(list_of_doc))]
                        for _ in range(len(list_of_doc))]
        input = []
        for idx, x in enumerate(list_of_doc):
            for jdx, y in enumerate(list_of_doc):
                input.append((idx, jdx, ref_sum, x + y))
                s = self.get_rouge_ready_to_use(ref_sum, x + y)
                score_matrix[idx][jdx] = s
                # if f_score_list[idx] < 0.01:
                #
                #     score_matrix_delta[idx][jdx] = 0
                # else:
                #     score_matrix_delta[idx][jdx] = min(s / (f_score_list[idx] + 0.001), 2)
        # import numpy as np
        # np.set_printoptions(precision=2)
        # import seaborn as sns
        # sns.set()
        # f_score_list = np.asarray([f_score_list, f_score_list])
        # bx = sns.heatmap(f_score_list)
        # fig = bx.get_figure()
        # fig.savefig("individual_output.png")
        # print('-' * 30)
        # print(np.asarray(score_matrix))
        # score_matrix_delta = np.asarray(score_matrix_delta)
        # ax = sns.heatmap(score_matrix_delta)
        # fig = ax.get_figure()
        # fig.savefig("output.png")

        # ncpu=multiprocessing.cpu_count()
        # pool = multiprocessing.Pool(processes=ncpu)
        # results = pool.starmap(self.get_rouge_ready_to_use, input)
        # for r in results:
        #     score, idx,jdx = r
        #     score_matrix[idx][jdx] = score
        return f_score_list, score_matrix

    def comp_num_seg_out_of_p_sent_beam(self, _filtered_doc_list,
                                        num_sent_in_combination,
                                        target_ref_sum_list,
                                        map_from_new_to_ori_idx) -> dict:
        beam: List[dict] = []
        if len(_filtered_doc_list) < num_sent_in_combination:
            return {
                "nlabel": num_sent_in_combination,
                "data": {},
                "best": None
            }

        combs = list(range(0, len(_filtered_doc_list)))
        # _num_edu seq_len
        cur_beam = {"in": [], "todo": combs, "val": 0}
        beam.append(cur_beam)
        for t in range(num_sent_in_combination):
            dict_pattern = {}
            # compute top beam_sz for every beam
            global_board = []
            for b in beam:
                already_in_beam = b['in']
                todo = b['todo']

                leaderboard = {}
                for to_add in todo:
                    after_add = already_in_beam + [to_add]
                    candidate_doc_list = list(
                        itertools.chain.from_iterable(
                            [_filtered_doc_list[i] for i in after_add]))
                    # average_f_score = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list)
                    average_f_score = self.get_rouge_ready_to_use(
                        gold_tokens=target_ref_sum_list,
                        pred_tokens=candidate_doc_list)
                    leaderboard[to_add] = average_f_score
                sorted_beam = [(k, leaderboard[k]) for k in sorted(
                    leaderboard, key=leaderboard.get, reverse=True)]

                for it in sorted_beam:
                    new_in = already_in_beam + [it[0]]
                    new_in.sort()
                    str_new_in = [str(x) for x in new_in]
                    if '_'.join(str_new_in) in dict_pattern:
                        continue
                    else:
                        dict_pattern['_'.join(str_new_in)] = True
                    new_list = todo.copy()
                    new_list.remove(it[0])
                    _beam = {"in": new_in, "todo": new_list, "val": it[1]}
                    global_board.append(_beam)
            # merge and get the top beam_sz among all

            sorted_global_board = sorted(global_board,
                                         key=lambda x: x["val"],
                                         reverse=True)

            _cnt = 0
            check_dict = []
            beam_waitlist = []
            for it in sorted_global_board:
                str_in = sorted(it['in'])
                str_in = [str(x) for x in str_in]
                _tmp_key = '_'.join(str_in)
                if _tmp_key in check_dict:
                    continue
                else:
                    beam_waitlist.append(it)
                    check_dict.append(_tmp_key)
                _cnt += 1
                if _cnt >= self.beam_sz:
                    break
            beam = beam_waitlist
        # if len(beam) < 2:
        #     print(len(_filtered_doc_list))
        #     print(_num_edu)
        # Write oracle to a string like: 0.4 0.3 0.4
        _comb_bag = {}
        for it in beam:
            n_comb = it['in']
            n_comb.sort()
            n_comb_original = [map_from_new_to_ori_idx[a] for a in n_comb]
            n_comb_original.sort()  # json label
            n_comb_original = [int(x) for x in n_comb_original]
            candidate_doc_list = list(
                itertools.chain.from_iterable(
                    [_filtered_doc_list[i] for i in n_comb]))
            # f1 = self.get_approximate_rouge(target_ref_sum_list, candidate_doc_list)
            f1 = self.get_rouge_ready_to_use(target_ref_sum_list,
                                             candidate_doc_list)

            # f_avg = (f1 + f2 + fl) / 3
            _comb_bag[f1] = {
                "label": n_comb_original,
                "R1": f1,
                "nlabel": num_sent_in_combination
            }
        # print(len(_comb_bag))
        if len(_comb_bag) == 0:
            return {
                "nlabel": num_sent_in_combination,
                "data": {},
                "best": None
            }
        else:
            best_key = sorted(_comb_bag.keys(), reverse=True)[0]
            rt_dict = {
                "nlabel": num_sent_in_combination,
                "data": _comb_bag,
                "best": _comb_bag[best_key]
            }
            return rt_dict

    @staticmethod
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    def get_rouge_ready_to_use_w_index(self, gold_tokens: List[str],
                                       pred_tokens: List[str], idx, jdx):
        return self.get_rouge_ready_to_use(gold_tokens, pred_tokens), idx, jdx

    # No synomous standard version

    def get_rouge_ready_to_use(self, gold_tokens: List[str],
                               pred_tokens: List[str]):
        len_gold = len(gold_tokens)
        len_pred = len(pred_tokens)

        gold_bigram = _get_ngrams(2, gold_tokens)
        pred_bigram = _get_ngrams(2, pred_tokens)

        if self.rm_stop_word:
            gold_unigram = set(
                [x for x in gold_tokens if x not in self.stop_words])
            pred_unigram = set(
                [x for x in pred_tokens if x not in self.stop_words])
        else:
            gold_unigram = set(gold_tokens)
            pred_unigram = set(pred_tokens)

        rouge_1 = cal_rouge(pred_unigram, gold_unigram, len_pred,
                            len_gold)['f']
        rouge_2 = cal_rouge(pred_bigram, gold_bigram, len_pred, len_gold)['f']
        rouge_score = (rouge_1 + rouge_2) / 2
        return rouge_score

    def pre_prune(self, list_of_doc: List[List[str]], ref_sum: List[str]):
        keep_candidate_num = math.ceil(
            len(list_of_doc) * self.prune_candidate_percent)
        # f_score_list = [self.get_approximate_rouge(ref_sum, x) for x in list_of_doc]
        f_score_list = [
            self.get_rouge_ready_to_use(ref_sum, x) for x in list_of_doc
        ]
        top_p_sent_idx = numpy.argsort(f_score_list)[-keep_candidate_num:]

        map_from_new_to_ori_idx = []
        # filter
        filtered_doc_list = []
        for i in range(len(top_p_sent_idx)):
            filtered_doc_list.append(list_of_doc[top_p_sent_idx[i]])
            map_from_new_to_ori_idx.append(top_p_sent_idx[i])
        return filtered_doc_list, map_from_new_to_ori_idx