コード例 #1
0
ファイル: utils.py プロジェクト: qingerVT/translate
def dummy_dictionary(dummy_tokens=3, additional_token_list=None):
    """First adds the amount of dummy_tokens that you specify, then
    finally the additional_token_list, which is a list of string token values"""
    d = pytorch_translate_dictionary.Dictionary()
    for i in range(dummy_tokens):
        token = f"token_{i}"
        d.add_symbol(token)
    if additional_token_list is not None:
        for token in additional_token_list:
            d.add_symbol(token)
    d.finalize(padding_factor=-1)
    return d
コード例 #2
0
ファイル: utils.py プロジェクト: qingerVT/translate
def create_vocab_dictionaries():
    additional_special_tokens = (
        vocab_constants.MAX_SPECIAL_TOKENS -
        pytorch_translate_dictionary.Dictionary().nspecial)
    src_dict = dummy_dictionary(
        dummy_tokens=additional_special_tokens,
        additional_token_list=["a", "b", "c", "d", "e"],
    )
    tgt_dict = dummy_dictionary(
        dummy_tokens=additional_special_tokens,
        additional_token_list=["A", "B", "C", "D", "E"],
    )
    return src_dict, tgt_dict
コード例 #3
0
def build_vocab_from_corpus(
    corpus_file: str,
    dialect: str,
    save_dir: str,
    max_vocab_size: int,
    tokens_with_penalty: Optional[str] = None,
):
    vocab_file = os.path.join(save_dir, f'dictionary-{dialect}.txt')
    d = pytorch_translate_dictionary.Dictionary()
    with open(corpus_file, 'r') as f:
        for line in f:
            tokens = line.split()
            for t in tokens:
                token_index = d.add_symbol(t)

    # Set indices to receive penalty
    if tokens_with_penalty:
        # Assume input tokens are unique
        lexicon = []
        with open(tokens_with_penalty, 'r', encoding='utf-8') as f:
            for line in f:
                tokens = line.strip().split()
                if len(tokens) == 1:
                    lexicon.append(tokens[0])

        for token, token_index in d.indices.items():
            if token in lexicon:
                d.lexicon_indices.add(token_index)

    d.finalize()
    d.save(vocab_file, threshold=0, nwords=max_vocab_size)
    print(f'Generated new vocab file saved at {vocab_file}.')
    if max_vocab_size < 0:
        print('No maximum vocab sized enforced.')
    else:
        print(f'Maximum vocab size {max_vocab_size}')

    return vocab_file
コード例 #4
0
def build_vocab_from_corpus(
    corpus_file: str,
    dialect: str,
    save_dir: str,
    max_vocab_size: int,
):
    vocab_file = os.path.join(save_dir, f'dictionary-{dialect}.txt')
    d = pytorch_translate_dictionary.Dictionary()
    with open(corpus_file, 'r') as f:
        for line in f:
            tokens = line.split()
            for t in tokens:
                token_index = d.add_symbol(t)

    d.finalize()
    d.save(vocab_file, threshold=0, nwords=max_vocab_size)
    print(f'Generated new vocab file saved at {vocab_file}.')
    if max_vocab_size < 0:
        print('No maximum vocab sized enforced.')
    else:
        print(f'Maximum vocab size {max_vocab_size}')

    return vocab_file
コード例 #5
0
ファイル: data.py プロジェクト: planb-hakone/translate
def load_binarized_dataset(
    train_corpus: ParallelCorpusConfig,
    eval_corpus: ParallelCorpusConfig,
    train_split: str,
    eval_split: str,
    args: argparse.Namespace,
    use_char_source: bool = False,
) -> data.LanguageDatasets:
    if is_multilingual(args):  # Dummy dictionaries
        source_dict = pytorch_translate_dictionary.Dictionary()
        target_dict = pytorch_translate_dictionary.Dictionary()
    else:
        source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.source_vocab_file)
        target_dict = pytorch_translate_dictionary.Dictionary.load(
            args.target_vocab_file)

    if use_char_source:
        char_source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.char_source_vocab_file)
        # this attribute is used for CharSourceModel construction
        args.char_source_dict_size = len(char_source_dict)

    dataset = data.LanguageDatasets(
        src=train_corpus.source.dialect,
        dst=train_corpus.target.dialect,
        src_dict=source_dict,
        dst_dict=target_dict,
    )

    for split, corpus in [(train_split, train_corpus),
                          (eval_split, eval_corpus)]:
        if not os.path.exists(corpus.source.data_file):
            raise ValueError(
                f"{corpus.source.data_file} for {split} not found!")
        if not os.path.exists(corpus.target.data_file):
            raise ValueError(
                f"{corpus.target.data_file} for {split} not found!")

        dst_dataset = InMemoryNumpyDataset.create_from_file(
            corpus.target.data_file)
        weights_dataset = None
        if corpus.weights_file and os.path.exists(corpus.weights_file):
            weights_dataset = weighted_data.IndexedWeightsDataset(
                corpus.weights_file)
            assert len(dst_dataset) == len(weights_dataset)

        if use_char_source:
            src_dataset = char_data.InMemoryNumpyWordCharDataset.create_from_file(
                corpus.source.data_file)
            dataset.splits[split] = char_data.LanguagePairSourceCharDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=source_dict.pad(),
                eos_idx=source_dict.eos(),
                weights=weights_dataset,
            )
        else:
            src_dataset = InMemoryNumpyDataset.create_from_file(
                corpus.source.data_file)
            dataset.splits[split] = weighted_data.WeightedLanguagePairDataset(
                src=src_dataset,
                dst=dst_dataset,
                pad_idx=source_dict.pad(),
                eos_idx=source_dict.eos(),
                weights=weights_dataset,
            )

    return dataset
コード例 #6
0
 def test_base(self):
     d = dictionary.Dictionary()
     self.assertEqual(len(d), d.nspecial)