示例#1
0
 def load_model(self):
     if not self.model:
         if self.embedding_type == EmbType.BERT:
             from text2vec.embeddings.bert_embedding import BERTEmbedding
             self.model = BERTEmbedding(sequence_length=128)
         elif self.embedding_type == EmbType.W2V:
             from text2vec.embeddings.word_embedding import WordEmbedding
             self.model = WordEmbedding()
         else:
             raise ValueError('set error embedding type.')
示例#2
0
 def load_model(self):
     if not self.model:
         if self.embedding_type == EmbType.BERT:
             from text2vec.embeddings.bert_embedding import BERTEmbedding
             self.model = BERTEmbedding(model_folder=self.bert_model_folder,
                                        layer_nums=self.bert_layer_nums,
                                        trainable=self.trainable,
                                        sequence_length=self.sequence_length,
                                        processor=self.processor)
         elif self.embedding_type == EmbType.W2V:
             from text2vec.embeddings.word_embedding import WordEmbedding
             self.model = WordEmbedding(w2v_path=self.w2v_path,
                                        w2v_kwargs=self.w2v_kwargs,
                                        sequence_length=self.sequence_length,
                                        processor=self.processor,
                                        trainable=self.trainable)
         else:
             raise ValueError('set error embedding type.')
 def setUpClass(cls):
     from text2vec.embeddings.bert_embedding import BERTEmbedding
     cls.embedding = BERTEmbedding(sequence_length=SEQUENCE_LENGTH)
示例#4
0
# -*- coding: utf-8 -*-
"""
@author:XuMing<*****@*****.**>
@description: 
"""
from text2vec.embeddings.bert_embedding import BERTEmbedding

if __name__ == '__main__':
    b = BERTEmbedding()

    data1 = 'all work and no play makes'.split(' ')
    data2 = '你 好 啊'.split(' ')
    r = b.embed([data1], True)

    tokens = b.process_x_dataset([['语', '言', '模', '型']])[0]
    target_index = [101, 6427, 6241, 3563, 1798, 102]
    target_index = target_index + [0] * (12 - len(target_index))
    print(list(tokens[0]), list(target_index))
    # assert list(tokens[0]) == list(target_index)
    print(tokens)
    print(r)
    print(r.shape)