예제 #1
0
def check_if_bpe_config_supported(bpe_config: BpeConfig):
    if bpe_config.get_param_value(BpeParam.UNICODE) == 'bytes':
        raise BpeConfigNotSupported('Byte-BPE is not yet supported')

    if bpe_config.get_param_value(BpeParam.WORD_END):
        raise BpeConfigNotSupported(
            'BPE with word-end characters are not yet supported')

    if bpe_config.get_param_value(BpeParam.CASE) == 'prefix':
        raise BpeConfigNotSupported(
            'BPE with case encoded in prefix is not yet supported')
예제 #2
0
def test_true_true_code_bytes(abspath_mock, bpe_learner_mock, dataset_mock):

    # given
    abspath_mock.return_value = PATH_TO_DATASET_STUB
    dataset_mock.create = Mock(spec=dataset_mock, return_value=dataset_mock)
    argv = [
        'learn-bpe', '1000', '-p', PATH_TO_DATASET_STUB, '--bytes',
        '--word-end'
    ]

    # when
    parse_and_run(argv)

    # then
    prep_config = PrepConfig({
        PrepParam.EN_ONLY: 'u',
        PrepParam.COM: '0',
        PrepParam.STR: 'E',
        PrepParam.SPLIT: 'F',
        PrepParam.TABS_NEWLINES: 's',
        PrepParam.CASE: 'u'
    })
    bpe_config = BpeConfig({
        BpeParam.CASE: 'yes',
        BpeParam.WORD_END: True,
        BpeParam.BASE: 'code',
        BpeParam.UNICODE: 'bytes',
    })
    dataset_mock.create.assert_called_with(PATH_TO_DATASET_STUB, prep_config,
                                           None, None, bpe_config)
    bpe_learner_mock.run.assert_called_with(dataset_mock, 1000, bpe_config)
예제 #3
0
def test_run_bytes_bpe(mocked_dataset):
    bpe_config = BpeConfig({
        BpeParam.BASE: 'code',
        BpeParam.WORD_END: False,
        BpeParam.UNICODE: 'bytes',
        BpeParam.CASE: 'yes'
    })
    with pytest.raises(BpeConfigNotSupported):
        run(mocked_dataset, 1, bpe_config)
예제 #4
0
def get_base_vocab_dir(bpe_list_id: str) -> str:
    dataset_bpe_dir = get_dataset_bpe_dir(bpe_list_id)
    prep_config_str = os.path.basename(dataset_bpe_dir)
    #TODO do not hard code date and dir format in general
    m = regex.fullmatch(r'(.*?)((?:_-_.*)?)', prep_config_str)
    if not m:
        raise ValueError(f'Invalid dir format: {prep_config_str}')
    bpe_config = BpeConfig.from_suffix(m[2])
    base_prep_config = bpe_config.to_prep_config()
    return os.path.join(USER_VOCAB_DIR, f'{m[1]}_-_{base_prep_config}')
예제 #5
0
def test_all_custom(get_timestamp_mock, os_exists_mock):
    prep_config = PrepConfig({
        PrepParam.EN_ONLY: 'u',
        PrepParam.COM: 'c',
        PrepParam.STR: '1',
        PrepParam.SPLIT: '0',
        PrepParam.TABS_NEWLINES: 's',
        PrepParam.CASE: 'u'
    })
    bpe_config = BpeConfig({
        BpeParam.CASE: 'yes',
        BpeParam.WORD_END: False,
        BpeParam.BASE: "code",
        BpeParam.UNICODE: "no",
    })

    custom_bpe_config = CustomBpeConfig("id", 1000, "/codes/file",
                                        "/cache/file")
    actual = Dataset.create(PATH_TO_DATASET_STUB,
                            prep_config,
                            "c|java",
                            custom_bpe_config,
                            bpe_config,
                            overriden_path_to_prep_dataset=OVERRIDDEN_PATH)

    assert PATH_TO_DATASET_STUB == actual._path
    assert prep_config == actual._prep_config
    assert ['c', 'java'] == actual._normalized_extension_list
    assert custom_bpe_config == actual._custom_bpe_config
    assert bpe_config == actual._bpe_config
    assert '01_01_01' == actual._dataset_last_modified

    assert SubDataset(actual, PATH_TO_DATASET_STUB, '') == actual.original
    assert SubDataset(
        actual, os.path.join(PARSED_DATASETS_DIR, 'dataset_01_01_01_c_java'),
        '.parsed') == actual.parsed
    assert SubDataset(
        actual,
        os.path.join(OVERRIDDEN_PATH,
                     'dataset_01_01_01_c_java_-_uc10su_id-1000_-_prep'),
        '.prep') == actual.preprocessed
    assert os.path.join(
        USER_CONFIG_DIR, VOCAB_DIR,
        'dataset_01_01_01_c_java_-_U0EFsu') == actual.base_bpe_vocab_path
    assert os.path.join(
        USER_CONFIG_DIR, BPE_DIR,
        'dataset_01_01_01_c_java_-_nounicode') == actual.bpe_path
    assert os.path.join(
        USER_CACHE_DIR, 'file_lists',
        'dataset_01_01_01_c_java') == actual.path_to_file_list_folder
    assert os.path.join(
        USER_CONFIG_DIR, VOCAB_DIR,
        'dataset_01_01_01_c_java_-_uc10su_id-1000') == actual.vocab_path
예제 #6
0
파일: impl.py 프로젝트: mir-am/codeprep
def create_bpe_config_from_args(run_options: Dict[str, str]) -> BpeConfig:
    if run_options['--no-unicode']:
        unicode = 'no'
    elif run_options['--bytes']:
        unicode = 'bytes'
    else:
        unicode = 'yes'
    return BpeConfig({
        BpeParam.CASE: 'yes',
        BpeParam.WORD_END: run_options["--word-end"],
        BpeParam.BASE: 'java' if run_options['--legacy'] else 'code',
        BpeParam.UNICODE: unicode
    })
예제 #7
0
def create_new_id_from(path: str, bpe_config: BpeConfig, predefined_bpe_codes_id: Optional[str] = None) -> str:
    if predefined_bpe_codes_id:
        return predefined_bpe_codes_id
    else:
        name_parts = [os.path.basename(path)]
        id_suffix = bpe_config.to_suffix()
        if id_suffix:
            name_parts.append(id_suffix)
        id_base = '_'.join(name_parts)
        existing_ids = _get_all_custom_bpe_codes_and_max_merges().keys()
        if id_base not in existing_ids:
            return id_base
        else:
            def extract_number(full_id: str, id_base: str) -> int:
                m = regex.fullmatch(f"{id_base}_([0-9]+)", full_id)
                return int(m[1]) if m else 0

            numbers = list(map(lambda d: extract_number(d, id_base), existing_ids))
            new_number = max(numbers) + 1
            return f'{id_base}_{new_number}'