def test_errors(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) with self.assertRaises(ValueError): # Test proper error raised when setting unk token to None Vocab(c, unk_token=None) with self.assertRaises(RuntimeError): # Test proper error raised when setting a token out of bounds v = Vocab(c, min_freq=3) v.insert_token('new_token', 100) with self.assertRaises(RuntimeError): # Test proper error raised when looking up a token out of bounds v = Vocab(c) v.lookup_token(100)
def test_errors(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) with self.assertRaises(ValueError): # Test proper error raised when setting unk token to None Vocab(c, specials=['<unk>', '<bos>'], unk_token=None) with self.assertRaises(ValueError): # Test proper error raised when specials token doesn't contain unk_token Vocab(c, specials=['<pad>', '<bos>']) with self.assertRaises(ValueError): # Test proper error raised when ordered_dict contains a special token updated_token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2, '<pad>': 1 } updated_sorted_by_freq_tuples = sorted( updated_token_to_freq.items(), key=lambda x: x[1], reverse=True) updated_c = OrderedDict(updated_sorted_by_freq_tuples) Vocab(updated_c, specials=['<unk>', '<pad>', '<bos>']) with self.assertRaises(RuntimeError): # Test proper error raised when setting a token out of bounds v = Vocab(c, min_freq=3) v.insert_token('new_token', 100) with self.assertRaises(RuntimeError): # Test proper error raised when looking up a token out of bounds v = Vocab(c) v.lookup_token(100)
def test_vocab_set_item(self): c = OrderedDict({'a': 2}) # add item to end v = Vocab(c) v.insert_token('b', 3) self.assertEqual(v['<unk>'], 0) self.assertEqual(v['<pad>'], 1) self.assertEqual(v['a'], 2) self.assertEqual(v['b'], 3) # add item to middle v = Vocab(c, specials_first=False) v.insert_token('b', 0) self.assertEqual(v['b'], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['<unk>'], 2) self.assertEqual(v['<pad>'], 3)
class PretrainedSPVocab(nn.Module): r"""Vocab based on a pretained sentencepiece model """ def __init__(self, spm_file): super(PretrainedSPVocab, self).__init__() self.sp_model = load_sp_model(spm_file) unk_id = self.sp_model.unk_id() unk_token = self.sp_model.IdToPiece(unk_id) vocab_list = [ self.sp_model.IdToPiece(i) for i in range(self.sp_model.GetPieceSize()) ] self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]), unk_token=unk_token) def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices(tokens) def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index)
def test_vocab_insert_token(self): c = OrderedDict({'<unk>': 2, 'a': 2}) # add item to end v = Vocab(c) v.insert_token('b', 2) expected_itos = ['<unk>', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = Vocab(c) v.insert_token('b', 0) expected_itos = ['b', '<unk>', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)