Example #1
0
    def bpe_token(self, token: str) -> List[str]:
        # If full token is in vocab, we're done.
        full_token = token + self.eow
        # `in` not implemented, this should be read `if full_token in self.vocab`
        if self.vocab.get(full_token) is not None:
            return [full_token]

        # Split word into parts, with the last part having EOW attached.
        # Any part (character or char + EOW) not in the vocab on its own
        # should be removed. EOW should always be attached to the last remaining
        # token.
        parts = utf8_chars(token)

        # parts and parts[-1] + self.eow not in self.vocab
        while len(parts) > 0 and self.vocab.get(parts[-1] + self.eow) is None:
            parts.pop()
        # The word consisted entirely of unknown characters
        if len(parts) == 0:
            return [self.eow]
        parts[-1] += self.eow

        # Remove any other obscure characters not in the vocab.
        # No easy way to iterate backwards or create descending ranges,
        # so using a while loop.
        i = 0
        while i < len(parts):
            # parts[i] not in self.vocab
            if self.vocab.get(parts[i]) is None:
                parts.pop(i)
            else:
                i += 1

        # We compare vocab dict scores to this value, so this is where we assume
        # vocab dict values are non-negative.
        NOT_IN_VOCAB = -1
        # break not implemented
        should_break = False

        # Keep going until no more part pairs are in the vocab.
        # In obscure cases this could also get down to a single token, eg. if
        # we filter out some character and rebuild up to a single token.
        while len(parts) > 1 and not should_break:
            # Create part pairs, join part pair with highest score in vocab.
            # In pure python, this could be implemented as
            # max(range(len(parts) - 1),
            #     key=lambda i: self.vocab.get(parts[i] + parts[i+1], -1)))
            max_pair_index = 0
            max_pair_value = NOT_IN_VOCAB
            # We structure the vocabulary to not have ties, but they can come up anyway,
            # for instance in cases with repeated tokens or when passing in vocabs not
            # created with BPE.load_vocab. In the case of a tie between the value of
            # joined segments, they'll be joined proiritizing the first pair in the
            # token according to byte order, ie. left in LTR and right in RTL languages.
            # For instance, if the vocab contains "aa" but not "aaa", then
            # bpe_tokens("aaa") -> ["aa", "a"]. If the vocab contains "ab" and "bc"
            # mapped to the same priority, but not "abc", then
            # bpe_tokens("abc") -> ["ab", "c"].
            for pair_index in range(len(parts) - 1):
                joined = parts[pair_index] + parts[pair_index + 1]
                pair_value = self.vocab.get(joined, NOT_IN_VOCAB)
                if pair_value > max_pair_value:
                    max_pair_value = pair_value
                    max_pair_index = pair_index

            if max_pair_value == NOT_IN_VOCAB:
                # No pairs found in vocab, we're done!
                should_break = True
            else:
                # break, continue not supported; only run this block if we wouldn't
                # want to break out after the above step

                # Combine parts pair with highest priority in vocab.
                # len(parts) shrinks by 1 each iteration, so we should be bounded
                # as linear in token length.
                # Subscript assignment not implemented.
                p1, p2 = parts[max_pair_index:max_pair_index + 2]
                parts = parts[:max_pair_index] + [
                    p1 + p2
                ] + parts[max_pair_index + 2:]

        return parts
 def test_utf8_chars(self):
     words = ["hello", "💩", "¯\\_(ツ)_/¯", "今日"]
     for word in words:
         self.assertEqual(list(word), utf8_chars(word))