Ejemplo n.º 1
0
    def test_errors(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)

        with self.assertRaises(ValueError):
            # Test proper error raised when setting unk token to None
            vocab(c, unk_token=None)

        with self.assertRaises(RuntimeError):
            # Test proper error raised when setting a token out of bounds
            v = vocab(c, min_freq=3)
            v.insert_token('new_token', 100)

        with self.assertRaises(RuntimeError):
            # Test proper error raised when looking up a token out of bounds
            v = vocab(c)
            v.lookup_token(100)
Ejemplo n.º 2
0
    def test_errors_vocab_python(self):
        token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)

        with self.assertRaises(ValueError):
            # Test proper error raised when setting unk token to None
            vocab(c, unk_token=None)
Ejemplo n.º 3
0
    def test_has_unk(self):
        c = OrderedDict()
        v = vocab(c)

        # check if unk is mapped to the first index
        self.assertEqual(v['not_in_it'], 0)
        self.assertEqual(v['<unk>'], 0)
Ejemplo n.º 4
0
    def test_vocab_jit(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)
        jit_v = torch.jit.script(v)

        expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        assert not v.is_jitable
        # Call the __prepare_scriptable__() func and convert the building block to the torbhind version
        # Not expect users to use the torchbind version on eager mode but still need a CI test here.
        assert v.__prepare_scriptable__().is_jitable

        self.assertEqual(jit_v.get_itos(), expected_itos)
        self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
Ejemplo n.º 5
0
    def test_new_unk(self):
        c = OrderedDict()
        v = vocab(c, unk_token="<new_unk>")

        # check if new_unk is mapped to the first index
        self.assertEqual(v['<new_unk>'], 0)
        self.assertEqual(v['not_in_it'], 0)
Ejemplo n.º 6
0
    def test_vocab_len(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)

        self.assertEqual(len(v), 4)
Ejemplo n.º 7
0
    def test_vocab_append_token(self):
        c = OrderedDict({'a': 2})
        v = vocab(c)
        v.append_token('b')

        self.assertEqual(len(v), 3)
        self.assertEqual(v['b'], 2)
Ejemplo n.º 8
0
    def test_vocab_load_and_save(self):
        token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)

        expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        with self.subTest('pybind'):
            vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
            torch.save(v, vocab_path)
            loaded_v = torch.load(vocab_path)
            self.assertEqual(v.get_itos(), expected_itos)
            self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

        with self.subTest('torchscript'):
            vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
            # Call the __prepare_scriptable__() func and convert the building block to the torbhind version
            # Not expect users to use the torchbind version on eager mode but still need a CI test here.
            torch.save(v.__prepare_scriptable__(), vocab_path)
            loaded_v = torch.load(vocab_path)
            self.assertEqual(v.get_itos(), expected_itos)
            self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
Ejemplo n.º 9
0
    def test_vocab_load_and_save(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)

        expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        vocab_path = os.path.join(self.test_dir, 'vocab.pt')
        torch.save(v, vocab_path)
        loaded_v = torch.load(vocab_path)

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
Ejemplo n.º 10
0
    def test_vocab_get_item(self):
        token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=2)

        self.assertEqual(v['<unk>'], 0)
        self.assertEqual(v['a'], 1)
        self.assertEqual(v['b'], 2)
Ejemplo n.º 11
0
    def test_vocab_append_token(self):
        c = OrderedDict({'a': 2})
        v = vocab(c)
        v.append_token('b')

        expected_itos = ['<unk>', 'a', 'b']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
Ejemplo n.º 12
0
    def test_vocab_membership(self):
        token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=2)

        self.assertTrue('<unk>' in v)
        self.assertTrue('a' in v)
        self.assertTrue('b' in v)
        self.assertFalse('c' in v)
Ejemplo n.º 13
0
    def test_vocab_lookup_indices(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)

        tokens = ['b', 'a', 'c']
        expected_indices = [2, 1, 3]

        self.assertEqual(v.lookup_indices(tokens), expected_indices)
Ejemplo n.º 14
0
def script_vocab(ordered_dict,
                 pad_token=None,
                 bos_token=None,
                 eos_token=None,
                 mask_token=None,
                 **kwargs):

    v = vocab(ordered_dict, **kwargs)
    return ScriptVocab(v.vocab, pad_token, bos_token, eos_token, mask_token,
                       **kwargs)
Ejemplo n.º 15
0
def build_fairseq_vocab(
        vocab_file: str,
        dictionary_class: Dictionary = Dictionary,
        special_token_replacements: Dict[str, str] = None,
        unk_token: str = "<unk>",
        max_vocab: int = -1,
        min_count: int = -1,
        tokens_to_add: Optional[List[str]] = None,
):
    """Function builds a torchtext Vocab for models pre-trained using Fairseq
    modules.
    The dictionary class can take any Fairseq Dictionary class and is
    used to load the vocab file.
    """
    if not special_token_replacements:
        special_token_replacements = {
            "<pad>": "__PAD__",
            "<s>": "__BEGIN_OF_SENTENCE__",
            "</s>": "__END_OF_SENTENCE__",
            "<unk>": "__UNKNOWN__",
            "<mask>": "__MASK__",
        }
        unk_replacement = special_token_replacements[
            unk_token] if unk_token in special_token_replacements else unk_token
        special_tokens_to_remove = [special_pair[0] for special_pair in special_token_replacements]
        special_tokens_to_add = tuple(
            special_pair[1] for special_pair in special_token_replacements if special_pair[0] != unk_token)

    with open(vocab_file) as f:
        dictionary = dictionary_class.load(f)
        # finalize will sort the dict based on frequency so only do this if
        # a min_count or max_vocab size is specified
        if min_count > 0 or max_vocab > 0:
            dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1)
        if tokens_to_add:
            for token in tokens_to_add:
                dictionary.add_symbol(token)

        dictionary_items = list(zip(dictionary.symbols, dictionary.count))

        ordered_dict = OrderedDict()
        # add special tokens to beginning of ordered_dict
        for s in special_tokens_to_add:
            ordered_dict[s] = 1

        # add all other tokens from dictionary_items
        for token, freq in dictionary_items:
            ordered_dict[token] = freq

        # remove special_tokens_to_remove from dict
        for s in special_tokens_to_remove:
            if s in ordered_dict:
                del ordered_dict[s]

        return vocab(ordered_dict, unk_token=unk_replacement)
Ejemplo n.º 16
0
 def __init__(self, sp_model):
     super(PretrainedSPVocab, self).__init__()
     self.sp_model = sp_model
     unk_id = self.sp_model.unk_id()
     unk_token = self.sp_model.IdToPiece(unk_id)
     vocab_list = [
         self.sp_model.IdToPiece(i)
         for i in range(self.sp_model.GetPieceSize())
     ]
     self.vocab = vocab(OrderedDict([(token, 1) for token in vocab_list]),
                        unk_token=unk_token)
Ejemplo n.º 17
0
    def test_vocab_basic(self):
        token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)

        expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
Ejemplo n.º 18
0
    def test_vocab_insert_token(self):
        c = OrderedDict({'<unk>': 2, 'a': 2})

        # add item to end
        v = vocab(c)
        v.insert_token('b', 2)

        expected_itos = ['<unk>', 'a', 'b']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        # add item to middle
        v = vocab(c)
        v.insert_token('b', 0)

        expected_itos = ['b', '<unk>', 'a']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
Ejemplo n.º 19
0
    def test_vocab_forward(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)
        jit_v = torch.jit.script(v)

        tokens = ['b', 'a', 'c']
        expected_indices = [2, 1, 3]

        self.assertEqual(v(tokens), expected_indices)
        self.assertEqual(jit_v(tokens), expected_indices)
Ejemplo n.º 20
0
    def test_vocab_jit(self):
        token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)
        jit_v = torch.jit.script(v.to_ivalue())

        expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        assert not v.is_jitable
        assert v.to_ivalue().is_jitable

        self.assertEqual(jit_v.get_itos(), expected_itos)
        self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)