def test_vocab_lookup_indices(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) v = Vocab(c, specials_first=False) tokens = ['b', 'a', 'c'] expected_indices = [1, 0, 2] self.assertEqual(v.lookup_indices(tokens), expected_indices)
class PretrainedSPVocab(nn.Module): r"""Vocab based on a pretained sentencepiece model """ def __init__(self, spm_file): super(PretrainedSPVocab, self).__init__() self.sp_model = load_sp_model(spm_file) unk_id = self.sp_model.unk_id() unk_token = self.sp_model.IdToPiece(unk_id) vocab_list = [ self.sp_model.IdToPiece(i) for i in range(self.sp_model.GetPieceSize()) ] self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]), unk_token=unk_token) def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices(tokens) def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index)