Exemplo n.º 1
0
    def check_consistency(vocab):
        for i, token in enumerate(vocab.all_tokens):
            assert vocab[token] == i
        if hasattr(vocab, 'unk_token'):
            assert vocab['some1234123dasf'] == vocab[vocab.unk_token]
        assert len(vocab) == len(vocab.all_tokens)
        if len(vocab.all_tokens) > 0:
            random_idx = [
                random.randint(0,
                               len(vocab.all_tokens) - 1) for _ in range(20)
            ]
            assert vocab.to_tokens(random_idx) == [
                vocab.all_tokens[i] for i in random_idx
            ]
            assert vocab.to_tokens(np.array(random_idx)) == [
                vocab.all_tokens[i] for i in random_idx
            ]
            random_tokens = vocab.to_tokens(random_idx)
            assert vocab[random_tokens] == random_idx
            if vocab.has_unk:
                assert vocab[random_tokens + ['213412hadhfk']]\
                       == random_idx + [vocab.unk_id]
            for k, v in vocab.special_tokens_kv.items():
                idx_property = k[:-6] + '_id'
                assert getattr(vocab, idx_property) == vocab[v]

        # Test for serialize/deserailze from json
        json_str = vocab.to_json()
        new_vocab = Vocab.from_json(json_str)
        check_same_vocab(new_vocab, vocab)
        # Test for save/load from file
        while True:
            fname = '{}.json'.format(uuid.uuid4())
            if os.path.exists(fname):
                continue
            vocab.save(path=fname)
            new_vocab = Vocab.load(fname)
            check_same_vocab(new_vocab, vocab)
            os.remove(fname)
            break
Exemplo n.º 2
0
def test_vocab():
    def check_same_vocab(vocab1, vocab2):
        assert vocab1.all_tokens == vocab2.all_tokens
        assert len(vocab1._special_token_kv) == len(vocab2._special_token_kv)
        for k, v in vocab1._special_token_kv.items():
            assert v == vocab2._special_token_kv[k]
            assert getattr(vocab1, k) == getattr(vocab2, k)

    def check_consistency(vocab):
        for i, token in enumerate(vocab.all_tokens):
            assert vocab[token] == i
        if hasattr(vocab, 'unk_token'):
            assert vocab['some1234123dasf'] == vocab[vocab.unk_token]
        assert len(vocab) == len(vocab.all_tokens)
        if len(vocab.all_tokens) > 0:
            random_idx = [
                random.randint(0,
                               len(vocab.all_tokens) - 1) for _ in range(20)
            ]
            assert vocab.to_tokens(random_idx) == [
                vocab.all_tokens[i] for i in random_idx
            ]
            assert vocab.to_tokens(np.array(random_idx)) == [
                vocab.all_tokens[i] for i in random_idx
            ]
            random_tokens = vocab.to_tokens(random_idx)
            assert vocab[random_tokens] == random_idx
            if vocab.has_unk:
                assert vocab[random_tokens + ['213412hadhfk']]\
                       == random_idx + [vocab.unk_id]
            for k, v in vocab.special_tokens_kv.items():
                idx_property = k[:-6] + '_id'
                assert getattr(vocab, idx_property) == vocab[v]

        # Test for serialize/deserailze from json
        json_str = vocab.to_json()
        new_vocab = Vocab.from_json(json_str)
        check_same_vocab(new_vocab, vocab)
        # Test for save/load from file
        while True:
            fname = '{}.json'.format(uuid.uuid4())
            if os.path.exists(fname):
                continue
            vocab.save(path=fname)
            new_vocab = Vocab.load(fname)
            check_same_vocab(new_vocab, vocab)
            os.remove(fname)
            break

    words = ['a', 'a', 'b', 'd', 'c', 'b', 'a', 'c', 'd', 'd', 'd']
    random.shuffle(words)
    counter = collections.Counter(words)
    vocab = Vocab(counter, max_size=2, min_freq=None)
    check_consistency(vocab)
    assert vocab.all_tokens == ['d', 'a', '<unk>']
    # Test for unknown token
    vocab = Vocab(tokens=counter,
                  max_size=2,
                  min_freq=None,
                  unk_token='<unk2>')
    check_consistency(vocab)
    assert vocab.all_tokens == ['d', 'a', '<unk2>']

    vocab = Vocab(tokens=counter,
                  max_size=None,
                  min_freq=None,
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  cls_token=Vocab.CLS_TOKEN,
                  sep_token=Vocab.SEP_TOKEN,
                  mask_token=Vocab.MASK_TOKEN)
    check_consistency(vocab)
    assert vocab.unk_token == Vocab.UNK_TOKEN
    assert vocab.pad_token == Vocab.PAD_TOKEN
    assert vocab.eos_token == Vocab.EOS_TOKEN
    assert vocab.bos_token == Vocab.BOS_TOKEN
    assert vocab.cls_token == Vocab.CLS_TOKEN
    assert vocab.sep_token == Vocab.SEP_TOKEN
    assert vocab.mask_token == Vocab.MASK_TOKEN
    assert vocab.special_token_keys == [
        'unk_token', 'bos_token', 'cls_token', 'eos_token', 'mask_token',
        'pad_token', 'sep_token'
    ]
    assert vocab.special_tokens == [
        '<unk>', '<bos>', '<cls>', '<eos>', '<mask>', '<pad>', '<sep>'
    ]
    assert vocab.all_tokens == [
        'd', 'a', 'c', 'b', '<unk>', '<bos>', '<cls>', '<eos>', '<mask>',
        '<pad>', '<sep>'
    ]

    vocab = Vocab(counter,
                  bos_token=Vocab.BOS_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  pad_token=Vocab.PAD_TOKEN)
    check_consistency(vocab)
    assert vocab.all_tokens == [
        'd', 'a', 'c', 'b', '<unk>', '<bos>', '<eos>', '<pad>'
    ]

    vocab = Vocab(counter,
                  max_size=None,
                  min_freq=None,
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask2>',
                  other3_token='<other3>',
                  other2_token='<other2>')
    check_consistency(vocab)
    assert vocab.all_tokens == [
        'd', 'a', 'c', 'b', '<unk>', '<bos>', '<eos>', '<mask2>', '<other2>',
        '<other3>', '<pad>'
    ]
    assert vocab.mask_token == '<mask2>'
    assert vocab.other2_token == '<other2>'
    assert vocab.other3_token == '<other3>'
    assert vocab.special_token_keys == [
        'unk_token', 'bos_token', 'eos_token', 'mask_token', 'other2_token',
        'other3_token', 'pad_token'
    ]
    assert vocab.special_tokens == [
        '<unk>', '<bos>', '<eos>', '<mask2>', '<other2>', '<other3>', '<pad>'
    ]

    vocab = Vocab(counter, max_size=1, min_freq=10000, unk_token=None)
    check_consistency(vocab)
    assert vocab.all_tokens == []

    vocab = Vocab([],
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask2>')
    check_consistency(vocab)
    assert vocab.all_tokens == ['<unk>', '<bos>', '<eos>', '<mask2>', '<pad>']
    vocab = Vocab(pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask2>')
    check_consistency(vocab)
    assert vocab.all_tokens == ['<unk>', '<bos>', '<eos>', '<mask2>', '<pad>']

    vocab = Vocab(['<unk2>', '<pad>', '<bos>', '<eos>', '<mask>', 'a'],
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask>')
    check_consistency(vocab)
    assert vocab.all_tokens == [
        '<unk2>', '<pad>', '<bos>', '<eos>', '<mask>', 'a', '<unk>'
    ]
    assert vocab.special_tokens == [
        '<pad>', '<bos>', '<eos>', '<mask>', '<unk>'
    ]
    assert vocab.special_token_keys == [
        'pad_token', 'bos_token', 'eos_token', 'mask_token', 'unk_token'
    ]

    # Check errors
    with pytest.raises(ValueError):
        vocab = Vocab(['a', 'a', 'a'])
    with pytest.raises(ValueError):
        vocab = Vocab(['a', 'b', 'c'],
                      mask_token='<mask>',
                      another_mask_token='<mask>')
    with pytest.raises(ValueError):
        vocab = Vocab(['a', 'b', 'c'],
                      mask_token='<mask>',
                      another_mask_token='<mask>')
    vocab = Vocab(['a', 'b', 'c'])
    check_consistency(vocab)

    # Check emoji
    all_tokens = ['<unk>', '😁']
    vocab = Vocab(all_tokens, unk_token='<unk>')
    vocab_file = str(uuid.uuid4()) + '.vocab'
    vocab.save(vocab_file)
    vocab = Vocab.load(vocab_file)
    assert vocab.all_tokens == all_tokens
    os.remove(vocab_file)