def calc_merges_difference_comparison_matrix2(
        path_to_metadata: str,
        fraction: str,
        n_merges: int,
        output_file: Optional[str] = None):
    result = []
    int_fraction = int(fraction)
    for i in range(100 // int_fraction):
        chunk_vocab = f'{i * int_fraction}'
        reassambled_file_original = os.path.join(
            path_to_metadata,
            f'{fraction}_{chunk_vocab}_{chunk_vocab}_{n_merges}_reassambled')
        original_vocab = read_dict_from_2_columns(reassambled_file_original,
                                                  read_to_dict=False)
        row = []
        for j in range(100 // int_fraction):
            chunk_merges = f'{j*int_fraction}'
            reassambled_file = os.path.join(
                path_to_metadata,
                f'{fraction}_{chunk_vocab}_{chunk_merges}_{n_merges}_reassambled'
            )
            vocab = read_dict_from_2_columns(reassambled_file,
                                             read_to_dict=False)
            row.append(get_merge_similiarity_rate(original_vocab, vocab))
        non_bpe_vocab = read_dict_from_2_columns(os.path.join(
            path_to_metadata, f'{fraction}_{chunk_vocab}_vocab'),
                                                 read_to_dict=False)
        row.append(get_merge_similiarity_rate(original_vocab, non_bpe_vocab))
        result.append(row)
    if output_file:
        output_matrix_to_csv(result, output_file)
    return result
Example #2
0
def init_splitting_config(dataset: str, prep_config: PrepConfig,
                          bpe_base_repr: Optional[str],
                          bpe_n_merges: Optional[int],
                          splitting_file: Optional[str]):
    global global_n_gramm_splitting_config
    global_n_gramm_splitting_config = NgramSplitConfig()
    if prep_config.get_param_value(PrepParam.SPLIT) in [4, 5, 6, 7, 8, 9]:
        if not bpe_base_repr:
            bpe_base_repr = prep_config.get_base_bpe_prep_config()

        if prep_config.get_param_value(PrepParam.SPLIT) == 9:
            if not bpe_n_merges:
                raise ValueError(
                    "--bpe-n-merges must be specified for repr **9**")
        else:
            bpe_n_merges_dict = {4: 5000, 5: 1000, 6: 10000, 7: 20000, 8: 0}
            bpe_n_merges = bpe_n_merges_dict[prep_config.get_param_value(
                PrepParam.SPLIT)]

        if bpe_base_repr.find("/") == -1:
            bpe_base_dataset = dataset
        else:
            bpe_base_dataset, bpe_base_repr = bpe_base_repr.split("/")
        logger.info(f'Using bpe base dataset: {bpe_base_dataset}')
        logger.info(f'Using bpe base repr: {bpe_base_repr}')
        logger.info(f'Using bpe_n_merges: {bpe_n_merges}')
        path_to_merges_dir = os.path.join(DEFAULT_PARSED_DATASETS_DIR,
                                          bpe_base_dataset, METADATA_DIR,
                                          bpe_base_repr, BPE_DIR,
                                          str(bpe_n_merges))
        bpe_merges_file = os.path.join(path_to_merges_dir, 'merges.txt')
        bpe_merges_cache = os.path.join(path_to_merges_dir, 'merges_cache.txt')

        global_n_gramm_splitting_config.merges_cache = read_dict_from_2_columns(
            bpe_merges_cache, val_type=list)
        global_n_gramm_splitting_config.merges = read_merges(bpe_merges_file)
        global_n_gramm_splitting_config.set_splitting_type(
            NgramSplittingType.BPE)
    elif prep_config.get_param_value(PrepParam.SPLIT) == 3:
        if not splitting_file:
            raise ValueError("--splitting-file must be specified")

        splittings = read_dict_from_2_columns(splitting_file,
                                              val_type=list,
                                              delim='|')
        global_n_gramm_splitting_config.sc_splittings = splittings
        global_n_gramm_splitting_config.set_splitting_type(
            NgramSplittingType.NUMBERS_AND_CUSTOM)
    elif prep_config.get_param_value(PrepParam.SPLIT) == 2:
        global_n_gramm_splitting_config.set_splitting_type(
            NgramSplittingType.ONLY_NUMBERS)
def split_vocab_using_merges(vocab_file: str, merges_file: str, n_merges: int,
                             output_file: str):
    vocab = read_dict_from_2_columns(vocab_file)
    merges = bpe_encode.read_merges(merges_file, n_merges)
    result = {}
    for word, freq in vocab.items():
        encoded_word = bpe_encode.encode_word(word, merges)
        result[" ".join(encoded_word)] = freq
    dump_dict_into_2_columns(result, output_file)
Example #4
0
def run(dataset: str, repr: str, n_merges: int, reset: bool, percent: float, start_from: float) -> None:
    bpe_dir_prefix = fractions_manager.get_percent_prefix(percent, start_from)
    bpe_dir_prefix = '' if bpe_dir_prefix == '100_' else bpe_dir_prefix

    base_dir = os.path.join(DEFAULT_PARSED_DATASETS_DIR, dataset, METADATA_DIR, repr)
    if reset:
        starting_from_scratch = True
        archive_existing_common_bpe_folder(base_dir)
    else:
        logger.info("Using existing merges...")
        most_recent_bpe_dir = get_most_recent_bpe_dir(base_dir, bpe_dir_prefix)
        if not most_recent_bpe_dir:
            logger.warning("Existing merges not found ")
            starting_from_scratch = True
        else:
            all_vocab = read_dict_from_2_columns(
                os.path.join(most_recent_bpe_dir, REASSEMBLED_VOCAB_FILE_NAME))
            vocab, non_splitable_vocab = separate_non_splittable_vocab(all_vocab, from_reassambled=True)
            merges = read_list(os.path.join(most_recent_bpe_dir, MERGES_FILE_NAME))
            starting_from_scratch = False

    if starting_from_scratch:
        logger.info("Starting the encoding from scratch...")
        all_vocab = read_dict_from_2_columns(os.path.join(base_dir, f'{bpe_dir_prefix}{VOCAB_FILE_NAME}'))
        vocab, non_splitable_vocab = separate_non_splittable_vocab(all_vocab, from_reassambled=False)
        merges = []

    pairs = get_stats(vocab)
    n_done_merges = len(merges)
    for i in range(n_merges):
        try:
            best, occurences = pairs.pop_pair()
            print(f'Processing pair number {n_done_merges + i+1} {best}')
            merges.append((best[0], best[1], str(occurences)))
        except KeyError:
            break
        vocab = merge_vocab(best, vocab, pairs)

    for k, v in non_splitable_vocab.items():
        vocab[k] = v
    resulting_vocab = collections.defaultdict(int)
    for entry, frequency in vocab.items():
        for subword in entry.split(" "):
            resulting_vocab[subword] += frequency
    resulting_vocab_sorted = sorted(resulting_vocab.items(), key=lambda x: x[1], reverse=True)

    merges_cache = {}
    for entry, frequency in vocab.items():
        subword_list = entry.split(' ')
        key = ''.join(subword_list)
        merges_cache[key] = subword_list

    new_bpe_dir = os.path.join(base_dir, f'{bpe_dir_prefix}{BPE_DIR}', str(len(merges)))
    if os.path.exists(new_bpe_dir):
        raise AssertionError(f'Dir {new_bpe_dir} already exists? Something went wrong.'
                             f'Check the contents of {os.path.join(base_dir, BPE_DIR)} folder')
    os.makedirs(new_bpe_dir)

    dump_list(merges, os.path.join(new_bpe_dir, MERGES_FILE_NAME))
    dump_dict_into_2_columns(vocab, os.path.join(new_bpe_dir, REASSEMBLED_VOCAB_FILE_NAME))
    dump_dict_into_2_columns(merges_cache, os.path.join(new_bpe_dir, MERGES_CACHE_FILE_NAME), val_type=list)
    dump_dict_into_2_columns(resulting_vocab_sorted, os.path.join(new_bpe_dir, RESULTING_VOCAB_FILE_NAME))
    logger.info(f'Bpe output files are saved into {new_bpe_dir} folder')