def calc_merges_difference_comparison_matrix2( path_to_metadata: str, fraction: str, n_merges: int, output_file: Optional[str] = None): result = [] int_fraction = int(fraction) for i in range(100 // int_fraction): chunk_vocab = f'{i * int_fraction}' reassambled_file_original = os.path.join( path_to_metadata, f'{fraction}_{chunk_vocab}_{chunk_vocab}_{n_merges}_reassambled') original_vocab = read_dict_from_2_columns(reassambled_file_original, read_to_dict=False) row = [] for j in range(100 // int_fraction): chunk_merges = f'{j*int_fraction}' reassambled_file = os.path.join( path_to_metadata, f'{fraction}_{chunk_vocab}_{chunk_merges}_{n_merges}_reassambled' ) vocab = read_dict_from_2_columns(reassambled_file, read_to_dict=False) row.append(get_merge_similiarity_rate(original_vocab, vocab)) non_bpe_vocab = read_dict_from_2_columns(os.path.join( path_to_metadata, f'{fraction}_{chunk_vocab}_vocab'), read_to_dict=False) row.append(get_merge_similiarity_rate(original_vocab, non_bpe_vocab)) result.append(row) if output_file: output_matrix_to_csv(result, output_file) return result
def init_splitting_config(dataset: str, prep_config: PrepConfig, bpe_base_repr: Optional[str], bpe_n_merges: Optional[int], splitting_file: Optional[str]): global global_n_gramm_splitting_config global_n_gramm_splitting_config = NgramSplitConfig() if prep_config.get_param_value(PrepParam.SPLIT) in [4, 5, 6, 7, 8, 9]: if not bpe_base_repr: bpe_base_repr = prep_config.get_base_bpe_prep_config() if prep_config.get_param_value(PrepParam.SPLIT) == 9: if not bpe_n_merges: raise ValueError( "--bpe-n-merges must be specified for repr **9**") else: bpe_n_merges_dict = {4: 5000, 5: 1000, 6: 10000, 7: 20000, 8: 0} bpe_n_merges = bpe_n_merges_dict[prep_config.get_param_value( PrepParam.SPLIT)] if bpe_base_repr.find("/") == -1: bpe_base_dataset = dataset else: bpe_base_dataset, bpe_base_repr = bpe_base_repr.split("/") logger.info(f'Using bpe base dataset: {bpe_base_dataset}') logger.info(f'Using bpe base repr: {bpe_base_repr}') logger.info(f'Using bpe_n_merges: {bpe_n_merges}') path_to_merges_dir = os.path.join(DEFAULT_PARSED_DATASETS_DIR, bpe_base_dataset, METADATA_DIR, bpe_base_repr, BPE_DIR, str(bpe_n_merges)) bpe_merges_file = os.path.join(path_to_merges_dir, 'merges.txt') bpe_merges_cache = os.path.join(path_to_merges_dir, 'merges_cache.txt') global_n_gramm_splitting_config.merges_cache = read_dict_from_2_columns( bpe_merges_cache, val_type=list) global_n_gramm_splitting_config.merges = read_merges(bpe_merges_file) global_n_gramm_splitting_config.set_splitting_type( NgramSplittingType.BPE) elif prep_config.get_param_value(PrepParam.SPLIT) == 3: if not splitting_file: raise ValueError("--splitting-file must be specified") splittings = read_dict_from_2_columns(splitting_file, val_type=list, delim='|') global_n_gramm_splitting_config.sc_splittings = splittings global_n_gramm_splitting_config.set_splitting_type( NgramSplittingType.NUMBERS_AND_CUSTOM) elif prep_config.get_param_value(PrepParam.SPLIT) == 2: global_n_gramm_splitting_config.set_splitting_type( NgramSplittingType.ONLY_NUMBERS)
def split_vocab_using_merges(vocab_file: str, merges_file: str, n_merges: int, output_file: str): vocab = read_dict_from_2_columns(vocab_file) merges = bpe_encode.read_merges(merges_file, n_merges) result = {} for word, freq in vocab.items(): encoded_word = bpe_encode.encode_word(word, merges) result[" ".join(encoded_word)] = freq dump_dict_into_2_columns(result, output_file)
def run(dataset: str, repr: str, n_merges: int, reset: bool, percent: float, start_from: float) -> None: bpe_dir_prefix = fractions_manager.get_percent_prefix(percent, start_from) bpe_dir_prefix = '' if bpe_dir_prefix == '100_' else bpe_dir_prefix base_dir = os.path.join(DEFAULT_PARSED_DATASETS_DIR, dataset, METADATA_DIR, repr) if reset: starting_from_scratch = True archive_existing_common_bpe_folder(base_dir) else: logger.info("Using existing merges...") most_recent_bpe_dir = get_most_recent_bpe_dir(base_dir, bpe_dir_prefix) if not most_recent_bpe_dir: logger.warning("Existing merges not found ") starting_from_scratch = True else: all_vocab = read_dict_from_2_columns( os.path.join(most_recent_bpe_dir, REASSEMBLED_VOCAB_FILE_NAME)) vocab, non_splitable_vocab = separate_non_splittable_vocab(all_vocab, from_reassambled=True) merges = read_list(os.path.join(most_recent_bpe_dir, MERGES_FILE_NAME)) starting_from_scratch = False if starting_from_scratch: logger.info("Starting the encoding from scratch...") all_vocab = read_dict_from_2_columns(os.path.join(base_dir, f'{bpe_dir_prefix}{VOCAB_FILE_NAME}')) vocab, non_splitable_vocab = separate_non_splittable_vocab(all_vocab, from_reassambled=False) merges = [] pairs = get_stats(vocab) n_done_merges = len(merges) for i in range(n_merges): try: best, occurences = pairs.pop_pair() print(f'Processing pair number {n_done_merges + i+1} {best}') merges.append((best[0], best[1], str(occurences))) except KeyError: break vocab = merge_vocab(best, vocab, pairs) for k, v in non_splitable_vocab.items(): vocab[k] = v resulting_vocab = collections.defaultdict(int) for entry, frequency in vocab.items(): for subword in entry.split(" "): resulting_vocab[subword] += frequency resulting_vocab_sorted = sorted(resulting_vocab.items(), key=lambda x: x[1], reverse=True) merges_cache = {} for entry, frequency in vocab.items(): subword_list = entry.split(' ') key = ''.join(subword_list) merges_cache[key] = subword_list new_bpe_dir = os.path.join(base_dir, f'{bpe_dir_prefix}{BPE_DIR}', str(len(merges))) if os.path.exists(new_bpe_dir): raise AssertionError(f'Dir {new_bpe_dir} already exists? Something went wrong.' f'Check the contents of {os.path.join(base_dir, BPE_DIR)} folder') os.makedirs(new_bpe_dir) dump_list(merges, os.path.join(new_bpe_dir, MERGES_FILE_NAME)) dump_dict_into_2_columns(vocab, os.path.join(new_bpe_dir, REASSEMBLED_VOCAB_FILE_NAME)) dump_dict_into_2_columns(merges_cache, os.path.join(new_bpe_dir, MERGES_CACHE_FILE_NAME), val_type=list) dump_dict_into_2_columns(resulting_vocab_sorted, os.path.join(new_bpe_dir, RESULTING_VOCAB_FILE_NAME)) logger.info(f'Bpe output files are saved into {new_bpe_dir} folder')