def test_berttokenizer(): # test WordpieceTokenizer vocab_tokens = ["want", "##want", "##ed", "wa", "un", "runn", "##ing"] vocab = Vocab( count_tokens(vocab_tokens), reserved_tokens=["[CLS]", "[SEP]"], unknown_token="[UNK]", padding_token=None, bos_token=None, eos_token=None) tokenizer = t.BERTTokenizer(vocab=vocab) assert tokenizer(u"unwanted running") == [ "un", "##want", "##ed", "runn", "##ing"] assert tokenizer(u"unwantedX running") == ["[UNK]", "runn", "##ing"] assert tokenizer.is_first_subword('un') assert not tokenizer.is_first_subword('##want') # test BERTTokenizer vocab_tokens = ["[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ","] vocab = Vocab( count_tokens(vocab_tokens), reserved_tokens=["[CLS]", "[SEP]"], unknown_token="[UNK]", padding_token=None, bos_token=None, eos_token=None) tokenizer = t.BERTTokenizer(vocab=vocab) tokens = tokenizer(u"UNwant\u00E9d,running") assert tokens == ["un", "##want", "##ed", ",", "runn", "##ing"]
def test_bert_sentences_transform(): text_a = u'is this jacksonville ?' text_b = u'no it is not' vocab_tokens = ['is', 'this', 'jack', '##son', '##ville', '?', 'no', 'it', 'is', 'not'] bert_vocab = BERTVocab(count_tokens(vocab_tokens)) tokenizer = t.BERTTokenizer(vocab=bert_vocab) # test BERTSentenceTransform bert_st = t.BERTSentenceTransform(tokenizer, 15, pad=True, pair=True) token_ids, length, type_ids = bert_st((text_a, text_b)) text_a_tokens = ['is', 'this', 'jack', '##son', '##ville', '?'] text_b_tokens = ['no', 'it', 'is', 'not'] text_a_ids = bert_vocab[text_a_tokens] text_b_ids = bert_vocab[text_b_tokens] cls_ids = bert_vocab[[bert_vocab.cls_token]] sep_ids = bert_vocab[[bert_vocab.sep_token]] pad_ids = bert_vocab[[bert_vocab.padding_token]] concated_ids = cls_ids + text_a_ids + sep_ids + text_b_ids + sep_ids + pad_ids valid_token_ids = np.array([pad_ids[0]] * 15, dtype=np.int32) for i, x in enumerate(concated_ids): valid_token_ids[i] = x valid_type_ids = np.zeros((15,), dtype=np.int32) start = len(text_a_tokens) + 2 end = len(text_a_tokens) + 2 + len(text_b_tokens) + 1 valid_type_ids[start:end] = 1 assert all(token_ids == valid_token_ids) assert length == len(vocab_tokens) + 3 assert all(type_ids == valid_type_ids)