def write_represents_to_pkl(path, output_path, name='train'): data = Data(ori_texts=[], labels=[], if_train=False) with open(path, 'rb') as rbf: data.char_alphabet.instance2index = pickle.load(rbf) data.word_alphabet.instance2index = pickle.load(rbf) data.label_alphabet.instance2index = pickle.load(rbf) data.char_alphabet_size = pickle.load(rbf) data.word_alphabet_size = pickle.load(rbf) data.label_alphabet_size = pickle.load(rbf) data.label_alphabet.instances = pickle.load(rbf) data.train_texts = pickle.load(rbf) data.train_ids = pickle.load(rbf) data.fix_alphabet() model = TextMatchModel(data) model.load_state_dict( torch.load(model_dir, map_location=model.configs['map_location'])) model.eval() model.to(model.configs['device']) train_texts, train_represents, train_label_ids = get_represents( data, model, name, model.configs) # 写入 # with open(path, 'ab') as abf: # pickle.dump(train_texts, abf) # pickle.dump(train_represents, abf) # pickle.dump(train_label_ids, abf) with open(output_path, 'wb') as wbf: pickle.dump(train_represents, wbf)
def __init__(self, if_train=False): if if_train: self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) # vocabulary中增加一个unk,作为OOV的处理 if self.data.char_alphabet.get_instance(0) is None: self.data.char_alphabet.instance2index['unk'] = 0
def __init__(self): self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device'])
def __init__(self, ip_port, if_write=False): """ :param ip_port: faiss url :param if_write: """ self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device']) self.train_represents = np.zeros(shape=()) self.train_texts = self.data.train_texts # 读取同义词辞典 with open(synonyms_path, 'r') as rf: self.synonyms = yaml.load(rf.read(), Loader=yaml.FullLoader) # 场景匹配的初始化 self.scene_texts = [] self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding( glove_emb_path, self.data.char_alphabet) self.char_embedding = nn.Embedding(self.data.char_alphabet_size, self.char_emb_dim) self.char_embedding.weight.data.copy_( torch.from_numpy(self.pretrain_char_embedding)) if self.model.configs['gpu']: self.char_embedding = self.char_embedding.cuda() # if if_write: # self.write_index() # tf-idf模型初始化 self.vectorizer = CharWeight.load_model(char_weight_model_dir) # faiss服务 self.channel = grpc.insecure_channel(ip_port) self.stub = FaissServerStub(self.channel)
class CharWeight(object): def __init__(self, if_train=False): if if_train: self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) # vocabulary中增加一个unk,作为OOV的处理 if self.data.char_alphabet.get_instance(0) is None: self.data.char_alphabet.instance2index['unk'] = 0 def train(self): train_texts = [] for train_text in self.data.train_texts: intent = train_text[-1] # if intent: if intent in ['turn_on', 'turn_off']: text = '' words = train_text[0] for word in words: if not self.data.specific_word(word): text += word text = ' '.join(text) train_texts.append(text) vectorizer = TfidfVectorizer( token_pattern=r'(?u)\b\w+\b', vocabulary=self.data.char_alphabet.instance2index) vectorizer.fit(train_texts) # 实例化权重模型 joblib.dump(vectorizer, model_path) @classmethod def load_model(cls, model_dir): vectorizer = joblib.load(model_dir) return vectorizer
class TextMatch(object): def __init__(self, ip_port, if_write=False): """ :param ip_port: faiss url :param if_write: """ self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device']) self.train_represents = np.zeros(shape=()) self.train_texts = self.data.train_texts # 读取同义词辞典 with open(synonyms_path, 'r') as rf: self.synonyms = yaml.load(rf.read(), Loader=yaml.FullLoader) # 场景匹配的初始化 self.scene_texts = [] self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding( glove_emb_path, self.data.char_alphabet) self.char_embedding = nn.Embedding(self.data.char_alphabet_size, self.char_emb_dim) self.char_embedding.weight.data.copy_( torch.from_numpy(self.pretrain_char_embedding)) if self.model.configs['gpu']: self.char_embedding = self.char_embedding.cuda() # if if_write: # self.write_index() # tf-idf模型初始化 self.vectorizer = CharWeight.load_model(char_weight_model_dir) # faiss服务 self.channel = grpc.insecure_channel(ip_port) self.stub = FaissServerStub(self.channel) def close(self): self.channel.close() # def write_index(self): # """ # 余弦和欧式距离等价: # https://www.zhihu.com/question/19640394 # :return: # """ # import faiss # with open(represent_path, 'rb') as rbf: # self.train_represents = pickle.load(rbf) # self.train_represents = np.array(self.train_represents).astype('float32') # d = self.model.configs['num_output'] # nlist = self.data.label_alphabet_size - 1 # # L2距离计算方式 # # quantizer = faiss.IndexFlatL2(d) # # index = faiss.IndexIVFFlat(quantizer, d, nlist) # # 余弦计算方式 # quantizer = faiss.IndexFlatIP(d) # index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) # faiss.normalize_L2(self.train_represents) # 规一化 # # print(index.is_trained) # index.train(self.train_represents) # index.add_with_ids(self.train_represents, np.arange(self.train_represents.shape[0])) # index.nprobe = 10 # faiss.write_index(index, index_path) def inference(self, text): """ :param text: :return: """ texts, ids = [], [] seg_list = self.data.segment([text])[0] seg_list = self.synonyms_replace(seg_list) # 同义词替换 # print('text: %s, seg_list: %s' % (text, seg_list)) if len(seg_list) == 0: return 1, None, None char_list, char_id_list, word_id_list, = [], [], [] for word in seg_list: word_id = self.data.word_alphabet.get_index(normalize_word(word)) word_id_list.append(word_id) chars, char_ids = [], [] if self.data.specific_word(word): chars.append(word) char_ids.append( self.data.char_alphabet.get_index(normalize_word(word))) else: for char in word: chars.append(char) char_ids.append( self.data.char_alphabet.get_index( normalize_word(char))) char_list.append(chars) char_id_list.append(char_ids) texts.append([seg_list, char_list]) ids.append([word_id_list, char_id_list]) batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, mask = \ predict_batchfy_classification_with_label(ids, self.model.configs['gpu'], if_train=False) pred_represent = self.model(batch_word, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask) pred_represent = pred_represent.data.numpy() # ori_pred_represent = pred_represent # faiss.normalize_L2(pred_represent) # numpy改写faiss.normalize_L2 pred_represent = pred_represent / np.linalg.norm(pred_represent, ord=2) pred_represent = pred_represent.tolist()[0] faiss_start = datetime.datetime.now() D, I = self.search(self.stub, pred_represent) logger.info('Faiss search costs: %s' % (datetime.datetime.now() - faiss_start).total_seconds()) if D > 0 and I > 0: max_id = I[0][0] max_score = D max_similar_text = self.train_texts[max_id] pred_text = ''.join(max_similar_text[0]) pred_label = max_similar_text[-1] if pred_label == 'None': pred_label = None return max_score, pred_text, pred_label else: # 如果faiss调用失败,返回默认得分和标签 return 0, None, None def synonyms_replace(self, seg_list): result_list = [] for word in seg_list: if word in self.synonyms.keys(): synonym = self.synonyms[word] result_list.append(synonym) else: result_list.append(word) return result_list def inference_for_scene_with_glove(self, text, text_list, label_list): # 预处理scene_texts scene_chars, scene_ids = self.data.read_scene_text_list(text_list) # 计算weight # s = datetime.datetime.now() sen_weights = self.cal_char_weight(scene_chars, scene_ids) # print('cal_char_weight costs: %s' % (datetime.datetime.now() - s).total_seconds()) # 计算对应weight下的句子表征 scene_represents = self.cal_scene_represents(scene_ids, sen_weights) # 处理当前input_text: chars, ids = [], [] for char in text: chars.append(char) ids.append(self.data.char_alphabet.get_index(normalize_word(char))) if len(chars) == 0: return 1, None, None input_weights = self.cal_char_weight([chars], [ids]) pred_represent = self.cal_scene_represents([ids], input_weights) max_score, pred_text, pred_label = self.cal_similarity( pred_represent, scene_represents, text_list, label_list) if pred_label == 'None': pred_label = None # 置信度、最接近的text,最接近的label return max_score, pred_text, pred_label # 计算scene_text中每一个字符的权重(保持原顺序) def cal_char_weight(self, chars, ids): new_chars, new_ids, sen_weights = [], [], [] alphabet_unknow_id = self.data.char_alphabet.get_index( self.data.char_alphabet.UNKNOWN) vectorizer_da_id = self.vectorizer.vocabulary_['打'] vectorizer_kai_id = self.vectorizer.vocabulary_['开'] for sen_char, sen_id in zip(chars, ids): new_char = ' ' new_id = [] for char, id in zip(sen_char, sen_id): # 当字符oov则将字符替换为unk if id == alphabet_unknow_id: new_char += ' unk' new_id.append(id) # 替换'场' '景' '模' '式' ->'打' '开' elif char in ['场', '模']: new_char += ' 打' new_id.append(vectorizer_da_id) elif char in ['景', '式']: new_char += ' 开' new_id.append(vectorizer_kai_id) else: new_char = new_char + ' ' + char new_id.append(id) new_char = new_char.strip() new_ids.append(new_id) new_chars.append(new_char) # 权重模型仅inference一次 tf_idf_output = self.vectorizer.transform(new_chars) for i, sen_id in enumerate(new_ids): sen_weight = [] for id in sen_id: if id == alphabet_unknow_id: sen_weight.append(tf_idf_output[i, 0]) else: sen_weight.append(tf_idf_output[i, id]) sen_weights.append(sen_weight) return sen_weights # 根据char_weight计算句子表征 def cal_scene_represents(self, scene_ids, sen_weights): scene_represents = [] for char_id, sen_weight in zip(scene_ids, sen_weights): char_input = torch.tensor(char_id, dtype=torch.long).to( self.model.configs['device']) char_embedding = self.char_embedding(char_input) try: assert char_embedding.shape[0] == len(sen_weight) except AssertionError: logger.info('check length of sen_weight') else: new_char_embedding = [] for ce, sw in zip(char_embedding, sen_weight): new_char_embedding.append(ce.cpu().data.numpy() * sw) new_char_embedding = torch.tensor( new_char_embedding, dtype=torch.float).to(self.model.configs['device']) sentence_embedding = torch.mean(new_char_embedding, dim=0) sentence_embedding = sentence_embedding.cpu().data.numpy( ).tolist() scene_represents.append(sentence_embedding) return scene_represents def cal_similarity(self, pred_represent, train_represents, text_list, label_list): score = cosine_similarity(pred_represent, train_represents) max_id, max_score = np.argmax(score, axis=-1)[0], np.max(score, axis=-1)[0] max_simialr_text, max_similar_label = text_list[max_id], label_list[ max_id] return max_score, max_simialr_text, max_similar_label @staticmethod def search(stub, input_vector, topn=1, index_name='text_match'): float_vector = FloatVector() for i in input_vector: float_vector.fvec.append(i) search_request = SearchRequest(index_name=index_name, vector=float_vector, topn=topn) try: response = stub.search(search_request, timeout=0.1) # get faiss服务异常: except Exception as exc: logger.error("Respose failed: {}".format( traceback.format_exc())) # format_exc:将异常信息记录在log里 return -1, -1 if response.success: _D = response.D D = [] for v in _D.fmatrix: d = [] for e in v.fvec: d.append(e) D.append(d) D = np.array(D) _I = response.I I = [] for v in _I.imatrix: i = [] for e in v.ivec: i.append(e) I.append(i) I = np.array(I) if isinstance(D, np.ndarray): D = D[0][0] return D, I else: # 如果faiss调用失败,返回默认值 logger.error("Faiss server failed, Return default value.") return -1, -1
class SceneMatch(object): def __init__(self): self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device']) # self.no_train_texts, self.no_train_represents, self.no_train_label_ids = get_represents( # self.data, self.model, 'add', self.model.configs) def inference(self, text, text_list, label_list): texts, ids = self.data.read_scene_text_list(text_list, label_list) self.data.scene_texts, self.data.scene_ids = texts, ids self.scene_texts, scene_represents, scene_label_ids = get_represents( self.data, self.model, 'scene', self.model.configs) # 处理当前传入的用户input_text texts, ids = [], [] seg_list = self.data.segment([text])[0] if len(seg_list) == 0: return None, None, None # print('seg_list: %s' % seg_list) char_list, char_id_list, word_id_list, = [], [], [] for word in seg_list: word_id = self.data.word_alphabet.get_index(normalize_word(word)) word_id_list.append(word_id) chars, char_ids = [], [] if self.data.specific_word(word): chars.append(word) char_ids.append( self.data.char_alphabet.get_index(normalize_word(word))) else: for char in word: chars.append(char) char_ids.append( self.data.char_alphabet.get_index( normalize_word(char))) char_list.append(chars) char_id_list.append(char_ids) texts.append([seg_list, char_list]) ids.append([word_id_list, char_id_list]) batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, mask = \ predict_batchfy_classification_with_label(ids, self.model.configs['gpu'], if_train=False) pred_represent = self.model(batch_word, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask) max_score, max_similar_text = self.cal_similarity( pred_represent, scene_represents) pred_text = ''.join(max_similar_text[0]) pred_label = max_similar_text[-1] if pred_label == 'None': pred_label = None # 置信度、最接近的text,最接近的label return max_score, pred_text, pred_label def cal_similarity(self, pred_represent, train_represents): pred_represent = pred_represent.cpu().data.numpy() score = cosine_similarity(pred_represent, train_represents) max_id, max_score = np.argmax(score, axis=-1)[0], np.max(score, axis=-1)[0] # max_simialr_text, max_similar_label_id = self.train_texts[max_id], self.train_label_ids[max_id] max_simialr_text = self.scene_texts[max_id] return max_score, max_simialr_text
def write_represents_to_txt(path, output_path, name='train'): data = Data(ori_texts=[], labels=[], if_train=False) with open(path, 'rb') as rbf: data.char_alphabet.instance2index = pickle.load(rbf) data.word_alphabet.instance2index = pickle.load(rbf) data.label_alphabet.instance2index = pickle.load(rbf) data.char_alphabet_size = pickle.load(rbf) data.word_alphabet_size = pickle.load(rbf) data.label_alphabet_size = pickle.load(rbf) data.label_alphabet.instances = pickle.load(rbf) data.train_texts = pickle.load(rbf) data.train_ids = pickle.load(rbf) data.fix_alphabet() model = TextMatchModel(data) model.load_state_dict( torch.load(model_dir, map_location=model.configs['map_location'])) model.eval() model.to(model.configs['device']) data.no_train_texts, data.no_train_ids = data.read_no_train(no_train_path) train_texts, train_represents, train_label_ids = get_represents( data, model, name, model.configs) if not os.path.exists(output_path + '/train_texts.txt'): with open(output_path + '/train_texts.txt', 'w') as wf: for item in train_texts: wf.write('%s\n' % item) with open(output_path + '/train_represents.txt', 'w') as wf: for item in train_represents: wf.write('%s\n' % item) with open(output_path + '/train_label_ids.txt', 'w') as wf: for item in train_label_ids: wf.write('%s\n' % item)
def train(self): data = Data.read_data() match_configs = data.match_configs model = TextMatchModel(data) model_configs = model.configs if model_configs['gpu']: model = model.cuda() batch_size = match_configs['batch_size'] model_configs.update({'batch_size': batch_size}) optimizer = optim.Adam(model.parameters(), lr=model_configs['lr'], weight_decay=model_configs['l2']) if model_configs['gpu']: model = model.cuda() best_dev = -10 last_improved = 0 logger.info('train start:%s', datetime.datetime.now()) for idx in range(model_configs['epoch']): epoch_start = time.time() temp_start = epoch_start logger.info('Epoch: %s/%s' % (idx, model_configs['epoch'])) optimizer = self.lr_decay(optimizer, idx, model_configs['lr_decay'], model_configs['lr']) sample_loss = 0 total_loss = 0 # right_token = 0 # whole_token = 0 logging.info("first input word _list: %s, %s" % (data.train_texts[0][1], data.train_ids[0][1])) model.train() model.zero_grad() num_classes_per_batch = match_configs['num_classes_per_batch'] # 16 num_sentences_per_class = batch_size // num_classes_per_batch # 4 assert type(num_sentences_per_class) == int start = datetime.datetime.now() instances, total_batch = self.get_instances(data, num_classes_per_batch, num_sentences_per_class) logger.info('get_instances costs: %s' % (datetime.datetime.now() - start).total_seconds()) logger.info('total_batch: %s' % total_batch) train_num = len(instances) for batch_id in range(total_batch): start = batch_id * batch_size end = (batch_id + 1) * batch_size # end最后一定等于train_num instance = instances[start:end] # print(start, end) if not instance: continue # instance -> (word, char, label) batch_word, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, mask, batch_label = \ batchfy_classification_with_label(instance, model_configs['gpu'], if_train=True) loss, sen_represent = model( batch_word, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask, batch_label) sample_loss += loss.item() total_loss += loss.item() # 每10个batch,输出一下结果 if end % (batch_size * 10) == 0: temp_time = time.time() temp_cost = temp_time - temp_start temp_start = temp_time logger.info("Instance: %s; Time: %.2fs; loss: %.4f" % (end, temp_cost, sample_loss)) if sample_loss > 1e4 or str(sample_loss) == 'nan': raise ValueError("ERROR: LOSS EXPLOSION (>1e4) !") sample_loss = 0 loss.backward() optimizer.step() model.zero_grad() temp_time = time.time() temp_cost = temp_time - temp_start logger.info("Instance: %s; Time: %.2fs; loss: %.4f" % (end, temp_cost, sample_loss)) epoch_finish = time.time() epoch_cost = epoch_finish - epoch_start logger.info("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s" % ( idx, epoch_cost, train_num / epoch_cost, total_loss)) if total_loss > 1e4 or str(total_loss) == 'nan': raise ValueError("ERROR: LOSS EXPLOSION (>1e4) !") writer.add_scalar('Train_loss', total_loss, idx) # 计算当前模型下,训练集的句子表征 _, train_represents, train_labels = get_represents(data, model, 'train', model_configs) # dev的验证 speed, acc, p, r, f, _, _ = evalute(data, model, 'dev', model_configs, train_represents, train_labels) dev_finish = time.time() dev_cost = dev_finish - epoch_finish current_score = acc writer.add_scalar('Dev_acc', current_score, idx) writer.add_scalar('Dev_f1', f, idx) logger.info( "Dev: time: %.2fs speed: %.2fst/s; acc: %.4f weighted_avg_f1: %.4f" % (dev_cost, speed, acc, f)) if current_score > best_dev: logger.info("Exceed previous best acc score: %s" % best_dev) model_name = os.path.join(ROOT_PATH, model_configs['model_path'] + '.model') torch.save(model.state_dict(), model_name) best_dev = current_score last_improved = idx # test的验证 speed, acc, p, r, f, _, _ = evalute(data, model, 'test', model_configs, train_represents, train_labels) test_finish = time.time() test_cost = test_finish - dev_finish writer.add_scalar('Test_acc', acc, idx) writer.add_scalar('Test_f1', f, idx) logger.info( "Test: time: %.2fs, speed: %.2fst/s; acc: %.4f weighted_avg_f1: %.4f" % (test_cost, speed, acc, f)) # early_stopping if idx - last_improved > model_configs['require_improvement']: logger.info('No optimization for %s epoch, auto-stopping' % model_configs['require_improvement']) writer.close() # 将所有训练样本的表征写入represent.dset main() break writer.close()
def random_embedding(vocab_size, embedding_dim): pretrain_emb = np.empty([vocab_size, embedding_dim]) scale = np.sqrt(3.0 / embedding_dim) for index in range(vocab_size): pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim]) return pretrain_emb if __name__ == '__main__': # 场景匹配的demo: dset_path = os.path.join(ROOT_PATH, 'models/text_match_v1/data/alphabet.dset') model_dir = os.path.join(ROOT_PATH, 'saved_models/text_match_v1/text_match_v1.model') data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: data.char_alphabet.instance2index = pickle.load(rbf) data.word_alphabet.instance2index = pickle.load(rbf) data.label_alphabet.instance2index = pickle.load(rbf) data.char_alphabet_size = pickle.load(rbf) data.word_alphabet_size = pickle.load(rbf) data.label_alphabet_size = pickle.load(rbf) data.label_alphabet.instances = pickle.load(rbf) data.train_texts = pickle.load(rbf) data.fix_alphabet() model = TextMatchModel(data) model.load_state_dict( torch.load(model_dir, map_location=model.configs['map_location'])) model.eval() model.to(model.configs['device'])