def test_bert_sentences_transform():
    text_a = u'is this jacksonville ?'
    text_b = u'no it is not'
    vocab_tokens = ['is', 'this', 'jack', '##son', '##ville', '?', 'no', 'it', 'is', 'not']

    bert_vocab = BERTVocab(count_tokens(vocab_tokens))
    tokenizer = t.BERTTokenizer(vocab=bert_vocab)

    # test BERTSentenceTransform
    bert_st = t.BERTSentenceTransform(tokenizer, 15, pad=True, pair=True)
    token_ids, length, type_ids = bert_st((text_a, text_b))

    text_a_tokens = ['is', 'this', 'jack', '##son', '##ville', '?']
    text_b_tokens = ['no', 'it', 'is', 'not']
    text_a_ids = bert_vocab[text_a_tokens]
    text_b_ids = bert_vocab[text_b_tokens]

    cls_ids = bert_vocab[[bert_vocab.cls_token]]
    sep_ids = bert_vocab[[bert_vocab.sep_token]]
    pad_ids = bert_vocab[[bert_vocab.padding_token]]

    concated_ids = cls_ids + text_a_ids + sep_ids + text_b_ids + sep_ids + pad_ids
    valid_token_ids = np.array([pad_ids[0]] * 15, dtype=np.int32)
    for i, x in enumerate(concated_ids):
        valid_token_ids[i] = x
    valid_type_ids = np.zeros((15,), dtype=np.int32)
    start = len(text_a_tokens) + 2
    end = len(text_a_tokens) + 2 + len(text_b_tokens) + 1
    valid_type_ids[start:end] = 1

    assert all(token_ids == valid_token_ids)
    assert length == len(vocab_tokens) + 3
    assert all(type_ids == valid_type_ids)
Beispiel #2
0
def test_bert_sentencepiece_sentences_transform():
    url = 'http://repo.mxnet.io/gluon/dataset/vocab/test-682b5d15.bpe'
    with warnings.catch_warnings():
        # UserWarning: File test-682b5d15.bpe exists in file system so the downloaded file is deleted
        warnings.simplefilter("ignore")
        f = download(url, overwrite=True)
    bert_vocab = BERTVocab.from_sentencepiece(f)
    bert_tokenizer = t.BERTSPTokenizer(f, bert_vocab, lower=True)
    assert bert_tokenizer.is_first_subword(u'▁this')
    assert not bert_tokenizer.is_first_subword(u'this')
    max_len = 36
    data_train_raw = SimpleDataset(
        [[u'This is a very awesome, life-changing sentence.']])
    transform = t.BERTSentenceTransform(bert_tokenizer,
                                        max_len,
                                        pad=True,
                                        pair=False)
    try:
        data_train = data_train_raw.transform(transform)
    except ImportError:
        warnings.warn(
            "Sentencepiece not installed, skip test_bert_sentencepiece_sentences_transform()."
        )
        return
    processed = list(data_train)[0]

    tokens = [
        u'▁this', u'▁is', u'▁a', u'▁very', u'▁a', u'w', u'es', u'om', u'e',
        u'▁', u',', u'▁life', u'▁', u'-', u'▁c', u'hang', u'ing', u'▁sentence',
        u'▁', u'.'
    ]
    token_ids = [bert_vocab[bert_vocab.cls_token]
                 ] + bert_tokenizer.convert_tokens_to_ids(tokens) + [
                     bert_vocab[bert_vocab.sep_token]
                 ]
    token_ids += [bert_vocab[bert_vocab.padding_token]
                  ] * (max_len - len(token_ids))

    # token ids
    assert all(processed[0] == np.array(token_ids, dtype='int32'))
    # sequence length
    assert processed[1].item() == len(tokens) + 2
    # segment id
    assert all(processed[2] == np.array([0] * max_len, dtype='int32'))