Пример #1
0
    def test_vocab_lookup_indices(self):
        token_to_freq = {'a': 2, 'b': 2, 'c': 2}
        sorted_by_freq_tuples = sorted(token_to_freq.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        c = OrderedDict(sorted_by_freq_tuples)
        v = Vocab(c, specials_first=False)

        tokens = ['b', 'a', 'c']
        expected_indices = [1, 0, 2]

        self.assertEqual(v.lookup_indices(tokens), expected_indices)
Пример #2
0
class PretrainedSPVocab(nn.Module):
    r"""Vocab based on a pretained sentencepiece model
    """
    def __init__(self, spm_file):
        super(PretrainedSPVocab, self).__init__()
        self.sp_model = load_sp_model(spm_file)
        unk_id = self.sp_model.unk_id()
        unk_token = self.sp_model.IdToPiece(unk_id)
        vocab_list = [
            self.sp_model.IdToPiece(i)
            for i in range(self.sp_model.GetPieceSize())
        ]
        self.vocab = Vocab(OrderedDict([(token, 1) for token in vocab_list]),
                           unk_token=unk_token)

    def forward(self, tokens: List[str]) -> List[int]:
        return self.vocab.lookup_indices(tokens)

    def insert_token(self, token: str, index: int) -> None:
        self.vocab.insert_token(token, index)