def bpe_token(self, token: str) -> List[str]: # If full token is in vocab, we're done. full_token = token + self.eow # `in` not implemented, this should be read `if full_token in self.vocab` if self.vocab.get(full_token) is not None: return [full_token] # Split word into parts, with the last part having EOW attached. # Any part (character or char + EOW) not in the vocab on its own # should be removed. EOW should always be attached to the last remaining # token. parts = utf8_chars(token) # parts and parts[-1] + self.eow not in self.vocab while len(parts) > 0 and self.vocab.get(parts[-1] + self.eow) is None: parts.pop() # The word consisted entirely of unknown characters if len(parts) == 0: return [self.eow] parts[-1] += self.eow # Remove any other obscure characters not in the vocab. # No easy way to iterate backwards or create descending ranges, # so using a while loop. i = 0 while i < len(parts): # parts[i] not in self.vocab if self.vocab.get(parts[i]) is None: parts.pop(i) else: i += 1 # We compare vocab dict scores to this value, so this is where we assume # vocab dict values are non-negative. NOT_IN_VOCAB = -1 # break not implemented should_break = False # Keep going until no more part pairs are in the vocab. # In obscure cases this could also get down to a single token, eg. if # we filter out some character and rebuild up to a single token. while len(parts) > 1 and not should_break: # Create part pairs, join part pair with highest score in vocab. # In pure python, this could be implemented as # max(range(len(parts) - 1), # key=lambda i: self.vocab.get(parts[i] + parts[i+1], -1))) max_pair_index = 0 max_pair_value = NOT_IN_VOCAB # We structure the vocabulary to not have ties, but they can come up anyway, # for instance in cases with repeated tokens or when passing in vocabs not # created with BPE.load_vocab. In the case of a tie between the value of # joined segments, they'll be joined proiritizing the first pair in the # token according to byte order, ie. left in LTR and right in RTL languages. # For instance, if the vocab contains "aa" but not "aaa", then # bpe_tokens("aaa") -> ["aa", "a"]. If the vocab contains "ab" and "bc" # mapped to the same priority, but not "abc", then # bpe_tokens("abc") -> ["ab", "c"]. for pair_index in range(len(parts) - 1): joined = parts[pair_index] + parts[pair_index + 1] pair_value = self.vocab.get(joined, NOT_IN_VOCAB) if pair_value > max_pair_value: max_pair_value = pair_value max_pair_index = pair_index if max_pair_value == NOT_IN_VOCAB: # No pairs found in vocab, we're done! should_break = True else: # break, continue not supported; only run this block if we wouldn't # want to break out after the above step # Combine parts pair with highest priority in vocab. # len(parts) shrinks by 1 each iteration, so we should be bounded # as linear in token length. # Subscript assignment not implemented. p1, p2 = parts[max_pair_index:max_pair_index + 2] parts = parts[:max_pair_index] + [ p1 + p2 ] + parts[max_pair_index + 2:] return parts
def test_utf8_chars(self): words = ["hello", "💩", "¯\\_(ツ)_/¯", "今日"] for word in words: self.assertEqual(list(word), utf8_chars(word))