def write_represents_to_txt(path, output_path, name='train'): data = Data(ori_texts=[], labels=[], if_train=False) with open(path, 'rb') as rbf: data.char_alphabet.instance2index = pickle.load(rbf) data.word_alphabet.instance2index = pickle.load(rbf) data.label_alphabet.instance2index = pickle.load(rbf) data.char_alphabet_size = pickle.load(rbf) data.word_alphabet_size = pickle.load(rbf) data.label_alphabet_size = pickle.load(rbf) data.label_alphabet.instances = pickle.load(rbf) data.train_texts = pickle.load(rbf) data.train_ids = pickle.load(rbf) data.fix_alphabet() model = TextMatchModel(data) model.load_state_dict( torch.load(model_dir, map_location=model.configs['map_location'])) model.eval() model.to(model.configs['device']) data.no_train_texts, data.no_train_ids = data.read_no_train(no_train_path) train_texts, train_represents, train_label_ids = get_represents( data, model, name, model.configs) if not os.path.exists(output_path + '/train_texts.txt'): with open(output_path + '/train_texts.txt', 'w') as wf: for item in train_texts: wf.write('%s\n' % item) with open(output_path + '/train_represents.txt', 'w') as wf: for item in train_represents: wf.write('%s\n' % item) with open(output_path + '/train_label_ids.txt', 'w') as wf: for item in train_label_ids: wf.write('%s\n' % item)
def write_represents_to_pkl(path, output_path, name='train'): data = Data(ori_texts=[], labels=[], if_train=False) with open(path, 'rb') as rbf: data.char_alphabet.instance2index = pickle.load(rbf) data.word_alphabet.instance2index = pickle.load(rbf) data.label_alphabet.instance2index = pickle.load(rbf) data.char_alphabet_size = pickle.load(rbf) data.word_alphabet_size = pickle.load(rbf) data.label_alphabet_size = pickle.load(rbf) data.label_alphabet.instances = pickle.load(rbf) data.train_texts = pickle.load(rbf) data.train_ids = pickle.load(rbf) data.fix_alphabet() model = TextMatchModel(data) model.load_state_dict( torch.load(model_dir, map_location=model.configs['map_location'])) model.eval() model.to(model.configs['device']) train_texts, train_represents, train_label_ids = get_represents( data, model, name, model.configs) # 写入 # with open(path, 'ab') as abf: # pickle.dump(train_texts, abf) # pickle.dump(train_represents, abf) # pickle.dump(train_label_ids, abf) with open(output_path, 'wb') as wbf: pickle.dump(train_represents, wbf)
def __init__(self, if_train=False): if if_train: self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) _ = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) # vocabulary中增加一个unk,作为OOV的处理 if self.data.char_alphabet.get_instance(0) is None: self.data.char_alphabet.instance2index['unk'] = 0
def __init__(self): self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device'])
def __init__(self, ip_port, if_write=False): """ :param ip_port: faiss url :param if_write: """ self.data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: self.data.char_alphabet.instance2index = pickle.load(rbf) self.data.word_alphabet.instance2index = pickle.load(rbf) self.data.label_alphabet.instance2index = pickle.load(rbf) self.data.char_alphabet_size = pickle.load(rbf) self.data.word_alphabet_size = pickle.load(rbf) self.data.label_alphabet_size = pickle.load(rbf) self.data.label_alphabet.instances = pickle.load(rbf) self.data.train_texts = pickle.load(rbf) self.data.fix_alphabet() self.model = TextMatchModel(self.data) self.model.load_state_dict( torch.load(model_dir, map_location=self.model.configs['map_location'])) self.model.eval() self.model.to(self.model.configs['device']) self.train_represents = np.zeros(shape=()) self.train_texts = self.data.train_texts # 读取同义词辞典 with open(synonyms_path, 'r') as rf: self.synonyms = yaml.load(rf.read(), Loader=yaml.FullLoader) # 场景匹配的初始化 self.scene_texts = [] self.pretrain_char_embedding, self.char_emb_dim = build_pretrain_embedding( glove_emb_path, self.data.char_alphabet) self.char_embedding = nn.Embedding(self.data.char_alphabet_size, self.char_emb_dim) self.char_embedding.weight.data.copy_( torch.from_numpy(self.pretrain_char_embedding)) if self.model.configs['gpu']: self.char_embedding = self.char_embedding.cuda() # if if_write: # self.write_index() # tf-idf模型初始化 self.vectorizer = CharWeight.load_model(char_weight_model_dir) # faiss服务 self.channel = grpc.insecure_channel(ip_port) self.stub = FaissServerStub(self.channel)
def random_embedding(vocab_size, embedding_dim): pretrain_emb = np.empty([vocab_size, embedding_dim]) scale = np.sqrt(3.0 / embedding_dim) for index in range(vocab_size): pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim]) return pretrain_emb if __name__ == '__main__': # 场景匹配的demo: dset_path = os.path.join(ROOT_PATH, 'models/text_match_v1/data/alphabet.dset') model_dir = os.path.join(ROOT_PATH, 'saved_models/text_match_v1/text_match_v1.model') data = Data(ori_texts=[], labels=[], if_train=False) with open(dset_path, 'rb') as rbf: data.char_alphabet.instance2index = pickle.load(rbf) data.word_alphabet.instance2index = pickle.load(rbf) data.label_alphabet.instance2index = pickle.load(rbf) data.char_alphabet_size = pickle.load(rbf) data.word_alphabet_size = pickle.load(rbf) data.label_alphabet_size = pickle.load(rbf) data.label_alphabet.instances = pickle.load(rbf) data.train_texts = pickle.load(rbf) data.fix_alphabet() model = TextMatchModel(data) model.load_state_dict( torch.load(model_dir, map_location=model.configs['map_location'])) model.eval() model.to(model.configs['device'])