def test_uncased_enc() -> None: """Encode text to token ids (case-insensitive).""" tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0) tknzr.build_vocab(batch_txt=['a']) # Return `[bos] [eos]` when given empty input. assert tknzr.enc(max_seq_len=2, txt='') == [BOS_TKID, EOS_TKID] # Encoding format. assert tknzr.enc(max_seq_len=4, txt='aA') == [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID ] # Padding. assert tknzr.enc(max_seq_len=5, txt='aA') == [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID, PAD_TKID ] # Truncate. assert tknzr.enc(max_seq_len=3, txt='aA') == [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'] ] # Unknown tokens. assert tknzr.enc(max_seq_len=4, txt='bB') == [BOS_TKID, UNK_TKID, UNK_TKID, EOS_TKID] # Unknown tokens with padding. assert tknzr.enc(max_seq_len=5, txt='bB') == [ BOS_TKID, UNK_TKID, UNK_TKID, EOS_TKID, PAD_TKID ] # Unknown tokens with truncation. assert tknzr.enc(max_seq_len=2, txt='bB') == [BOS_TKID, UNK_TKID]
def test_no_limit_build() -> None: """Include all tokens when ``max_vocab == -1``.""" tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0) CJK_txt = [chr(i) for i in range(ord('\u4e00'), ord('\u9fff') + 1)] norm_CJK_txt = [tknzr.norm(t) for t in CJK_txt] tknzr.build_vocab(CJK_txt) assert tknzr.vocab_size == len(set(norm_CJK_txt)) + 4 assert all(map(lambda tk: tk in tknzr.tk2id, norm_CJK_txt))
def test_empty_build() -> None: """Build nothing when given empty list.""" tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0) tknzr.build_vocab([]) assert tknzr.vocab_size == 4 assert tknzr.tk2id == { BOS_TK: BOS_TKID, EOS_TK: EOS_TKID, PAD_TK: PAD_TKID, UNK_TK: UNK_TKID }
def test_case_insensitive() -> None: """Must be case-insensitive when ``is_uncased = True``.""" tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0) tknzr.build_vocab(['a', 'A']) assert tknzr.tk2id == { BOS_TK: BOS_TKID, EOS_TK: EOS_TKID, PAD_TK: PAD_TKID, UNK_TK: UNK_TKID, 'a': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1, }
def test_minimum_occurrence_counts() -> None: """Must satisfy minumum occurrence counts to include tokens in vocabulary.""" tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=2) tknzr.build_vocab(['c', 'bc', 'abc']) assert tknzr.tk2id == { BOS_TK: BOS_TKID, EOS_TK: EOS_TKID, PAD_TK: PAD_TKID, UNK_TK: UNK_TKID, 'c': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1, 'b': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2, }
def test_sort_by_occurrence_counts() -> None: """Sort vocabulary by occurrence counts.""" tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0) tknzr.build_vocab(['c', 'bc', 'abc']) assert tknzr.tk2id == { BOS_TK: BOS_TKID, EOS_TK: EOS_TKID, PAD_TK: PAD_TKID, UNK_TK: UNK_TKID, 'c': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1, 'b': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2, 'a': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 3, }
def test_continue_build() -> None: """Build vocabulary based on existed vocabulary.""" tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0) tknzr.build_vocab(['a']) tknzr.build_vocab(['b']) tknzr.build_vocab(['c']) assert tknzr.tk2id == { BOS_TK: BOS_TKID, EOS_TK: EOS_TKID, PAD_TK: PAD_TKID, UNK_TK: UNK_TKID, 'a': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1, 'b': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2, 'c': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 3, }
def test_dec() -> None: """Decode token ids to text.""" tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0) tknzr.build_vocab(batch_txt=['A', 'a']) # Return empty string when given empty list. assert tknzr.dec(tkids=[]) == '' # Decoding format. assert tknzr.dec( tkids=[ BOS_TKID, tknzr.tk2id['a'], UNK_TKID, tknzr.tk2id['A'], EOS_TKID, PAD_TKID, ], rm_sp_tks=False, ) == f'{BOS_TK}a{UNK_TK}A{EOS_TK}{PAD_TK}' # Remove special tokens but not unknown tokens. assert tknzr.dec( tkids=[ BOS_TKID, tknzr.tk2id['a'], UNK_TKID, tknzr.tk2id['A'], UNK_TKID, EOS_TKID, PAD_TKID, ], rm_sp_tks=True, ) == f'a{UNK_TK}A{UNK_TK}' # Convert unknown id to unknown tokens. assert tknzr.dec( tkids=[ BOS_TKID, max(tknzr.tk2id.values()) + 1, max(tknzr.tk2id.values()) + 2, EOS_TKID, PAD_TKID, ], rm_sp_tks=False, ) == f'{BOS_TK}{UNK_TK}{UNK_TK}{EOS_TK}{PAD_TK}'
def test_normalization() -> None: """Must normalize text first.""" tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0) tknzr.build_vocab(['0', '0é']) assert tknzr.tk2id == { BOS_TK: BOS_TKID, EOS_TK: EOS_TKID, PAD_TK: PAD_TKID, UNK_TK: UNK_TKID, unicodedata.normalize('NFKC', '0'): max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1, unicodedata.normalize('NFKC', 'é'): max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2, }
def test_limit_build() -> None: """Must have correct vocabulary size.""" tknzr = CharTknzr(is_uncased=False, max_vocab=10, min_count=0) tknzr.build_vocab([chr(i) for i in range(65536)]) assert tknzr.vocab_size == 10
def test_uncased_batch_enc() -> None: """Encode batch of text to batch of token ids (case-insensitive).""" tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0) tknzr.build_vocab(batch_txt=['a']) # Return empty list when given empty list. assert tknzr.batch_enc(batch_txt=[], max_seq_len=2) == [] # Batch encoding format. assert tknzr.batch_enc(batch_txt=['aA', 'Aa'], max_seq_len=4) == [ [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID, ], [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID, ], ] # Truncate and pad to specified length. assert tknzr.batch_enc(batch_txt=['a', 'aA', 'aAA'], max_seq_len=4) == [ [ BOS_TKID, tknzr.tk2id['a'], EOS_TKID, PAD_TKID, ], [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID, ], [ BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], tknzr.tk2id['a'], ], ] # Unknown tokens. assert tknzr.batch_enc(batch_txt=['a', 'ab', 'abc'], max_seq_len=4) == [ [ BOS_TKID, tknzr.tk2id['a'], EOS_TKID, PAD_TKID, ], [ BOS_TKID, tknzr.tk2id['a'], UNK_TKID, EOS_TKID, ], [ BOS_TKID, tknzr.tk2id['a'], UNK_TKID, UNK_TKID, ], ]