Ejemplo n.º 1
0
    def test_vocab_jit(self):
        token_to_freq = {
            'hello': 4,
            'world': 3,
            'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
            'freq_too_low': 2
        }
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)
        jit_v = torch.jit.script(v)

        expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        assert not v.is_jitable
        # Call the __prepare_scriptable__() func and convert the building block to the torbhind version
        # Not expect users to use the torchbind version on eager mode but still need a CI test here.
        assert v.__prepare_scriptable__().is_jitable

        self.assertEqual(jit_v.get_itos(), expected_itos)
        self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
Ejemplo n.º 2
0
    def test_vocab_load_and_save(self):
        token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)
        v.set_default_index(0)

        expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        with self.subTest('pybind'):
            vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
            torch.save(v, vocab_path)
            loaded_v = torch.load(vocab_path)
            self.assertEqual(v.get_itos(), expected_itos)
            self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
            self.assertEqual(v['not in vocab'], 0)

        with self.subTest('torchscript'):
            vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
            # Call the __prepare_scriptable__() func and convert the building block to the torbhind version
            # Not expect users to use the torchbind version on eager mode but still need a CI test here.
            torch.save(v.__prepare_scriptable__(), vocab_path)
            loaded_v = torch.load(vocab_path)
            self.assertEqual(v.get_itos(), expected_itos)
            self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
            self.assertEqual(v['not in vocab'], 0)
Ejemplo n.º 3
0
    def test_vocab_len(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)

        self.assertEqual(len(v), 3)
Ejemplo n.º 4
0
 def __init__(self, sp_model):
     super(PretrainedSPVocab, self).__init__()
     self.sp_model = sp_model
     unk_id = self.sp_model.unk_id()
     unk_token = self.sp_model.IdToPiece(unk_id)
     vocab_list = [self.sp_model.IdToPiece(i) for i in range(self.sp_model.GetPieceSize())]
     self.vocab = vocab(OrderedDict([(token, 1) for token in vocab_list]), unk_token=unk_token)
Ejemplo n.º 5
0
def data_process(raw_text_iter):
    # tokenizer to seg text, and vocab to trans to num
    data = [
        torch.tensor(vocab(tokenizer(item)), dtype=torch.long)
        for item in raw_text_iter
    ]
    # discard 0 element text and cat them. numel func to vector element num.
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
Ejemplo n.º 6
0
 def test_default_index_jit(self):
     token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
     sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
     c = OrderedDict(sorted_by_freq_tuples)
     v = vocab(c, min_freq=2)
     v.set_default_index(0)
     v_jit = torch.jit.script(v)
     self.assertEqual(v_jit['not in vocab'], 0)
Ejemplo n.º 7
0
    def test_vocab_lookup_token(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)

        self.assertEqual(v.lookup_token(1), 'b')
        with self.assertRaises(RuntimeError):
            v.lookup_token(100)
Ejemplo n.º 8
0
    def test_vocab_get_item(self):
        token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=2)

        self.assertEqual(v['<unk>'], 0)
        self.assertEqual(v['a'], 1)
        self.assertEqual(v['b'], 2)
Ejemplo n.º 9
0
def script_vocab(ordered_dict,
                 pad_token=None,
                 bos_token=None,
                 eos_token=None,
                 mask_token=None,
                 **kwargs):

    v = vocab(ordered_dict, **kwargs)
    return ScriptVocab(v.vocab, pad_token, bos_token, eos_token, mask_token,
                       **kwargs)
Ejemplo n.º 10
0
    def test_vocab_lookup_indices(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)

        tokens = ['b', 'a', 'c']
        expected_indices = [1, 0, 2]

        self.assertEqual(v.lookup_indices(tokens), expected_indices)
Ejemplo n.º 11
0
    def test_vocab_membership(self):
        token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=2)

        self.assertTrue('<unk>' in v)
        self.assertTrue('a' in v)
        self.assertTrue('b' in v)
        self.assertFalse('c' in v)
Ejemplo n.º 12
0
def build_vocab(dataset, tokenizer, use_padding):
    counter = Counter()
    for i in range(len(dataset)):
        counter.update(tokenizer(dataset[i][0]))
    builded_voc = vocab(counter)
    if use_padding:
        builded_voc.append_token("<pad>")
    builded_voc.insert_token("<unk>", 0)
    builded_voc.set_default_index(0)
    return builded_voc
Ejemplo n.º 13
0
    def test_default_index(self):
        token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=2)

        self.assertTrue(v.get_default_index() is None)
        with self.assertRaises(RuntimeError):
            v['not in vocab']

        v.set_default_index(0)
        self.assertEqual(v['not in vocab'], 0)
Ejemplo n.º 14
0
    def test_vocab_basic(self):
        token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=3)

        expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
Ejemplo n.º 15
0
    def test_vocab_insert_token(self):
        c = OrderedDict({'<unk>': 2, 'a': 2})

        # add item to end
        v = vocab(c)
        v.insert_token('b', 2)

        expected_itos = ['<unk>', 'a', 'b']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        # add item to middle
        v = vocab(c)
        v.insert_token('b', 0)

        expected_itos = ['b', '<unk>', 'a']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)
Ejemplo n.º 16
0
    def test_vocab_forward(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c)
        jit_v = torch.jit.script(v)

        tokens = ['b', 'a', 'c']
        expected_indices = [1, 0, 2]

        self.assertEqual(v(tokens), expected_indices)
        self.assertEqual(jit_v(tokens), expected_indices)
Ejemplo n.º 17
0
    def test_vocab_append_token(self):
        c = OrderedDict({'a': 2})
        v = vocab(c)
        v.append_token('b')

        expected_itos = ['a', 'b']
        expected_stoi = {x: index for index, x in enumerate(expected_itos)}

        self.assertEqual(v.get_itos(), expected_itos)
        self.assertEqual(dict(v.get_stoi()), expected_stoi)

        # token must not exist to be appended
        with self.assertRaises(RuntimeError):
            v.append_token('b')
Ejemplo n.º 18
0
def build_vocab(dataset, tokenizer):
    counter = Counter()
    size = len(dataset)
    for i in range(size):
        text, label = dataset[i]
        counter.update(tokenizer(text))

    sorted_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
    ordered_dict = OrderedDict(sorted_tuples)
    v = vocab(ordered_dict)
    pad_token = '<PAD>'
    unk_token = '<UNK>'
    v.insert_token(pad_token, 0)
    v.insert_token(unk_token, 1)
    v.set_default_index(v[unk_token])
    return v
Ejemplo n.º 19
0
    def test_reassign_token(self):
        token_to_freq = {'<unk>': 1, 'a': 2, 'b': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = vocab(c, min_freq=1)

        self.assertEqual(v['<unk>'], 2)
        self.assertEqual(v['a'], 0)
        self.assertEqual(v['b'], 1)
        v.reassign_token('<unk>', 0)
        self.assertEqual(v['<unk>'], 0)
        self.assertEqual(v['a'], 1)
        self.assertEqual(v['b'], 2)

        self.assertEqual(v.get_itos(), ['<unk>', 'a', 'b'])

        with self.assertRaises(RuntimeError):
            v.reassign_token('not in vocab', 0)

        with self.assertRaises(RuntimeError):
            v.reassign_token('<unk>', 3)