class TextMatch(object): def __init__(self, ip_port, if_write=False): """ :param ip_port: faiss url :param if_write: """ self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device']) self.train_represents = np.zeros(shape=()) self.train_texts = self.data.train_texts # 读取同义词辞典 with open(synonyms_path, 'r') as rf: self.synonyms = yaml.load(rf.read(), Loader=yaml.FullLoader) # 场景匹配的初始化 self.scene_texts = [] self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding( glove_emb_path, self.data.char_alphabet) self.char_embedding = nn.Embedding(self.data.char_alphabet_size, self.char_emb_dim) self.char_embedding.weight.data.copy_( torch.from_numpy(self.pretrain_char_embedding)) if self.model.configs['gpu']: self.char_embedding = self.char_embedding.cuda() # if if_write: # self.write_index() # tf-idf模型初始化 self.vectorizer = CharWeight.load_model(char_weight_model_dir) # faiss服务 self.channel = grpc.insecure_channel(ip_port) self.stub = FaissServerStub(self.channel) def close(self): self.channel.close() # def write_index(self): # """ # 余弦和欧式距离等价: # https://www.zhihu.com/question/19640394 # :return: # """ # import faiss # with open(represent_path, 'rb') as rbf: # self.train_represents = pickle.load(rbf) # self.train_represents = np.array(self.train_represents).astype('float32') # d = self.model.configs['num_output'] # nlist = self.data.label_alphabet_size - 1 # # L2距离计算方式 # # quantizer = faiss.IndexFlatL2(d) # # index = faiss.IndexIVFFlat(quantizer, d, nlist) # # 余弦计算方式 # quantizer = faiss.IndexFlatIP(d) # index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) # faiss.normalize_L2(self.train_represents) # 规一化 # # print(index.is_trained) # index.train(self.train_represents) # index.add_with_ids(self.train_represents, np.arange(self.train_represents.shape[0])) # index.nprobe = 10 # faiss.write_index(index, index_path) def inference(self, text): """ :param text: :return: """ texts, ids = [], [] seg_list = self.data.segment([text])[0] seg_list = self.synonyms_replace(seg_list) # 同义词替换 # print('text: %s, seg_list: %s' % (text, seg_list)) if len(seg_list) == 0: return 1, None, None char_list, char_id_list, word_id_list, = [], [], [] for word in seg_list: word_id = self.data.word_alphabet.get_index(normalize_word(word)) word_id_list.append(word_id) chars, char_ids = [], [] if self.data.specific_word(word): chars.append(word) char_ids.append( self.data.char_alphabet.get_index(normalize_word(word))) else: for char in word: chars.append(char) char_ids.append( self.data.char_alphabet.get_index( normalize_word(char))) char_list.append(chars) char_id_list.append(char_ids) texts.append([seg_list, char_list]) ids.append([word_id_list, char_id_list]) batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, mask = \ predict_batchfy_classification_with_label(ids, self.model.configs['gpu'], if_train=False) pred_represent = self.model(batch_word, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask) pred_represent = pred_represent.data.numpy() # ori_pred_represent = pred_represent # faiss.normalize_L2(pred_represent) # numpy改写faiss.normalize_L2 pred_represent = pred_represent / np.linalg.norm(pred_represent, ord=2) pred_represent = pred_represent.tolist()[0] faiss_start = datetime.datetime.now() D, I = self.search(self.stub, pred_represent) logger.info('Faiss search costs: %s' % (datetime.datetime.now() - faiss_start).total_seconds()) if D > 0 and I > 0: max_id = I[0][0] max_score = D max_similar_text = self.train_texts[max_id] pred_text = ''.join(max_similar_text[0]) pred_label = max_similar_text[-1] if pred_label == 'None': pred_label = None return max_score, pred_text, pred_label else: # 如果faiss调用失败,返回默认得分和标签 return 0, None, None def synonyms_replace(self, seg_list): result_list = [] for word in seg_list: if word in self.synonyms.keys(): synonym = self.synonyms[word] result_list.append(synonym) else: result_list.append(word) return result_list def inference_for_scene_with_glove(self, text, text_list, label_list): # 预处理scene_texts scene_chars, scene_ids = self.data.read_scene_text_list(text_list) # 计算weight # s = datetime.datetime.now() sen_weights = self.cal_char_weight(scene_chars, scene_ids) # print('cal_char_weight costs: %s' % (datetime.datetime.now() - s).total_seconds()) # 计算对应weight下的句子表征 scene_represents = self.cal_scene_represents(scene_ids, sen_weights) # 处理当前input_text: chars, ids = [], [] for char in text: chars.append(char) ids.append(self.data.char_alphabet.get_index(normalize_word(char))) if len(chars) == 0: return 1, None, None input_weights = self.cal_char_weight([chars], [ids]) pred_represent = self.cal_scene_represents([ids], input_weights) max_score, pred_text, pred_label = self.cal_similarity( pred_represent, scene_represents, text_list, label_list) if pred_label == 'None': pred_label = None # 置信度、最接近的text,最接近的label return max_score, pred_text, pred_label # 计算scene_text中每一个字符的权重(保持原顺序) def cal_char_weight(self, chars, ids): new_chars, new_ids, sen_weights = [], [], [] alphabet_unknow_id = self.data.char_alphabet.get_index( self.data.char_alphabet.UNKNOWN) vectorizer_da_id = self.vectorizer.vocabulary_['打'] vectorizer_kai_id = self.vectorizer.vocabulary_['开'] for sen_char, sen_id in zip(chars, ids): new_char = ' ' new_id = [] for char, id in zip(sen_char, sen_id): # 当字符oov则将字符替换为unk if id == alphabet_unknow_id: new_char += ' unk' new_id.append(id) # 替换'场' '景' '模' '式' ->'打' '开' elif char in ['场', '模']: new_char += ' 打' new_id.append(vectorizer_da_id) elif char in ['景', '式']: new_char += ' 开' new_id.append(vectorizer_kai_id) else: new_char = new_char + ' ' + char new_id.append(id) new_char = new_char.strip() new_ids.append(new_id) new_chars.append(new_char) # 权重模型仅inference一次 tf_idf_output = self.vectorizer.transform(new_chars) for i, sen_id in enumerate(new_ids): sen_weight = [] for id in sen_id: if id == alphabet_unknow_id: sen_weight.append(tf_idf_output[i, 0]) else: sen_weight.append(tf_idf_output[i, id]) sen_weights.append(sen_weight) return sen_weights # 根据char_weight计算句子表征 def cal_scene_represents(self, scene_ids, sen_weights): scene_represents = [] for char_id, sen_weight in zip(scene_ids, sen_weights): char_input = torch.tensor(char_id, dtype=torch.long).to( self.model.configs['device']) char_embedding = self.char_embedding(char_input) try: assert char_embedding.shape[0] == len(sen_weight) except AssertionError: logger.info('check length of sen_weight') else: new_char_embedding = [] for ce, sw in zip(char_embedding, sen_weight): new_char_embedding.append(ce.cpu().data.numpy() * sw) new_char_embedding = torch.tensor( new_char_embedding, dtype=torch.float).to(self.model.configs['device']) sentence_embedding = torch.mean(new_char_embedding, dim=0) sentence_embedding = sentence_embedding.cpu().data.numpy( ).tolist() scene_represents.append(sentence_embedding) return scene_represents def cal_similarity(self, pred_represent, train_represents, text_list, label_list): score = cosine_similarity(pred_represent, train_represents) max_id, max_score = np.argmax(score, axis=-1)[0], np.max(score, axis=-1)[0] max_simialr_text, max_similar_label = text_list[max_id], label_list[ max_id] return max_score, max_simialr_text, max_similar_label @staticmethod def search(stub, input_vector, topn=1, index_name='text_match'): float_vector = FloatVector() for i in input_vector: float_vector.fvec.append(i) search_request = SearchRequest(index_name=index_name, vector=float_vector, topn=topn) try: response = stub.search(search_request, timeout=0.1) # get faiss服务异常: except Exception as exc: logger.error("Respose failed: {}".format( traceback.format_exc())) # format_exc:将异常信息记录在log里 return -1, -1 if response.success: _D = response.D D = [] for v in _D.fmatrix: d = [] for e in v.fvec: d.append(e) D.append(d) D = np.array(D) _I = response.I I = [] for v in _I.imatrix: i = [] for e in v.ivec: i.append(e) I.append(i) I = np.array(I) if isinstance(D, np.ndarray): D = D[0][0] return D, I else: # 如果faiss调用失败,返回默认值 logger.error("Faiss server failed, Return default value.") return -1, -1
class SceneMatch(object): def __init__(self): self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device']) # self.no_train_texts, self.no_train_represents, self.no_train_label_ids = get_represents( # self.data, self.model, 'add', self.model.configs) def inference(self, text, text_list, label_list): texts, ids = self.data.read_scene_text_list(text_list, label_list) self.data.scene_texts, self.data.scene_ids = texts, ids self.scene_texts, scene_represents, scene_label_ids = get_represents( self.data, self.model, 'scene', self.model.configs) # 处理当前传入的用户input_text texts, ids = [], [] seg_list = self.data.segment([text])[0] if len(seg_list) == 0: return None, None, None # print('seg_list: %s' % seg_list) char_list, char_id_list, word_id_list, = [], [], [] for word in seg_list: word_id = self.data.word_alphabet.get_index(normalize_word(word)) word_id_list.append(word_id) chars, char_ids = [], [] if self.data.specific_word(word): chars.append(word) char_ids.append( self.data.char_alphabet.get_index(normalize_word(word))) else: for char in word: chars.append(char) char_ids.append( self.data.char_alphabet.get_index( normalize_word(char))) char_list.append(chars) char_id_list.append(char_ids) texts.append([seg_list, char_list]) ids.append([word_id_list, char_id_list]) batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, mask = \ predict_batchfy_classification_with_label(ids, self.model.configs['gpu'], if_train=False) pred_represent = self.model(batch_word, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask) max_score, max_similar_text = self.cal_similarity( pred_represent, scene_represents) pred_text = ''.join(max_similar_text[0]) pred_label = max_similar_text[-1] if pred_label == 'None': pred_label = None # 置信度、最接近的text,最接近的label return max_score, pred_text, pred_label def cal_similarity(self, pred_represent, train_represents): pred_represent = pred_represent.cpu().data.numpy() score = cosine_similarity(pred_represent, train_represents) max_id, max_score = np.argmax(score, axis=-1)[0], np.max(score, axis=-1)[0] # max_simialr_text, max_similar_label_id = self.train_texts[max_id], self.train_label_ids[max_id] max_simialr_text = self.scene_texts[max_id] return max_score, max_simialr_text