def train(train_sentences, trial_sentences, test_sentences, embeddings, epochs, batch_size, checkpoint_dir): with tf.Session() as sess: model = CNN(sess, 16) model.init() for e in range(epochs): print 'Epoch: {}'.format(e) nData = len(train_sentences[1]) offset = 0 # training highest_accuracy = 0.5 while offset < nData: size = min(nData - offset, batch_size / 2) idx = range(offset, offset + size) batch_x, batch_y = indices_to_vectors(idx, train_sentences, embeddings) model.train(batch_x, batch_y, 0.005) offset += size # testing idx = range(len(trial_sentences[1])) trial_x, trial_y = indices_to_vectors(idx, trial_sentences, embeddings) print 'Trial - ', accuracy = model.test(trial_x, trial_y) if accuracy > highest_accuracy: model.save(checkpoint_dir)
class MainProc(Proc): ''' @about 构造函数,只接受参数,不进行初始化 初始化在线程内完成 @param prop:程序配置模块 @return None ''' def __init__(self, prop, feed, msg_queue=None): sd = Shared() self.__logger = sd.getLogger() self.__prop = prop self.__feed = feed self.__msg_queue = msg_queue ''' @about 创建两个卷积网络对象 @param prop:程序配置对象 feed:数据提供对象 msg_queue:消息队列 @return None ''' def init(self): self.__isTrain = self.__prop.needTrain() self.__sess = tf.Session() self.__prop.updateAttr('session', self.__sess) self.__output_path = self.__prop.queryAttr('data_dir') self.__model_path = self.__prop.queryAttr('model_path') self.__ckpt_name = self.__prop.queryAttr('ckpt_name') self.__cnn_name = self.__prop.queryAttr('cnn_name') self.__batch_n = self.__prop.queryAttr('batch_n') self.__batch_h = self.__prop.queryAttr('batch_h') self.__batch_w = self.__prop.queryAttr('batch_w') self.__batch_c = self.__prop.queryAttr( 'sfeatures') + self.__prop.queryAttr('ifeatures') self.__cols = self.__prop.queryAttr('cols') # global photon cnn self.__global_fea = tf.placeholder( tf.float32, [self.__batch_n, self.__batch_h, self.__batch_w, self.__batch_c]) self.__global_img = tf.placeholder( tf.float32, [self.__batch_n, self.__batch_h, self.__batch_w, self.__cols]) self.__globalCNN = CNN(self.__prop, self.__cnn_name[0]) self.__globalCNN.build(self.__global_img, self.__global_fea) # self.__globalCNN.init() # caustic photon cnn self.__caustic_fea = tf.placeholder( tf.float32, [self.__batch_n, self.__batch_h, self.__batch_w, self.__batch_c]) self.__caustic_img = tf.placeholder( tf.float32, [self.__batch_n, self.__batch_h, self.__batch_w, self.__cols]) self.__causticCNN = CNN(self.__prop, self.__cnn_name[1]) self.__causticCNN.build(self.__caustic_img, self.__caustic_fea) # self.__causticCNN.init() # other configuration in train mode if self.__isTrain: self.__meta_name = self.__prop.queryAttr('meta_name') self.__loss_func = self.__prop.queryAttr('loss_func') self.__optimizer = self.__prop.queryAttr('optimizer') self.__max_round = self.__prop.queryAttr('max_round') self.__save_round = self.__prop.queryAttr('save_round') self.__learning_rate = self.__prop.queryAttr('learning_rate') ''' @about 线程调用入口 封装此模块主要流程 @param None @return None ''' def run(self): self.init() predict = self.preprocess(self.__prop) result = self.process(self.__feed, predict) self.postprocess(result) ''' @about 开启守护线程,执行此模块主要逻辑 @param None @return None ''' def start(self): wrk = Thread(target=self.run, args=(), name='Worker') wrk.setDaemon(True) wrk.start() ''' @about 向消息队列中发送数据 @param msg:[current status %,current loss] @return None ''' def sendMsg(self, msg): if platform.system() != 'Windows': return try: self.__msg_queue.put(msg) except: self.__logger.error('message queue is not specified.') ''' @about 优化器 @param loss:损失值 learning_rate:学习率 type:优化器类型 @return 优化器 ''' def __getOptimizer(self, loss, learning_rate, type): self.__logger.debug('building optimizer...') if type.lower() == 'adam': result = tf.train.AdamOptimizer(learning_rate).minimize(loss) elif type.lower() == 'gradient': result = tf.train.GradientDescentOptimizer(learning_rate).minimize( loss) else: raise NotImplementedError self.__logger.debug('optimizer built.') return result ''' @about 损失函数,根据指定误差计算方式计算误差 @param input:预测值 output:目标值 type:损失函数类型 @return 损失值 ''' def __getLoss(self, input, output, type='l1'): self.__logger.debug('computing loss...') assert (input.shape == output.shape) if type == 'l1': result = tf.reduce_mean(tf.abs(tf.subtract(input, output))) elif type == 'cross_entropy': # 此方式表现不佳 result = tf.reduce_mean(-tf.reduce_sum(output * tf.log(input))) else: raise NotImplementedError self.__logger.debug('loss computed.') return result ''' @param @about @return ''' def __updateMeta(self, filename): print(filename) try: obj = utils.readJson(filename) obj['iterations'] += self.__save_round except: obj = {} obj['name'] = filename obj['iterations'] = self.__save_round obj['active_func'] = self.__prop.queryAttr('active_func') obj['weights_shape'] = self.__prop.queryAttr('weights_shape') utils.writeJson(obj, filename) ''' @about 两个卷积网络进行滤波,然后合并结果 @param prop: input2: @return 预测值,Tensor ''' def preprocess(self, prop): self.__logger.debug('preprocessing...') # filter images separately glob = self.__globalCNN.process(Filter(prop)) caus = self.__causticCNN.process(Filter(prop)) assert (glob.shape == caus.shape) predict = glob + caus return predict ''' @about 主要处理流程,处理来自preprocess的predict @param feed: 数据填充模块 predict:预测值 @return 处理过的predict ''' def process(self, feed, predict): self.__logger.debug('processing...') ishape, tshape = feed.getBatchShapes() # model storage path ckpt_global = self.__model_path + self.__cnn_name[0] + '/' ckpt_caustic = self.__model_path + self.__cnn_name[1] + '/' # train when in train mode if self.__isTrain: self.__logger.debug('training...') truth = tf.placeholder(tf.float32, ishape) # loss = tf.reduce_mean(tf.abs(tf.subtract(predict, truth))) loss = self.__getLoss(predict, truth, self.__loss_func) # step = tf.train.AdamOptimizer(self.__learning_rate).minimize(loss) step = self.__getOptimizer(loss, self.__learning_rate, self.__optimizer) # 恢复与随机初始化两网络 self.__globalCNN.init(ckpt_global) self.__causticCNN.init(ckpt_caustic) # 用于绘制进度条 # bar = LineProgress(title='status', total=self.__max_round) # 训练 for i in range(self.__max_round): gi, gf, ci, cf, gt = feed.next_batch() self.__sess.run(step, feed_dict={ self.__caustic_img: ci, self.__global_img: gi, self.__caustic_fea: cf, self.__global_fea: gf, truth: gt }) xloss = self.__sess.run(loss, feed_dict={ self.__caustic_img: ci, self.__global_img: gi, self.__caustic_fea: cf, self.__global_fea: gf, truth: gt }) self.sendMsg([i / self.__max_round, xloss]) # xpred = self.__sess.run(predict, feed_dict={self.__caustic_img: ci, self.__global_img: gi, # self.__caustic_fea: cf, self.__global_fea: gf, # truth: gt}) # utils.displayImage(xpred) print('round:%d of %d,loss: %f...' % (i + 1, self.__max_round, xloss)) self.__logger.info('round:%d of %d,loss:%f...' % (i + 1, self.__max_round, xloss)) # print("status: {:.2f}%".format(float((i + 1) / self.__max_round)), end="\r") # bar.update((i + 1) / self.__max_round * 100) # 保存结果 if i % self.__save_round == (self.__save_round - 1): self.__globalCNN.save(ckpt_global, self.__ckpt_name, self.__save_round) self.__causticCNN.save(ckpt_caustic, self.__ckpt_name, self.__save_round) for cnn in self.__cnn_name: path = self.__model_path + '/' + cnn + '/' + self.__meta_name self.__updateMeta(path) # infer模式下直接输出 else: self.__logger.debug('inferring...') # 恢复与随机初始化两网络 self.__globalCNN.init(ckpt_global) self.__causticCNN.init(ckpt_caustic) gi, gf, ci, cf = feed.getInputdata() result = self.__sess.run(predict, feed_dict={ self.__caustic_img: ci, self.__global_img: gi, self.__caustic_fea: cf, self.__global_fea: gf }) return result return None ''' @about process之后执行 @param input:process的输出 @return None ''' def postprocess(self, input1): sd = Shared() self.__logger.debug('postprocessing...') if not self.__isTrain: # path of test data file save_path = self.__output_path utils.saveImage(input1, save_path + 'infer.png') # utils.displayImage(input) sd.setFlag('nowExit', True) print('done') '''