def test_vocab_load_and_save(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, min_freq=3) expected_itos = ['<unk>', '<pad>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) vocab_path = os.path.join(self.test_dir, 'vocab.pt') torch.save(v, vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
def test_vocab_append_token(self): c = OrderedDict({'a': 2}) v = Vocab(c) v.append_token('b') self.assertEqual(len(v), 4) self.assertEqual(v['b'], 3)
def test_vocab_lookup_token(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, specials_first=False) self.assertEqual(v.lookup_token(0), 'a')
def test_errors(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) with self.assertRaises(ValueError): # Test proper error raised when setting unk token to None Vocab(c, unk_token=None) with self.assertRaises(RuntimeError): # Test proper error raised when setting a token out of bounds v = Vocab(c, min_freq=3) v.insert_token('new_token', 100) with self.assertRaises(RuntimeError): # Test proper error raised when looking up a token out of bounds v = Vocab(c) v.lookup_token(100)
def __init__(self, spm_file): super(PretrainedSPVocab, self).__init__() self.sp_model = load_sp_model(spm_file) unk_id = self.sp_model.unk_id() unk_token = self.sp_model.IdToPiece(unk_id) vocab_list = [ self.sp_model.IdToPiece(i) for i in range(self.sp_model.GetPieceSize()) ] self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]), unk_token=unk_token)
def test_vocab_lookup_indices(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, specials_first=False) tokens = ['b', 'a', 'c'] expected_indices = [1, 0, 2] self.assertEqual(v.lookup_indices(tokens), expected_indices)
def test_has_unk(self): c = OrderedDict({}) v = Vocab(c) # check if unk is mapped to the first index self.assertEqual(v['not_in_it'], 0) self.assertEqual(v['<unk>'], 0)
def test_new_unk(self): c = OrderedDict({}) v = Vocab(c, specials=('<new_unk>', ), unk_token="<new_unk>") # check if new_unk is mapped to the first index self.assertEqual(v['<new_unk>'], 0) self.assertEqual(v['not_in_it'], 0)
def test_vocab_len(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c) self.assertEqual(len(v), 5)
def test_vocab_basic(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, min_freq=3) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)
def make_vocab(text, vocab_size): counts = collections.Counter() for tokens in text: for token in tokens: counts[token] += 1 _, max_count = counts.most_common(1)[0] counts['<pad>'] = max_count + 2 counts['<unk>'] = max_count + 1 vocab = Vocab(collections.OrderedDict(counts.most_common(vocab_size))) return vocab
def test_vocab_get_item(self): token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, min_freq=2) self.assertEqual(v['<unk>'], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2)
class PretrainedSPVocab(nn.Module): r"""Vocab based on a pretained sentencepiece model """ def __init__(self, spm_file): super(PretrainedSPVocab, self).__init__() self.sp_model = load_sp_model(spm_file) unk_id = self.sp_model.unk_id() unk_token = self.sp_model.IdToPiece(unk_id) vocab_list = [ self.sp_model.IdToPiece(i) for i in range(self.sp_model.GetPieceSize()) ] self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]), unk_token=unk_token) def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices(tokens) def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index)
def test_vocab_set_item(self): c = OrderedDict({'a': 2}) # add item to end v = Vocab(c) v.insert_token('b', 3) self.assertEqual(v['<unk>'], 0) self.assertEqual(v['<pad>'], 1) self.assertEqual(v['a'], 2) self.assertEqual(v['b'], 3) # add item to middle v = Vocab(c, specials_first=False) v.insert_token('b', 0) self.assertEqual(v['b'], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['<unk>'], 2) self.assertEqual(v['<pad>'], 3)
def test_vocab_specials_order(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) # add specials into vocabulary at first v = Vocab(c, specials=['<pad>', '<unk>']) expected_itos = ['<pad>', '<unk>', 'a', 'b', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add specials into vocabulary at last v = Vocab(c, specials=['<pad>', '<unk>'], specials_first=False) expected_itos = ['a', 'b', 'c', '<pad>', '<unk>'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)
def test_vocab_insert_token(self): c = OrderedDict({'<unk>': 2, 'a': 2}) # add item to end v = Vocab(c) v.insert_token('b', 2) expected_itos = ['<unk>', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = Vocab(c) v.insert_token('b', 0) expected_itos = ['b', '<unk>', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi)
def test_errors(self): token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2 } sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) with self.assertRaises(ValueError): # Test proper error raised when setting unk token to None Vocab(c, specials=['<unk>', '<bos>'], unk_token=None) with self.assertRaises(ValueError): # Test proper error raised when specials token doesn't contain unk_token Vocab(c, specials=['<pad>', '<bos>']) with self.assertRaises(ValueError): # Test proper error raised when ordered_dict contains a special token updated_token_to_freq = { 'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2, '<pad>': 1 } updated_sorted_by_freq_tuples = sorted( updated_token_to_freq.items(), key=lambda x: x[1], reverse=True) updated_c = OrderedDict(updated_sorted_by_freq_tuples) Vocab(updated_c, specials=['<unk>', '<pad>', '<bos>']) with self.assertRaises(RuntimeError): # Test proper error raised when setting a token out of bounds v = Vocab(c, min_freq=3) v.insert_token('new_token', 100) with self.assertRaises(RuntimeError): # Test proper error raised when looking up a token out of bounds v = Vocab(c) v.lookup_token(100)
def build_fairseq_vocab( vocab_file: str, dictionary_class: Dictionary = Dictionary, special_token_replacements: Dict[str, str] = None, unk_token: str = "<unk>", max_vocab: int = -1, min_count: int = -1, tokens_to_add: Optional[List[str]] = None, ): """Function builds a torchtext Vocab for models pre-trained using Fairseq modules. The dictionary class can take any Fairseq Dictionary class and is used to load the vocab file. """ if not special_token_replacements: special_token_replacements = { "<pad>": "__PAD__", "<s>": "__BEGIN_OF_SENTENCE__", "</s>": "__END_OF_SENTENCE__", "<unk>": "__UNKNOWN__", "<mask>": "__MASK__", } unk_replacement = special_token_replacements[ unk_token] if unk_token in special_token_replacements else unk_token special_tokens_to_remove = [ special_pair[0] for special_pair in special_token_replacements ] special_tokens_to_add = tuple( special_pair[1] for special_pair in special_token_replacements if special_pair[0] != unk_token) with open(vocab_file) as f: dictionary = dictionary_class.load(f) # finalize will sort the dict based on frequency so only do this if # a min_count or max_vocab size is specified if min_count > 0 or max_vocab > 0: dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1) if tokens_to_add: for token in tokens_to_add: dictionary.add_symbol(token) dictionary_items = list(zip(dictionary.symbols, dictionary.count)) ordered_dict = OrderedDict() # add special tokens to beginning of ordered_dict for s in special_tokens_to_add: ordered_dict[s] = 1 # add all other tokens from dictionary_items for token, freq in dictionary_items: ordered_dict[token] = freq # remove special_tokens_to_remove from dict for s in special_tokens_to_remove: if s in ordered_dict: del ordered_dict[s] return Vocab(dictionary_items, unk_token=unk_replacement)