Пример #1
0
def test_bert_sentences_transform():
    text_a = u'is this jacksonville ?'
    text_b = u'no it is not'
    vocab_tokens = ['is', 'this', 'jack', '##son', '##ville', '?', 'no', 'it', 'is', 'not']

    bert_vocab = BERTVocab(count_tokens(vocab_tokens))
    tokenizer = t.BERTTokenizer(vocab=bert_vocab)

    # test BERTSentenceTransform
    bert_st = t.BERTSentenceTransform(tokenizer, 15, pad=True, pair=True)
    token_ids, length, type_ids = bert_st((text_a, text_b))

    text_a_tokens = ['is', 'this', 'jack', '##son', '##ville', '?']
    text_b_tokens = ['no', 'it', 'is', 'not']
    text_a_ids = bert_vocab[text_a_tokens]
    text_b_ids = bert_vocab[text_b_tokens]

    cls_ids = bert_vocab[[bert_vocab.cls_token]]
    sep_ids = bert_vocab[[bert_vocab.sep_token]]
    pad_ids = bert_vocab[[bert_vocab.padding_token]]

    concated_ids = cls_ids + text_a_ids + sep_ids + text_b_ids + sep_ids + pad_ids
    valid_token_ids = np.array([pad_ids[0]] * 15, dtype=np.int32)
    for i, x in enumerate(concated_ids):
        valid_token_ids[i] = x
    valid_type_ids = np.zeros((15,), dtype=np.int32)
    start = len(text_a_tokens) + 2
    end = len(text_a_tokens) + 2 + len(text_b_tokens) + 1
    valid_type_ids[start:end] = 1

    assert all(token_ids == valid_token_ids)
    assert length == len(vocab_tokens) + 3
    assert all(type_ids == valid_type_ids)
def test_bert_dataset_transform():
    text_a = u'is this jacksonville ?'
    text_b = u'no it is not'
    label_cls = 0
    vocab_tokens = [
        'is', 'this', 'jack', '##son', '##ville', '?', 'no', 'it', 'is', 'not'
    ]

    bert_vocab = BERTVocab(count_tokens(vocab_tokens))
    tokenizer = BERTTokenizer(vocab=bert_vocab)

    # test BERTDatasetTransform for classification task
    bert_cls_dataset_t = BERTDatasetTransform(tokenizer,
                                              15,
                                              labels=[label_cls],
                                              pad=True,
                                              pair=True,
                                              label_dtype='int32')
    token_ids, length, type_ids, label_ids = bert_cls_dataset_t(
        (text_a, text_b, label_cls))

    text_a_tokens = ['is', 'this', 'jack', '##son', '##ville', '?']
    text_b_tokens = ['no', 'it', 'is', 'not']
    text_a_ids = bert_vocab[text_a_tokens]
    text_b_ids = bert_vocab[text_b_tokens]

    cls_ids = bert_vocab[[bert_vocab.cls_token]]
    sep_ids = bert_vocab[[bert_vocab.sep_token]]
    pad_ids = bert_vocab[[bert_vocab.padding_token]]

    concated_ids = cls_ids + text_a_ids + sep_ids + text_b_ids + sep_ids + pad_ids
    valid_token_ids = np.array([pad_ids[0]] * 15, dtype=np.int32)
    for i, x in enumerate(concated_ids):
        valid_token_ids[i] = x
    valid_type_ids = np.zeros((15, ), dtype=np.int32)
    start = len(text_a_tokens) + 2
    end = len(text_a_tokens) + 2 + len(text_b_tokens) + 1
    valid_type_ids[start:end] = 1

    assert all(token_ids == valid_token_ids)
    assert length == len(vocab_tokens) + 3
    assert all(type_ids == valid_type_ids)
    assert all(label_ids == np.array([label_cls], dtype=np.int32))

    # test BERTDatasetTransform for regression task
    label_reg = 0.2
    bert_reg_dataset_t = BERTDatasetTransform(tokenizer,
                                              15,
                                              pad=True,
                                              pair=True,
                                              label_dtype='float32')
    token_ids, length, type_ids, label_reg_val = bert_reg_dataset_t(
        (text_a, text_b, label_reg))
    assert all(token_ids == valid_token_ids)
    assert length == len(vocab_tokens) + 3
    assert all(type_ids == valid_type_ids)
    assert all(label_reg_val == np.array([label_reg], dtype=np.float32))
Пример #3
0
def test_bert_sentencepiece_sentences_transform():
    url = 'http://repo.mxnet.io/gluon/dataset/vocab/test-682b5d15.bpe'
    with warnings.catch_warnings():
        # UserWarning: File test-682b5d15.bpe exists in file system so the downloaded file is deleted
        warnings.simplefilter("ignore")
        f = download(url, overwrite=True)
    bert_vocab = BERTVocab.from_sentencepiece(f)
    bert_tokenizer = t.BERTSPTokenizer(f, bert_vocab, lower=True)
    assert bert_tokenizer.is_first_subword(u'▁this')
    assert not bert_tokenizer.is_first_subword(u'this')
    max_len = 36
    data_train_raw = SimpleDataset(
        [[u'This is a very awesome, life-changing sentence.']])
    transform = t.BERTSentenceTransform(bert_tokenizer,
                                        max_len,
                                        pad=True,
                                        pair=False)
    try:
        data_train = data_train_raw.transform(transform)
    except ImportError:
        warnings.warn(
            "Sentencepiece not installed, skip test_bert_sentencepiece_sentences_transform()."
        )
        return
    processed = list(data_train)[0]

    tokens = [
        u'▁this', u'▁is', u'▁a', u'▁very', u'▁a', u'w', u'es', u'om', u'e',
        u'▁', u',', u'▁life', u'▁', u'-', u'▁c', u'hang', u'ing', u'▁sentence',
        u'▁', u'.'
    ]
    token_ids = [bert_vocab[bert_vocab.cls_token]
                 ] + bert_tokenizer.convert_tokens_to_ids(tokens) + [
                     bert_vocab[bert_vocab.sep_token]
                 ]
    token_ids += [bert_vocab[bert_vocab.padding_token]
                  ] * (max_len - len(token_ids))

    # token ids
    assert all(processed[0] == np.array(token_ids, dtype='int32'))
    # sequence length
    assert processed[1].item() == len(tokens) + 2
    # segment id
    assert all(processed[2] == np.array([0] * max_len, dtype='int32'))
Пример #4
0
        ptr_vocab_path = ptr_dir / 'pytorch_model_skt_vocab.json'
        ptr_tokenizer_path = ptr_dir /'pytorch_model_skt_tokenizer.model'

        if not ptr_bert_path.exists():
            urlretrieve('https://kobert.blob.core.windows.net/models/kobert/pytorch/pytorch_kobert_2439f391a6.params',
                        filename=ptr_bert_path)
            ptr_bert = torch.load(ptr_bert_path)
            ptr_bert = OrderedDict([(('bert.' + k), ptr_bert.get(k)) for k in ptr_bert.keys()])
            torch.save(ptr_bert, ptr_bert_path)
        else:
            print('Already you have pytorch_model_skt.bin!')

        if not ptr_vocab_path.exists():
            urlretrieve('https://kobert.blob.core.windows.net/models/kobert/vocab/kobertvocab_f38b8a4d6d.json',
                        filename=ptr_vocab_path)
            ptr_bert_vocab = BERTVocab.from_json(ptr_vocab_path.open(mode='rt').read())
            vocab = Vocab(ptr_bert_vocab.idx_to_token,
                          padding_token="[PAD]",
                          unknown_token="[UNK]",
                          bos_token=None,
                          eos_token=None,
                          reserved_tokens=["[CLS]", "[SEP]", "[MASK]"],
                          token_to_idx=ptr_bert_vocab.token_to_idx)

            # save vocab
            with open(ptr_vocab_path.with_suffix('.pkl'), mode="wb") as io:
                pickle.dump(vocab, io)
        else:
            print('Already you have pytorch_model_skt_vocab.json!')

        if not ptr_tokenizer_path.exists():