Ejemplo n.º 1
0
def test_berttokenizer():

    # test WordpieceTokenizer
    vocab_tokens = ["want", "##want", "##ed", "wa", "un", "runn", "##ing"]
    vocab = Vocab(
        count_tokens(vocab_tokens),
        reserved_tokens=["[CLS]", "[SEP]"],
        unknown_token="[UNK]", padding_token=None, bos_token=None, eos_token=None)
    tokenizer = t.BERTTokenizer(vocab=vocab)

    assert tokenizer(u"unwanted running") == [
        "un", "##want", "##ed", "runn", "##ing"]
    assert tokenizer(u"unwantedX running") == ["[UNK]", "runn", "##ing"]
    assert tokenizer.is_first_subword('un')
    assert not tokenizer.is_first_subword('##want')

    # test BERTTokenizer
    vocab_tokens = ["[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
                    "##ing", ","]

    vocab = Vocab(
        count_tokens(vocab_tokens),
        reserved_tokens=["[CLS]", "[SEP]"],
        unknown_token="[UNK]", padding_token=None, bos_token=None, eos_token=None)
    tokenizer = t.BERTTokenizer(vocab=vocab)
    tokens = tokenizer(u"UNwant\u00E9d,running")
    assert tokens == ["un", "##want", "##ed", ",", "runn", "##ing"]
Ejemplo n.º 2
0
def test_bert_sentences_transform():
    text_a = u'is this jacksonville ?'
    text_b = u'no it is not'
    vocab_tokens = ['is', 'this', 'jack', '##son', '##ville', '?', 'no', 'it', 'is', 'not']

    bert_vocab = BERTVocab(count_tokens(vocab_tokens))
    tokenizer = t.BERTTokenizer(vocab=bert_vocab)

    # test BERTSentenceTransform
    bert_st = t.BERTSentenceTransform(tokenizer, 15, pad=True, pair=True)
    token_ids, length, type_ids = bert_st((text_a, text_b))

    text_a_tokens = ['is', 'this', 'jack', '##son', '##ville', '?']
    text_b_tokens = ['no', 'it', 'is', 'not']
    text_a_ids = bert_vocab[text_a_tokens]
    text_b_ids = bert_vocab[text_b_tokens]

    cls_ids = bert_vocab[[bert_vocab.cls_token]]
    sep_ids = bert_vocab[[bert_vocab.sep_token]]
    pad_ids = bert_vocab[[bert_vocab.padding_token]]

    concated_ids = cls_ids + text_a_ids + sep_ids + text_b_ids + sep_ids + pad_ids
    valid_token_ids = np.array([pad_ids[0]] * 15, dtype=np.int32)
    for i, x in enumerate(concated_ids):
        valid_token_ids[i] = x
    valid_type_ids = np.zeros((15,), dtype=np.int32)
    start = len(text_a_tokens) + 2
    end = len(text_a_tokens) + 2 + len(text_b_tokens) + 1
    valid_type_ids[start:end] = 1

    assert all(token_ids == valid_token_ids)
    assert length == len(vocab_tokens) + 3
    assert all(type_ids == valid_type_ids)