def test_bert_sentences_transform(): text_a = u'is this jacksonville ?' text_b = u'no it is not' vocab_tokens = ['is', 'this', 'jack', '##son', '##ville', '?', 'no', 'it', 'is', 'not'] bert_vocab = BERTVocab(count_tokens(vocab_tokens)) tokenizer = t.BERTTokenizer(vocab=bert_vocab) # test BERTSentenceTransform bert_st = t.BERTSentenceTransform(tokenizer, 15, pad=True, pair=True) token_ids, length, type_ids = bert_st((text_a, text_b)) text_a_tokens = ['is', 'this', 'jack', '##son', '##ville', '?'] text_b_tokens = ['no', 'it', 'is', 'not'] text_a_ids = bert_vocab[text_a_tokens] text_b_ids = bert_vocab[text_b_tokens] cls_ids = bert_vocab[[bert_vocab.cls_token]] sep_ids = bert_vocab[[bert_vocab.sep_token]] pad_ids = bert_vocab[[bert_vocab.padding_token]] concated_ids = cls_ids + text_a_ids + sep_ids + text_b_ids + sep_ids + pad_ids valid_token_ids = np.array([pad_ids[0]] * 15, dtype=np.int32) for i, x in enumerate(concated_ids): valid_token_ids[i] = x valid_type_ids = np.zeros((15,), dtype=np.int32) start = len(text_a_tokens) + 2 end = len(text_a_tokens) + 2 + len(text_b_tokens) + 1 valid_type_ids[start:end] = 1 assert all(token_ids == valid_token_ids) assert length == len(vocab_tokens) + 3 assert all(type_ids == valid_type_ids)
def test_bert_sentencepiece_sentences_transform(): url = 'http://repo.mxnet.io/gluon/dataset/vocab/test-682b5d15.bpe' with warnings.catch_warnings(): # UserWarning: File test-682b5d15.bpe exists in file system so the downloaded file is deleted warnings.simplefilter("ignore") f = download(url, overwrite=True) bert_vocab = BERTVocab.from_sentencepiece(f) bert_tokenizer = t.BERTSPTokenizer(f, bert_vocab, lower=True) assert bert_tokenizer.is_first_subword(u'▁this') assert not bert_tokenizer.is_first_subword(u'this') max_len = 36 data_train_raw = SimpleDataset( [[u'This is a very awesome, life-changing sentence.']]) transform = t.BERTSentenceTransform(bert_tokenizer, max_len, pad=True, pair=False) try: data_train = data_train_raw.transform(transform) except ImportError: warnings.warn( "Sentencepiece not installed, skip test_bert_sentencepiece_sentences_transform()." ) return processed = list(data_train)[0] tokens = [ u'▁this', u'▁is', u'▁a', u'▁very', u'▁a', u'w', u'es', u'om', u'e', u'▁', u',', u'▁life', u'▁', u'-', u'▁c', u'hang', u'ing', u'▁sentence', u'▁', u'.' ] token_ids = [bert_vocab[bert_vocab.cls_token] ] + bert_tokenizer.convert_tokens_to_ids(tokens) + [ bert_vocab[bert_vocab.sep_token] ] token_ids += [bert_vocab[bert_vocab.padding_token] ] * (max_len - len(token_ids)) # token ids assert all(processed[0] == np.array(token_ids, dtype='int32')) # sequence length assert processed[1].item() == len(tokens) + 2 # segment id assert all(processed[2] == np.array([0] * max_len, dtype='int32'))