예제 #1
0
def test_uncased_enc() -> None:
    """Encode text to token ids (case-insensitive)."""
    tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0)
    tknzr.build_vocab(batch_txt=['a'])

    # Return `[bos] [eos]` when given empty input.
    assert tknzr.enc(max_seq_len=2, txt='') == [BOS_TKID, EOS_TKID]

    # Encoding format.
    assert tknzr.enc(max_seq_len=4, txt='aA') == [
        BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID
    ]

    # Padding.
    assert tknzr.enc(max_seq_len=5, txt='aA') == [
        BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a'], EOS_TKID, PAD_TKID
    ]

    # Truncate.
    assert tknzr.enc(max_seq_len=3, txt='aA') == [
        BOS_TKID, tknzr.tk2id['a'], tknzr.tk2id['a']
    ]

    # Unknown tokens.
    assert tknzr.enc(max_seq_len=4,
                     txt='bB') == [BOS_TKID, UNK_TKID, UNK_TKID, EOS_TKID]

    # Unknown tokens with padding.
    assert tknzr.enc(max_seq_len=5, txt='bB') == [
        BOS_TKID, UNK_TKID, UNK_TKID, EOS_TKID, PAD_TKID
    ]

    # Unknown tokens with truncation.
    assert tknzr.enc(max_seq_len=2, txt='bB') == [BOS_TKID, UNK_TKID]
예제 #2
0
def test_no_limit_build() -> None:
    """Include all tokens when ``max_vocab == -1``."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0)
    CJK_txt = [chr(i) for i in range(ord('\u4e00'), ord('\u9fff') + 1)]
    norm_CJK_txt = [tknzr.norm(t) for t in CJK_txt]
    tknzr.build_vocab(CJK_txt)
    assert tknzr.vocab_size == len(set(norm_CJK_txt)) + 4
    assert all(map(lambda tk: tk in tknzr.tk2id, norm_CJK_txt))
예제 #3
0
def test_empty_build() -> None:
    """Build nothing when given empty list."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0)
    tknzr.build_vocab([])
    assert tknzr.vocab_size == 4
    assert tknzr.tk2id == {
        BOS_TK: BOS_TKID,
        EOS_TK: EOS_TKID,
        PAD_TK: PAD_TKID,
        UNK_TK: UNK_TKID
    }
예제 #4
0
def test_case_insensitive() -> None:
    """Must be case-insensitive when ``is_uncased = True``."""
    tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0)
    tknzr.build_vocab(['a', 'A'])
    assert tknzr.tk2id == {
        BOS_TK: BOS_TKID,
        EOS_TK: EOS_TKID,
        PAD_TK: PAD_TKID,
        UNK_TK: UNK_TKID,
        'a': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1,
    }
예제 #5
0
def test_minimum_occurrence_counts() -> None:
    """Must satisfy minumum occurrence counts to include tokens in vocabulary."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=2)
    tknzr.build_vocab(['c', 'bc', 'abc'])
    assert tknzr.tk2id == {
        BOS_TK: BOS_TKID,
        EOS_TK: EOS_TKID,
        PAD_TK: PAD_TKID,
        UNK_TK: UNK_TKID,
        'c': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1,
        'b': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2,
    }
예제 #6
0
def test_sort_by_occurrence_counts() -> None:
    """Sort vocabulary by occurrence counts."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0)
    tknzr.build_vocab(['c', 'bc', 'abc'])
    assert tknzr.tk2id == {
        BOS_TK: BOS_TKID,
        EOS_TK: EOS_TKID,
        PAD_TK: PAD_TKID,
        UNK_TK: UNK_TKID,
        'c': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1,
        'b': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2,
        'a': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 3,
    }
예제 #7
0
def test_continue_build() -> None:
    """Build vocabulary based on existed vocabulary."""
    tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0)
    tknzr.build_vocab(['a'])
    tknzr.build_vocab(['b'])
    tknzr.build_vocab(['c'])
    assert tknzr.tk2id == {
        BOS_TK: BOS_TKID,
        EOS_TK: EOS_TKID,
        PAD_TK: PAD_TKID,
        UNK_TK: UNK_TKID,
        'a': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1,
        'b': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2,
        'c': max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 3,
    }
예제 #8
0
def test_dec() -> None:
    """Decode token ids to text."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0)
    tknzr.build_vocab(batch_txt=['A', 'a'])

    # Return empty string when given empty list.
    assert tknzr.dec(tkids=[]) == ''

    # Decoding format.
    assert tknzr.dec(
        tkids=[
            BOS_TKID,
            tknzr.tk2id['a'],
            UNK_TKID,
            tknzr.tk2id['A'],
            EOS_TKID,
            PAD_TKID,
        ],
        rm_sp_tks=False,
    ) == f'{BOS_TK}a{UNK_TK}A{EOS_TK}{PAD_TK}'

    # Remove special tokens but not unknown tokens.
    assert tknzr.dec(
        tkids=[
            BOS_TKID,
            tknzr.tk2id['a'],
            UNK_TKID,
            tknzr.tk2id['A'],
            UNK_TKID,
            EOS_TKID,
            PAD_TKID,
        ],
        rm_sp_tks=True,
    ) == f'a{UNK_TK}A{UNK_TK}'

    # Convert unknown id to unknown tokens.
    assert tknzr.dec(
        tkids=[
            BOS_TKID,
            max(tknzr.tk2id.values()) + 1,
            max(tknzr.tk2id.values()) + 2,
            EOS_TKID,
            PAD_TKID,
        ],
        rm_sp_tks=False,
    ) == f'{BOS_TK}{UNK_TK}{UNK_TK}{EOS_TK}{PAD_TK}'
예제 #9
0
def test_normalization() -> None:
    """Must normalize text first."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=-1, min_count=0)
    tknzr.build_vocab(['0', '0é'])
    assert tknzr.tk2id == {
        BOS_TK:
        BOS_TKID,
        EOS_TK:
        EOS_TKID,
        PAD_TK:
        PAD_TKID,
        UNK_TK:
        UNK_TKID,
        unicodedata.normalize('NFKC', '0'):
        max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 1,
        unicodedata.normalize('NFKC', 'é'):
        max(BOS_TKID, EOS_TKID, PAD_TKID, UNK_TKID) + 2,
    }
예제 #10
0
def test_limit_build() -> None:
    """Must have correct vocabulary size."""
    tknzr = CharTknzr(is_uncased=False, max_vocab=10, min_count=0)
    tknzr.build_vocab([chr(i) for i in range(65536)])
    assert tknzr.vocab_size == 10
예제 #11
0
def test_uncased_batch_enc() -> None:
    """Encode batch of text to batch of token ids (case-insensitive)."""
    tknzr = CharTknzr(is_uncased=True, max_vocab=-1, min_count=0)
    tknzr.build_vocab(batch_txt=['a'])

    # Return empty list when given empty list.
    assert tknzr.batch_enc(batch_txt=[], max_seq_len=2) == []

    # Batch encoding format.
    assert tknzr.batch_enc(batch_txt=['aA', 'Aa'], max_seq_len=4) == [
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            tknzr.tk2id['a'],
            EOS_TKID,
        ],
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            tknzr.tk2id['a'],
            EOS_TKID,
        ],
    ]

    # Truncate and pad to specified length.
    assert tknzr.batch_enc(batch_txt=['a', 'aA', 'aAA'], max_seq_len=4) == [
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            EOS_TKID,
            PAD_TKID,
        ],
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            tknzr.tk2id['a'],
            EOS_TKID,
        ],
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            tknzr.tk2id['a'],
            tknzr.tk2id['a'],
        ],
    ]

    # Unknown tokens.
    assert tknzr.batch_enc(batch_txt=['a', 'ab', 'abc'], max_seq_len=4) == [
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            EOS_TKID,
            PAD_TKID,
        ],
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            UNK_TKID,
            EOS_TKID,
        ],
        [
            BOS_TKID,
            tknzr.tk2id['a'],
            UNK_TKID,
            UNK_TKID,
        ],
    ]