def test_get_file(): _url = "https://raw.githubusercontent.com/shibing624/text2vec/master/LICENSE" file_path = get_file('LICENSE', _url, extract=True, cache_dir=text2vec.USER_DATA_DIR, cache_subdir='LICENSE', verbose=1) print("file_path:", file_path) num_lines = 201 assert len(open(file_path, 'rb').readlines()) == num_lines file_hash = hash_file(file_path, algorithm='md5') file_path2 = get_file('LICENSE', _url, extract=False, md5_hash=file_hash, cache_dir=text2vec.USER_DATA_DIR, cache_subdir='LICENSE', verbose=1) file_hash2 = hash_file(file_path2, algorithm='md5') assert file_hash == file_hash2 file_dir = text2vec.USER_DATA_DIR.joinpath('LICENSE') if os.path.exists(file_dir): shutil.rmtree(file_dir)
def _build_token2idx_from_bert(self): dict_path = os.path.join(self.model_folder, 'vocab.txt') if not os.path.exists(dict_path): model_name = self.model_key_map.get(self.model_folder, 'chinese_L-12_H-768_A-12') url = self.pre_trained_models.get(model_name) get_file(model_name + ".zip", url, extract=True, cache_dir=text2vec.USER_DATA_DIR, cache_subdir=text2vec.USER_DATA_DIR, verbose=1) self.model_folder = os.path.join(text2vec.USER_DATA_DIR, model_name) dict_path = os.path.join(self.model_folder, 'vocab.txt') logger.debug(f'load vocab.txt from {dict_path}') token2idx = {} with codecs.open(dict_path, 'r', encoding='utf-8') as f: for line in f: token = line.strip() token2idx[token] = len(token2idx) self.bert_token2idx = token2idx self.tokenizer = keras_bert.Tokenizer(token2idx) self.processor.token2idx = self.bert_token2idx self.processor.idx2token = dict([(value, key) for key, value in token2idx.items()])
def _build_token2idx_from_w2v(self): if not self.w2v_path or not os.path.exists(self.w2v_path): if self.w2v_path in self.model_key_map: self.w2v_path = self.model_key_map[self.w2v_path] model_dict = self.model_key_map.get( self.w2v_path, self.model_key_map['w2v-light-tencent-chinese']) tar_filename = model_dict.get('tar_filename') self.w2v_kwargs = {'binary': model_dict.get('binary')} url = model_dict.get('url') untar_filename = model_dict.get('untar_filename') self.w2v_path = os.path.join(text2vec.USER_DATA_DIR, untar_filename) if not os.path.exists(self.w2v_path): get_file(tar_filename, url, extract=True, cache_dir=text2vec.USER_DIR, cache_subdir=text2vec.USER_DATA_DIR, verbose=1) t0 = time.time() w2v = KeyedVectors.load_word2vec_format(self.w2v_path, **self.w2v_kwargs) # w2v.init_sims(replace=True) logger.debug('load w2v from %s, spend %s s' % (self.w2v_path, time.time() - t0)) token2idx = { self.processor.token_pad: 0, self.processor.token_unk: 1, self.processor.token_bos: 2, self.processor.token_eos: 3 } for token in w2v.key_to_index: token2idx[token] = len(token2idx) vector_matrix = np.zeros((len(token2idx), w2v.vector_size)) vector_matrix[1] = np.random.rand(w2v.vector_size) vector_matrix[4:] = w2v.vectors self.embedding_size = w2v.vector_size self.w2v_vector_matrix = vector_matrix self.w2v_token2idx = token2idx self.w2v_top_words = w2v.index_to_key[:50] self.w2v_model_loaded = True self.w2v = w2v self.processor.token2idx = self.w2v_token2idx self.processor.idx2token = dict([ (value, key) for key, value in self.w2v_token2idx.items() ]) logger.debug('word count: {}'.format(len(self.w2v_vector_matrix))) logger.debug('emb size: {}'.format(self.embedding_size)) logger.debug('top 50 word: {}'.format(self.w2v_top_words)) logger.debug('filter stopwords: {}, count: {}'.format( sorted(list(self.stopwords))[:10], len(self.stopwords))) self.tokenizer = Tokenizer()
def test_bin_file(): _url = 'https://www.borntowin.cn/mm/emb_models/sentence_w2v.bin' file_path = get_file('sentence_w2v.bin', _url, extract=True, cache_dir=text2vec.USER_DIR, cache_subdir=text2vec.USER_DATA_DIR) print("file_path:", file_path) if os.path.exists(file_path): shutil.rmtree(file_path)
def test_get_zip_file(): _url = "https://raw.githubusercontent.com/pengming617/bert_textMatching/master/data/train.txt" file_path = get_file('train.txt', _url, extract=True, cache_dir='./', cache_subdir='./', verbose=1) print("file_path:", file_path) if os.path.exists(file_path): shutil.rmtree(file_path)