Example #1
0
 def load_from_file(self, file_path):
     '''
     从文件组红加载vocab
     :param file_name:
     :param pickle_path:
     :return:
     '''
     mappings = load_pickle(input_file=file_path)
     self.idx2word = mappings['idx2word']
     self.word2idx = mappings['word2idx']
Example #2
0
def get_misspelled_data_dict(pkl_path, train_path):

    if not pkl_path.exists():

        print(f"creating {pkl_path}")

        sentences, _ = common.get_sentences_and_labels_from_txt(train_path)

        sentence_to_augmented_sentences = {}
        for sentence in tqdm(sentences):
            augmented_sentences = get_misspelled_sentences(sentence)
            sentence_to_augmented_sentences[sentence] = augmented_sentences

        common.save_pickle(pkl_path, sentence_to_augmented_sentences)

    return common.load_pickle(pkl_path)
Example #3
0
def get_eda_data_dict(pkl_path, train_path, n_aug, alpha):

    if not pkl_path.exists():

        print(f"creating {pkl_path}")

        sentences, _ = common.get_sentences_and_labels_from_txt(train_path)

        sentence_to_augmented_sentences = {}
        for sentence in tqdm(sentences):
            eda_sentences = eda(sentence, alpha=alpha, num_aug=n_aug)
            sentence_to_augmented_sentences[sentence] = eda_sentences

        common.save_pickle(pkl_path, sentence_to_augmented_sentences)

    return common.load_pickle(pkl_path)
Example #4
0
def get_encoding_dict(sentence_to_labels, original_file_path, aug_type, alpha):

    encodings_path = get_encodings_path(original_file_path, aug_type, alpha)

    if not encodings_path.exists():

        print(f"creating {encodings_path}")
        string_to_encoding = {}

        for sentence in tqdm(sentence_to_labels.keys()):
            encoding = get_encoding(sentence, tokenizer, model)
            string_to_encoding[sentence] = encoding
    
        common.save_pickle(encodings_path, string_to_encoding)
    
    return common.load_pickle(encodings_path)
Example #5
0
def get_switchout_data_dict(pkl_path, train_path, n_aug, alpha):

    if not pkl_path.exists():

        print(f"creating {pkl_path}")

        sentences, _ = common.get_sentences_and_labels_from_txt(train_path)
        all_words = load_all_words(sentences)

        sentence_to_augmented_sentences = {}
        for sentence in tqdm(sentences):
            augmented_sentences = get_switchout_sentences(
                sentence, n_aug, alpha, all_words)
            sentence_to_augmented_sentences[sentence] = augmented_sentences

        common.save_pickle(pkl_path, sentence_to_augmented_sentences)

    return common.load_pickle(pkl_path)
Example #6
0
def get_rd_data_dict(pkl_path, train_path, n_aug, alpha):

    if not pkl_path.exists():

        print(f"creating {pkl_path}")

        sentences, _ = common.get_sentences_and_labels_from_txt(train_path)

        sentence_to_augmented_sentences = {}
        for sentence in tqdm(sentences):
            rd_sentences = [
                get_rd_sentence(sentence, alpha) for _ in range(n_aug)
            ]
            sentence_to_augmented_sentences[sentence] = rd_sentences

        common.save_pickle(pkl_path, sentence_to_augmented_sentences)

    return common.load_pickle(pkl_path)
Example #7
0
def load_sent_index_offset(data_type, cache_dir):
    return load_pickle(join(cache_dir, data_type, "sent_index_list_offset.pkl"))
Example #8
0
def load_stop_ctkidx_list(cache_dir, stop_prop=1):
    # assert stop_prop in stop_ctk_list_proportion_list
    loaded_list = load_pickle(join(cache_dir, stop_ctkidx_list_file_name))
    need_num = int(1.* stop_prop / 1000 * len(loaded_list))
    needed_list = loaded_list[:need_num]
    return needed_list
Example #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_type_list", type=str, default="omcs,arc")
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument("--k_hop", type=int, default=3)
    parser.add_argument("--max_num_nodes", type=int, default=1024)
    parser.add_argument("--disable_stop_ctk", action="store_true")
    parser.add_argument("--disable_nb", action="store_true")
    args = parser.parse_args()

    data_type_list = args.data_type_list.split(",")
    num_workers = args.num_workers
    cache_dir = args.cache_dir or index_sent_cache_dir
    k_hop = args.k_hop
    max_num_nodes = args.max_num_nodes
    disable_stop_ctk = args.disable_stop_ctk
    disable_nb = args.disable_nb
    data_type_list = [_e for _e in ["gen", "omcs", "arc", "wikipedia"] if _e in data_type_list]
    ctk_list, cid_list, ctk2idx, cid2idx, cididx2ctkidx, ctkidx2cididxs = load_conceptnet()
    rel_list, rel2idx, cg, cididx2neighbor = load_conceptnet_graph(cid_list, cid2idx)

    part_idxs = [0, ]
    sent_index_offset_list = []
    for _data_type in data_type_list:
        _offset_list = load_sent_index_offset(_data_type, cache_dir)
        sent_index_offset_list.extend(_offset_list)
        part_idxs.append(len(sent_index_offset_list))

    # read all sent
    if disable_stop_ctk:
        print("disable_stop_ctk!!!!!")
    else:
        print("reading all sent to count ctkidx2freq")
        ctkidx2freq_path = join(cache_dir, "cn_ctkidx2freq.pkl")
        if file_exists(ctkidx2freq_path):
            print("\tfound file, loading")
            ctkidx2freq = load_pickle(ctkidx2freq_path)
        else:
            print("\tnot found file, building")
            def _processor_ctkidx2freq(_sent_index_offset_list, _with_sent_index=False):
                local_ctkidx2freq = [0 for _ in range(len(ctk_list))]

                if _with_sent_index:
                    _iterator = tqdm(_sent_index_offset_list)
                else:
                    _iterator = enumerate(tqdm(_sent_index_offset_list))

                for _idx_sent, _sent_index_offset in _iterator:
                    _data_type = get_data_type(_idx_sent, part_idxs, data_type_list)
                    if _data_type != "gen":
                        _sent_data = load_sent_from_shard(_sent_index_offset, cache_dir, _data_type)
                        _tk2spans = _sent_data[2]
                        for _tk in _tk2spans:
                            local_ctkidx2freq[ctk2idx[_tk]] += 1
                return local_ctkidx2freq
            if num_workers == 1:
                ctkidx2freq = _processor_ctkidx2freq(sent_index_offset_list)
            else:
                sent_index_offset_list_with_index = list((_idx, _e) for _idx, _e in enumerate(sent_index_offset_list))
                local_ctkidx2freq_list = multiprocessing_map(
                    _processor_ctkidx2freq, dict_args_list=[
                        {"_sent_index_offset_list": _d, "_with_sent_index": True}
                        for _d in split_to_lists(sent_index_offset_list_with_index, num_workers)
                    ], num_parallels=num_workers
                )
                ctkidx2freq = [sum(_ll[_ctkidx] for _ll in local_ctkidx2freq_list) for _ctkidx in range(len(ctk_list))]
            save_pickle(ctkidx2freq, ctkidx2freq_path)
        print("\tDone")

        # sorting
        print("Getting stop ctk")
        sorted_ctkidx_freq_pairs = sorted(
            [(_ctkidx, _freq) for _ctkidx, _freq in enumerate(ctkidx2freq) if _freq > 0],
            key=lambda _e: _e[1], reverse=True)
        sorted_ctkidx_list, _ = [list(_e) for _e in zip(*sorted_ctkidx_freq_pairs)]
        save_pickle(sorted_ctkidx_list, join(cache_dir, stop_ctkidx_list_file_name))
        save_list_to_file([ctk_list[_ctkidx] for _ctkidx in sorted_ctkidx_list],
                          join(cache_dir, stop_ctk_list_file_name))
        print("\tDone")

    # find
    def _processor(_cididx_list):
        _local_res_list = []
        for _ct_cididx in tqdm(_cididx_list):
            _node_explored = set([_ct_cididx])
            _node_save = [[_ct_cididx], ] + [[] for _ in range(k_hop)]
            _node_buffer = [(_ct_cididx, 0)]
            while len(_node_buffer) > 0:
                _node_cididx, _prev_depth = _node_buffer.pop(0)
                if _prev_depth == k_hop:
                    continue
                _cur_depth = _prev_depth + 1
                _neighbors = cididx2neighbor[_node_cididx]
                # shuffle keys
                _nb_cididxs = list(_neighbors.keys())
                random.shuffle(_nb_cididxs)
                for _nb_cididx in _nb_cididxs:
                    _attr = _neighbors[_nb_cididx]
                    if _nb_cididx in _node_explored:
                        continue
                    _node_explored.add(_nb_cididx)
                    _node_buffer.append((_nb_cididx, _cur_depth))
                    if rel_list[_attr["relation"]] not in REDUNDANT_RELATIONS:  # remove REDUNDANT_RELATIONS
                        _node_save[_cur_depth].append(_nb_cididx)
                        if sum(len(_e) for _e in _node_save) > max_num_nodes:
                            _node_buffer = []
                            break

            _local_res_list.append(_node_save)
        return _local_res_list

    if disable_nb:
        print("disable_nb!!!!!")
    else:
        print("Getting neighbors")
        proc_buffer = []
        wfp_nb = open(join(cache_dir, neighbor_cididxs_file_name), "w", encoding="utf-8")
        nb_offsets = []
        for _ctkidx in tqdm(range(len(cid_list)), total=len(cid_list)):
            proc_buffer.append(_ctkidx)
            if len(proc_buffer) == num_workers * 10000 or _ctkidx == (len(cid_list)-1):
                if num_workers == 1:
                    _res_list = _processor(proc_buffer)
                else:
                    _res_list = combine_from_lists(
                        multiprocessing_map(
                            _processor, dict_args_list=[
                                {"_cididx_list": _d} for _d in split_to_lists(proc_buffer, num_parallels=num_workers)
                            ], num_parallels=num_workers
                        ), ordered=True
                    )
                assert len(_res_list) == len(proc_buffer)
                for _elem in _res_list:
                    nb_offsets.append(wfp_nb.tell())
                    _dump_str = json.dumps(_elem) + os.linesep
                    wfp_nb.write(_dump_str)
                proc_buffer = []
        wfp_nb.close()
        save_pickle(nb_offsets, join(cache_dir, neighbor_cididxs_offset_file_name))
        print("\tDone")
Example #10
0
def load_neighbor_cididxs_offsets(cache_dir):
    return load_pickle(join(cache_dir, neighbor_cididxs_offset_file_name))