def test_vocab_jit(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) jit_v = torch.jit.script(v) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} assert not v.is_jitable # Call the __prepare_scriptable__() func and convert the building block to the torbhind version # Not expect users to use the torchbind version on eager mode but still need a CI test here. assert v.__prepare_scriptable__().is_jitable self.assertEqual(jit_v.get_itos(), expected_itos) self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
def test_vocab_load_and_save(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) v.set_default_index(0) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) with self.subTest('pybind'): vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt') torch.save(v, vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) self.assertEqual(v['not in vocab'], 0) with self.subTest('torchscript'): vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt') # Call the __prepare_scriptable__() func and convert the building block to the torbhind version # Not expect users to use the torchbind version on eager mode but still need a CI test here. torch.save(v.__prepare_scriptable__(), vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) self.assertEqual(v['not in vocab'], 0)
def test_vocab_len(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) self.assertEqual(len(v), 3)
def __init__(self, sp_model): super(PretrainedSPVocab, self).__init__() self.sp_model = sp_model unk_id = self.sp_model.unk_id() unk_token = self.sp_model.IdToPiece(unk_id) vocab_list = [self.sp_model.IdToPiece(i) for i in range(self.sp_model.GetPieceSize())] self.vocab = vocab(OrderedDict([(token, 1) for token in vocab_list]), unk_token=unk_token)
def data_process(raw_text_iter): # tokenizer to seg text, and vocab to trans to num data = [ torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter ] # discard 0 element text and cat them. numel func to vector element num. return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
def test_default_index_jit(self): token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=2) v.set_default_index(0) v_jit = torch.jit.script(v) self.assertEqual(v_jit['not in vocab'], 0)
def test_vocab_lookup_token(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) self.assertEqual(v.lookup_token(1), 'b') with self.assertRaises(RuntimeError): v.lookup_token(100)
def test_vocab_get_item(self): token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=2) self.assertEqual(v['<unk>'], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2)
def script_vocab(ordered_dict, pad_token=None, bos_token=None, eos_token=None, mask_token=None, **kwargs): v = vocab(ordered_dict, **kwargs) return ScriptVocab(v.vocab, pad_token, bos_token, eos_token, mask_token, **kwargs)
def test_vocab_lookup_indices(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) tokens = ['b', 'a', 'c'] expected_indices = [1, 0, 2] self.assertEqual(v.lookup_indices(tokens), expected_indices)
def test_vocab_membership(self): token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=2) self.assertTrue('<unk>' in v) self.assertTrue('a' in v) self.assertTrue('b' in v) self.assertFalse('c' in v)
def build_vocab(dataset, tokenizer, use_padding): counter = Counter() for i in range(len(dataset)): counter.update(tokenizer(dataset[i][0])) builded_voc = vocab(counter) if use_padding: builded_voc.append_token("<pad>") builded_voc.insert_token("<unk>", 0) builded_voc.set_default_index(0) return builded_voc
def test_default_index(self): token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=2) self.assertTrue(v.get_default_index() is None) with self.assertRaises(RuntimeError): v['not in vocab'] v.set_default_index(0) self.assertEqual(v['not in vocab'], 0)
def test_vocab_basic(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)
def test_vocab_insert_token(self): c = OrderedDict({'<unk>': 2, 'a': 2}) # add item to end v = vocab(c) v.insert_token('b', 2) expected_itos = ['<unk>', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = vocab(c) v.insert_token('b', 0) expected_itos = ['b', '<unk>', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)
def test_vocab_forward(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) jit_v = torch.jit.script(v) tokens = ['b', 'a', 'c'] expected_indices = [1, 0, 2] self.assertEqual(v(tokens), expected_indices) self.assertEqual(jit_v(tokens), expected_indices)
def test_vocab_append_token(self): c = OrderedDict({'a': 2}) v = vocab(c) v.append_token('b') expected_itos = ['a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # token must not exist to be appended with self.assertRaises(RuntimeError): v.append_token('b')
def build_vocab(dataset, tokenizer): counter = Counter() size = len(dataset) for i in range(size): text, label = dataset[i] counter.update(tokenizer(text)) sorted_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_tuples) v = vocab(ordered_dict) pad_token = '<PAD>' unk_token = '<UNK>' v.insert_token(pad_token, 0) v.insert_token(unk_token, 1) v.set_default_index(v[unk_token]) return v
def test_reassign_token(self): token_to_freq = {'<unk>': 1, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=1) self.assertEqual(v['<unk>'], 2) self.assertEqual(v['a'], 0) self.assertEqual(v['b'], 1) v.reassign_token('<unk>', 0) self.assertEqual(v['<unk>'], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) self.assertEqual(v.get_itos(), ['<unk>', 'a', 'b']) with self.assertRaises(RuntimeError): v.reassign_token('not in vocab', 0) with self.assertRaises(RuntimeError): v.reassign_token('<unk>', 3)