Ejemplo n.º 1
0
class BertEmbeddingsTest(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(BertEmbeddingsTest, self).__init__(*args, **kwargs)
        self.embedding = BERTEmbedding('chinese_L-12_H-768_A-12',
                                       sequence_length=SEQUENCE_LENGTH)

    def test_build(self):
        self.assertGreater(self.embedding.embedding_size, 0)
        self.assertEqual(self.embedding.token2idx[k.PAD], 0)
        self.assertGreater(self.embedding.token2idx[k.BOS], 0)
        self.assertGreater(self.embedding.token2idx[k.EOS], 0)
        self.assertGreater(self.embedding.token2idx[k.UNK], 0)

    def test_tokenize(self):
        sentence = ['我', '想', '看', '电影', '%%##!$#%']
        tokens = self.embedding.tokenize(sentence)

        logging.info('tokenize test: {} -> {}'.format(sentence, tokens))
        token_list = self.embedding.tokenize([sentence])
        self.assertEqual(len(token_list[0]), len(sentence)+2)

    def test_embed(self):
        sentence = ['我', '想', '看', '电影', '%%##!$#%']
        embedded_sentence = self.embedding.embed(sentence)
        logging.info('embed test: {} -> {}'.format(sentence, embedded_sentence))
        self.assertEqual(embedded_sentence.shape, (SEQUENCE_LENGTH, self.embedding.embedding_size))

        embedded_sentences = self.embedding.embed([sentence])
        self.assertEqual(embedded_sentences.shape, (1, SEQUENCE_LENGTH, self.embedding.embedding_size))
Ejemplo n.º 2
0
    model_folder='/Users/xuming06/Codes/bert/data/chinese_L-12_H-768_A-12',
    sequence_length=12)

# from kashgari.corpus import SMP2018ECDTCorpus

# test_x, test_y = SMP2018ECDTCorpus.load_data('valid')

# b.analyze_corpus(test_x, test_y)
data1 = '湖 北'.split(' ')
data3 = '纽 约'.split(' ')
data2 = '武 汉'.split(' ')
data4 = '武 汉'.split(' ')
data5 = '北 京'.split(' ')
data6 = '武 汉 地 铁'.split(' ')
sents = [data1, data3, data4, data5, data6]
doc_vecs = b.embed(sents, debug=True)

tokens = b.process_x_dataset([['语', '言', '模', '型']])[0]
target_index = [101, 6427, 6241, 3563, 1798, 102]
target_index = target_index + [0] * (12 - len(target_index))
assert list(tokens[0]) == list(target_index)
print(tokens)
print(doc_vecs)
print(doc_vecs.shape)
print(doc_vecs[0])
print(doc_vecs[0][0])

query_vec = b.embed([data2])[0]
query = '武 汉'
# compute normalized dot product as score
for i, sent in enumerate(sents):