예제 #1
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_vocab_load_and_save(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c, min_freq=3)

        expected_itos = ['<unk>', '<pad>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        vocab_path = os.path.join(self.test_dir, 'vocab.pt')
        torch.save(v, vocab_path)
        loaded_v = torch.load(vocab_path)

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
예제 #2
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_vocab_append_token(self):
        c = OrderedDict({'a': 2})
        v = Vocab(c)
        v.append_token('b')

        self.assertEqual(len(v), 4)
        self.assertEqual(v['b'], 3)
예제 #3
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_vocab_lookup_token(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c, specials_first=False)

        self.assertEqual(v.lookup_token(0), 'a')
예제 #4
0
    def test_errors(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)

        with self.assertRaises(ValueError):
            # Test proper error raised when setting unk token to None
            Vocab(c, unk_token=None)

        with self.assertRaises(RuntimeError):
            # Test proper error raised when setting a token out of bounds
            v = Vocab(c, min_freq=3)
            v.insert_token('new_token', 100)

        with self.assertRaises(RuntimeError):
            # Test proper error raised when looking up a token out of bounds
            v = Vocab(c)
            v.lookup_token(100)
예제 #5
0
 def __init__(self, spm_file):
     super(PretrainedSPVocab, self).__init__()
     self.sp_model = load_sp_model(spm_file)
     unk_id = self.sp_model.unk_id()
     unk_token = self.sp_model.IdToPiece(unk_id)
     vocab_list = [
         self.sp_model.IdToPiece(i)
         for i in range(self.sp_model.GetPieceSize())
     ]
     self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]),
                        unk_token=unk_token)
예제 #6
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_vocab_lookup_indices(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c, specials_first=False)

        tokens = ['b', 'a', 'c']
        expected_indices = [1, 0, 2]

        self.assertEqual(v.lookup_indices(tokens), expected_indices)
예제 #7
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_has_unk(self):
        c = OrderedDict({})
        v = Vocab(c)

        # check if unk is mapped to the first index
        self.assertEqual(v['not_in_it'], 0)
        self.assertEqual(v['<unk>'], 0)
예제 #8
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_new_unk(self):
        c = OrderedDict({})
        v = Vocab(c, specials=('<new_unk>', ), unk_token="<new_unk>")

        # check if new_unk is mapped to the first index
        self.assertEqual(v['<new_unk>'], 0)
        self.assertEqual(v['not_in_it'], 0)
예제 #9
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_vocab_len(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c)

        self.assertEqual(len(v), 5)
예제 #10
0
    def test_vocab_basic(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c, min_freq=3)

        expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
def make_vocab(text, vocab_size):
    counts = collections.Counter()
    for tokens in text:
        for token in tokens:
            counts[token] += 1
    _, max_count = counts.most_common(1)[0]
    counts['<pad>'] = max_count + 2
    counts['<unk>'] = max_count + 1
    vocab = Vocab(collections.OrderedDict(counts.most_common(vocab_size)))
    return vocab
예제 #12
0
    def test_vocab_get_item(self):
        token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c, min_freq=2)

        self.assertEqual(v['<unk>'], 0)
        self.assertEqual(v['a'], 1)
        self.assertEqual(v['b'], 2)
예제 #13
0
class PretrainedSPVocab(nn.Module):
    r"""Vocab based on a pretained sentencepiece model
    """
    def __init__(self, spm_file):
        super(PretrainedSPVocab, self).__init__()
        self.sp_model = load_sp_model(spm_file)
        unk_id = self.sp_model.unk_id()
        unk_token = self.sp_model.IdToPiece(unk_id)
        vocab_list = [
            self.sp_model.IdToPiece(i)
            for i in range(self.sp_model.GetPieceSize())
        ]
        self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]),
                           unk_token=unk_token)

    def forward(self, tokens: List[str]) -> List[int]:
        return self.vocab.lookup_indices(tokens)

    def insert_token(self, token: str, index: int) -> None:
        self.vocab.insert_token(token, index)
예제 #14
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_vocab_set_item(self):
        c = OrderedDict({'a': 2})

        # add item to end
        v = Vocab(c)
        v.insert_token('b', 3)

        self.assertEqual(v['<unk>'], 0)
        self.assertEqual(v['<pad>'], 1)
        self.assertEqual(v['a'], 2)
        self.assertEqual(v['b'], 3)

        # add item to middle
        v = Vocab(c, specials_first=False)
        v.insert_token('b', 0)

        self.assertEqual(v['b'], 0)
        self.assertEqual(v['a'], 1)
        self.assertEqual(v['<unk>'], 2)
        self.assertEqual(v['<pad>'], 3)
예제 #15
0
    def test_vocab_specials_order(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)

        # add specials into vocabulary at first
        v = Vocab(c, specials=['<pad>', '<unk>'])
        expected_itos = ['<pad>', '<unk>', 'a', 'b', 'c']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        # add specials into vocabulary at last
        v = Vocab(c, specials=['<pad>', '<unk>'], specials_first=False)
        expected_itos = ['a', 'b', 'c', '<pad>', '<unk>']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
예제 #16
0
    def test_vocab_insert_token(self):
        c = OrderedDict({'<unk>': 2, 'a': 2})

        # add item to end
        v = Vocab(c)
        v.insert_token('b', 2)

        expected_itos = ['<unk>', 'a', 'b']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        # add item to middle
        v = Vocab(c)
        v.insert_token('b', 0)

        expected_itos = ['b', '<unk>', 'a']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
예제 #17
0
파일: test_vocab.py 프로젝트: zsn-life/text
    def test_errors(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)

        with self.assertRaises(ValueError):
            # Test proper error raised when setting unk token to None
            Vocab(c, specials=['<unk>', '<bos>'], unk_token=None)

        with self.assertRaises(ValueError):
            # Test proper error raised when specials token doesn't contain unk_token
            Vocab(c, specials=['<pad>', '<bos>'])

        with self.assertRaises(ValueError):
            # Test proper error raised when ordered_dict contains a special token
            updated_token_to_freq = {
                'hello': 4,
                'world': 3,
                'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
                'freq_too_low': 2,
                '<pad>': 1
            }
            updated_sorted_by_freq_tuples = sorted(
                updated_token_to_freq.items(),
                key=lambda x: x[1],
                reverse=True)
            updated_c = OrderedDict(updated_sorted_by_freq_tuples)
            Vocab(updated_c, specials=['<unk>', '<pad>', '<bos>'])

        with self.assertRaises(RuntimeError):
            # Test proper error raised when setting a token out of bounds
            v = Vocab(c, min_freq=3)
            v.insert_token('new_token', 100)

        with self.assertRaises(RuntimeError):
            # Test proper error raised when looking up a token out of bounds
            v = Vocab(c)
            v.lookup_token(100)
예제 #18
0
def build_fairseq_vocab(
    vocab_file: str,
    dictionary_class: Dictionary = Dictionary,
    special_token_replacements: Dict[str, str] = None,
    unk_token: str = "<unk>",
    max_vocab: int = -1,
    min_count: int = -1,
    tokens_to_add: Optional[List[str]] = None,
):
    """Function builds a torchtext Vocab for models pre-trained using Fairseq
    modules.

    The dictionary class can take any Fairseq Dictionary class and is
    used to load the vocab file.

    """
    if not special_token_replacements:
        special_token_replacements = {
            "<pad>": "__PAD__",
            "<s>": "__BEGIN_OF_SENTENCE__",
            "</s>": "__END_OF_SENTENCE__",
            "<unk>": "__UNKNOWN__",
            "<mask>": "__MASK__",
        }
        unk_replacement = special_token_replacements[
            unk_token] if unk_token in special_token_replacements else unk_token
        special_tokens_to_remove = [
            special_pair[0] for special_pair in special_token_replacements
        ]
        special_tokens_to_add = tuple(
            special_pair[1] for special_pair in special_token_replacements
            if special_pair[0] != unk_token)

    with open(vocab_file) as f:
        dictionary = dictionary_class.load(f)
        # finalize will sort the dict based on frequency so only do this if
        # a min_count or max_vocab size is specified
        if min_count > 0 or max_vocab > 0:
            dictionary.finalize(threshold=min_count,
                                nwords=max_vocab,
                                padding_factor=1)
        if tokens_to_add:
            for token in tokens_to_add:
                dictionary.add_symbol(token)

        dictionary_items = list(zip(dictionary.symbols, dictionary.count))

        ordered_dict = OrderedDict()
        # add special tokens to beginning of ordered_dict
        for s in special_tokens_to_add:
            ordered_dict[s] = 1

        # add all other tokens from dictionary_items
        for token, freq in dictionary_items:
            ordered_dict[token] = freq

        # remove special_tokens_to_remove from dict
        for s in special_tokens_to_remove:
            if s in ordered_dict:
                del ordered_dict[s]

        return Vocab(dictionary_items, unk_token=unk_replacement)