def __init__(self, log_dir, save_dir): self.save_dir = Tools.new_dir(save_dir) self.log_dir = Tools.new_dir(log_dir) self.last_pool_size = 50 self.input_size = [self.last_pool_size * 8, self.last_pool_size * 8] pass
def build(self): # 提取特征,属于公共部分 block1, block2, block3, block4 = self._feature(self.input_data) blocks = [block1, block2, block3, block4] block4_shape = Tools.get_shape(block4) # 45, 512 block3_shape = Tools.get_shape(block3) # 90, 512 block2_shape = Tools.get_shape(block2) # 180, 256 block1_shape = Tools.get_shape(block1) # 360, 128 segments = [] segments_output = [] ###################################################### # 确定初始attention的输入点:建议在进入attention时输入 ###################################################### with tf.variable_scope(name_or_scope="attention_4"): net_output = self._decoder(block4, block4_shape, block3_shape, name="4") net_segment_output = self._segment(block4, block4_shape, block3_shape, name="segment_side_4") segments.append(net_segment_output) # segment pass segments_output.append(net_output) block3_add = Net.add([block3, net_output], name='attention_3_add') with tf.variable_scope(name_or_scope="attention_3"): net_output = self._decoder(block3_add, block3_shape, block2_shape, name="3") net_segment_output = self._segment(block3_add, block3_shape, block2_shape, name="segment_side_3") segments.append(net_segment_output) # segment pass segments_output.append(net_output) block2_add = Net.add([block2, net_output], name="attention_2_net_output_relu") with tf.variable_scope(name_or_scope="attention_2"): net_output = self._decoder(block2_add, block2_shape, block1_shape, name="2") net_segment_output = self._segment(block2_add, block2_shape, block1_shape, name="segment_side_2") segments.append(net_segment_output) # segment pass segments_output.append(net_output) block1_add = Net.add([block1, net_output], name="attention_1_concat") with tf.variable_scope(name_or_scope="attention_1"): net_output = self._decoder(block1_add, block1_shape, block1_shape, name="1") net_segment_output = self._segment(block1_add, block1_shape, block1_shape, name="segment_side_1") segments.append(net_segment_output) # segment pass segments_output.append(net_output) net_output = Net.conv(net_output, 3, 3, self.num_classes, 1, 1, biased=True, relu=False, name='attention_0_21') segments.append(net_output) # segment features = {"block": blocks, "segment": segments_output} return segments, features
def inference(self, image_path, image_index, save_path=None): im_data = Data.load_data(image_path=image_path, input_size=self.input_size) im_data = np.expand_dims(im_data, axis=0) pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op], feed_dict={self.image_placeholder: im_data}) self.summary_writer.add_summary(summary_now, global_step=image_index) s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8)) if save_path is None: s_image.show() else: Tools.new_dir(save_path) s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0])) pass
def __init__(self, input_size, log_dir, model_name="model.ckpt"): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.num_classes = 21 # 网络 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) # 网络 self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes) self.segments = self.net.build() self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass
def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain) total_loss = 0.0 pre_avg_loss = 0.0 for step in range(begin_step, self.num_steps): start_time = time.time() batch_data, batch_segment = self.data_reader.next_batch_train() # train_op = self.train_attention_op train_op = self.train_op if step % self.print_step == 0: # summary 3 _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run( [train_op, self.learning_rate, self.loss_segment_all, self.loss, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.label_seg_placeholder: batch_segment}) self.summary_writer.add_summary(summary_now, global_step=step) else: _, learning_rate_r, loss_segment_r, loss_r = self.sess.run( [train_op, self.learning_rate, self.loss_segment_all, self.loss], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.label_seg_placeholder: batch_segment}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time if step % self.cal_step == 0: pre_avg_loss = total_loss / self.cal_step total_loss = loss_r else: total_loss += loss_r total_loss_step = ((step % self.cal_step) + 1) if step % (self.cal_step // 10) == 0: Tools.print_info('step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} ' '({:.3f} s/step)'.format(step, pre_avg_loss, total_loss / total_loss_step, loss_r, loss_segment_r, learning_rate_r, duration)) pass if step % self.print_step == 0: Tools.print_info("") pass pass
def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, 5): final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() (raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [ self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes ], feed_dict={self.image_placeholder: final_batch_data}) if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass # Tools.print_info('step {:d} {} {} {}'.format( # step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r))) Tools.print_info('step {:d} {} {}'.format(step, list(final_batch_class), list(pred_classes_r))) pass pass
def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.num_classes = 21 # 网络 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) # 网络 self.features = self._feature(self.image_placeholder) with tf.name_scope("image"): tf.summary.image("input", self.image_placeholder) pass with tf.name_scope("block"): for feature_index, feature in enumerate(self.features[:-1]): feature_split = tf.split(feature, num_or_size_splits=int( feature.shape[-1]), axis=-1) for feature_one_index, feature_one in enumerate(feature_split): tf.summary.image( "{}-{}".format(feature_index, feature_one_index), feature_one) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph) pass
def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, self.num_steps): start_time = time.time() final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() # train_op = self.train_classes_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r, summary_now) = self.sess.run( [self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment, self.loss_classes, self.loss, self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_classes_placeholder: final_batch_class}) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment, self.loss_classes, self.loss, self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes], feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_classes_placeholder: final_batch_class}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time Tools.print_info( 'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}' ' lr={:.6f} ({:.3f} s/step) {} {}'.format( step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r, learning_rate_r, duration, list(final_batch_class), list(pred_classes_r))) pass pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net( ) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass
def build(self): # 提取特征,属于公共部分 block1, block2, block3, block4 = self._feature(self.input_data) block4_shape = Tools.get_shape(block4) # 45, 512 block3_shape = Tools.get_shape(block3) # 90, 512 block2_shape = Tools.get_shape(block2) # 180, 256 block1_shape = Tools.get_shape(block1) # 360, 128 segments = [] ###################################################### # 确定初始attention的输入点:建议在进入attention时输入 ###################################################### with tf.variable_scope(name_or_scope="attention_4"): net_output = self._decoder(block4, block4_shape, block3_shape, name="4") pass block3_add = Net.add([block3, net_output], name='attention_3_add') with tf.variable_scope(name_or_scope="attention_3"): net_output = self._decoder(block3_add, block3_shape, block2_shape, name="3") pass block2_add = Net.add([block2, net_output], name="attention_2_net_output_relu") with tf.variable_scope(name_or_scope="attention_2"): net_output = self._decoder(block2_add, block2_shape, block1_shape, name="2") pass block1_add = Net.add([block1, net_output], name="attention_1_concat") with tf.variable_scope(name_or_scope="attention_1"): net_output = self._decoder(block1_add, block1_shape, block1_shape, name="1") pass net_output = Net.conv(net_output, 3, 3, 2, 1, 1, biased=True, relu=False, name='attention_0') segments.append(net_output) # segment return segments
def load_net(self): print("begin to build net and start session and load model ...") # 网络 img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=21, last_pool_size=self.last_pool_size, filter_number=32, num_segment=4) # 输出/预测 raw_output_op = net.layers["conv6_n_4"] raw_output_op = tf.image.resize_bilinear(raw_output_op, size=self.input_size) predict_output_op = tf.argmax(tf.sigmoid(raw_output_op), axis=-1) raw_output_classes = net.layers['class_attention_fc'] pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) # 启动Session/加载模型 sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) sess.run(tf.global_variables_initializer()) Tools.restore_if_y(sess, self.log_dir) print("end build net and start session and load model ...") return sess, img_placeholder, predict_output_op, pred_classes
def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.num_classes = 21 # 网络 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) # 网络 self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes) self.segments, self.features = self.net.build() self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8) with tf.name_scope("image"): tf.summary.image("input", self.image_placeholder) # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) tf.summary.image("predict-{}".format(segment_index), tf.expand_dims(segment * 255, axis=-1)) pass pass for key in list(self.features.keys()): with tf.name_scope(key): for feature_index, feature in enumerate(self.features[key]): feature_split = tf.split(feature, num_or_size_splits=int(feature.shape[-1]), axis=-1) for feature_one_index, feature_one in enumerate(feature_split): tf.summary.image("{}-{}".format(feature_index, feature_one_index), feature_one) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 4 # 解码通道数:其他对象、attention、边界、背景 self.segment_attention = 1 # 当解码的通道数是4时,attention所在的位置 self.attention_module_num = 2 # attention模块中,解码通道数是2(背景、attention)的模块个数 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 5 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=True) # 数据 self.image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) self.label_segment_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_attention_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, )) # 网络 self.net = BAISNet(self.image_placeholder, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, segment_attention=self.segment_attention, last_pool_size=self.last_pool_size, filter_number=self.filter_number, attention_module_num=self.attention_module_num) self.segments, self.attentions, self.classes = self.net.build() self.final_segment_logit = self.segments[0] self.final_class_logit = self.classes[0] # Predictions self.pred_segment = tf.cast( tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1), axis=-1), tf.int32) self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1), tf.int32) # loss self.label_batch = tf.image.resize_nearest_neighbor( self.label_segment_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.label_attention_batch = tf.image.resize_nearest_neighbor( self.label_attention_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.classes, self.label_batch, self.label_attention_batch, self.label_classes_placeholder, self.num_segment, attention_module_num=self.attention_module_num) # 当前批次的准确率:accuracy self.accuracy_segment = tcm.accuracy(self.pred_segment, self.label_segment_placeholder) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_classes_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [ v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name ] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8)) tf.summary.image( "3-attention", tf.cast(self.label_attention_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("predict"): tf.summary.image("predict", tf.cast(self.pred_segment * 85, dtype=tf.uint8)) pass with tf.name_scope("attention"): # attention for attention_index, attention in enumerate(self.attentions): tf.summary.image("{}-attention".format(attention_index), attention) pass pass with tf.name_scope("sigmoid"): for segment_index, segment in enumerate(self.segments): if segment_index < self.attention_module_num: split = tf.split(segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("{}-other".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) tf.summary.image("{}-border".format(segment_index), split[2]) tf.summary.image("{}-background".format(segment_index), split[-1]) else: split = tf.split(segment, num_or_size_splits=2, axis=3) tf.summary.image("{}-background".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0): # 读入图片数据 final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image( image_filename, where=where, annotation_filename=annotation_filename, ann_index=ann_index, image_size=self.input_size) # 网络 img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=1, last_pool_size=self.last_pool_size, filter_number=32) # 输出/预测 raw_output_op = net.layers["conv6_n"] sigmoid_output_op = tf.sigmoid(raw_output_op) # 启动Session/加载模型 sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) sess.run(tf.global_variables_initializer()) Tools.restore_if_y(sess, self.log_dir) # 运行 raw_output, sigmoid_output = sess.run([raw_output_op, sigmoid_output_op], feed_dict={img_placeholder: final_batch_data}) # 保存 Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "data.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(sigmoid_output[0] * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "pred.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(np.greater(raw_output[0], 0.5) * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "pred_raw.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(np.greater(sigmoid_output[0], 0.5) * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "pred_sigmoid.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "mask.bmp")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "ann.bmp")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) pass
def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain) total_loss = 0.0 pre_avg_loss = 0.0 total_acc = 0.0 pre_avg_acc = 0.0 for step in range(begin_step, self.num_steps): start_time = time.time() batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train() # train_op = self.train_attention_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, pred_classes_r, summary_now) = self.sess.run( [self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.mask_placeholder: batch_mask, self.label_seg_placeholder: batch_attention, self.label_cls_placeholder: batch_class}) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run( [self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.mask_placeholder: batch_mask, self.label_seg_placeholder: batch_attention, self.label_cls_placeholder: batch_class}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time if step % self.cal_step == 0: pre_avg_loss = total_loss / self.cal_step pre_avg_acc = total_acc / self.cal_step total_loss = loss_r total_acc = accuracy_classes_r else: total_loss += loss_r total_acc += accuracy_classes_r total_loss_step = ((step % self.cal_step) + 1) if step % (self.cal_step // 10) == 0: Tools.print_info( 'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} ' 'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} ' 'lr={:.6f} ({:.3f} s/step) {} {}'.format( step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step, loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r, learning_rate_r, duration, list(batch_class), list(pred_classes_r))) pass if step % self.print_step == 0: Tools.print_info("") pass pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.has_255 = True # 是否预测边界 self.num_segment = 4 if self.has_255 else 3 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 1 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=self.has_255) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0: 3], axis=3)) tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[1]) tf.summary.image("4-other class", split[0]) tf.summary.image("5-background", split[-1]) if self.has_255: tf.summary.image("5-border", split[2]) pass tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def build_old(self): # 提取特征,属于公共部分 block1, block2, block3, block4 = self._feature(self.input_data) blocks = [block1, block2, block3, block4] block4_shape = Tools.get_shape(block4) # 45, 512 block3_shape = Tools.get_shape(block3) # 90, 512 block2_shape = Tools.get_shape(block2) # 180, 256 block1_shape = Tools.get_shape(block1) # 360, 128 adds = [] adds_in_block = [] adds_in_2 = [] adds = [] temps = [] segments = [] segments_output = [] ###################################################### # 确定初始attention的输入点:建议在进入attention时输入 ###################################################### with tf.variable_scope(name_or_scope="attention_4"): # 0 net_segment_output, temp = self._segment(block4, block4_shape, block3_shape, name="segment_side_4") temps.append(temp) segments.append(net_segment_output) # segment net_output = self._decoder(block4, block4_shape, block3_shape, name="4") segments_output.append(net_output) pass with tf.variable_scope(name_or_scope="attention_3"): # 1 ==> # net_segment_block_output, temp = self._segment(block3, block3_shape, block2_shape, name="segment_side_3_block") # temps.append(temp) # segments.append(net_segment_block_output) # segment # 2 block3_add = Net.add([block3, net_output], name='add') adds_in_block.append(block3) adds_in_2.append(net_output) adds.append(block3_add) net_segment_output, temp = self._segment(block3_add, block3_shape, block2_shape, name="segment_side_3") temps.append(temp) segments.append(net_segment_output) # segment net_output = self._decoder(block3_add, block3_shape, block2_shape, name="3") segments_output.append(net_output) pass with tf.variable_scope(name_or_scope="attention_2"): # 3 ==> # net_segment_block_output, temp = self._segment(block2, block2_shape, block1_shape, name="segment_side_2_block") # temps.append(temp) # segments.append(net_segment_block_output) # segment # 4 block2_add = Net.add([block2, net_output], name="add") adds_in_block.append(block2) adds_in_2.append(net_output) adds.append(block2_add) net_segment_output, temp = self._segment(block2_add, block2_shape, block1_shape, name="segment_side_2") temps.append(temp) segments.append(net_segment_output) # segment net_output = self._decoder(block2_add, block2_shape, block1_shape, name="2") segments_output.append(net_output) pass with tf.variable_scope(name_or_scope="attention_1"): # 5 ==> # net_segment_block_output, temp = self._segment(block1, block1_shape, block1_shape, name="segment_side_1_block") # temps.append(temp) # segments.append(net_segment_block_output) # segment # 6 block1_add = Net.add([block1, net_output], name="add") adds_in_block.append(block1) adds_in_2.append(net_output) adds.append(block1_add) net_segment_output, temp = self._segment(block1_add, block1_shape, block1_shape, name="segment_side_1") temps.append(temp) segments.append(net_segment_output) # segment net_output = self._decoder(block1_add, block1_shape, block1_shape, name="1") segments_output.append(net_output) pass # 7 net_output = Net.conv(net_output, 3, 3, 2, 1, 1, biased=True, relu=False, name='attention_0') segments.append(net_output) # segment features = { "block": blocks, "segment": segments_output, "temp": temps, "add": adds, "adds_in_block": adds_in_block, "adds_in_2": adds_in_2 } return segments, features
def load_model(self, pretrain): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=pretrain) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_0, self.accuracy_1, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_0", self.accuracy_0) tf.summary.scalar("accuracy_1", self.accuracy_1) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("image", tf.concat(split[0:3], axis=3)) tf.summary.image("mask", split[3]) tf.summary.image("label", self.label_segment_placeholder) tf.summary.image("raw_output_segment", self.raw_output_segment) tf.summary.image("pred_segment", tf.cast(self.pred_segment, tf.float32)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-4 self.num_steps = 500001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,)) # 网络 self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes) self.segments, self.attentions, self.classes = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder) # 当前批次的准确率:accuracy self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate(self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8)) tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.split(segment, num_or_size_splits=2, axis=3) # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0])) tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1])) # attention for attention_index, attention in enumerate(self.attentions): attention = tf.split(attention, num_or_size_splits=2, axis=3) # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0])) tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1])) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, log_dir, data, model_name="model.ckpt", is_test=False): # 读取数据 self.data_reader = data self.batch_size = self.data_reader.batch_size self.num_classes = self.data_reader.num_classes self.input_size = self.data_reader.image_size self.ratio = self.data_reader.ratio self.num_segment = self.data_reader.num_segment self.attention_class = self.data_reader.attention_class # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.last_pool_size = self.input_size[0] // self.ratio self.filter_number = 32 self.learning_rate = 5e-3 self.num_steps = 1000001 self.print_step = 1 if is_test else 25 # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * (255 // (self.num_segment - 1)), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[self.attention_class]) for num_segment in range(self.num_segment): tf.summary.image("4-segment-output-{}".format(num_segment), split[num_segment]) tf.summary.image( "5-pred_segment", tf.cast(self.pred_segment * (255 // (self.num_segment - 1)), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0): # 读入图片数据 if annotation_filename: final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image( image_filename, where=where, annotation_filename=annotation_filename, ann_index=ann_index, image_size=self.input_size) else: final_batch_data, data_raw, gaussian_mask = Data.load_image( image_filename, where=where, annotation_filename=annotation_filename, ann_index=ann_index, image_size=self.input_size) # 网络 img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=21, last_pool_size=self.last_pool_size, filter_number=32, num_segment=4) # 输出/预测 raw_output_op = net.layers["conv6_n_4"] sigmoid_output_op = tf.sigmoid(raw_output_op) predict_output_op = tf.argmax(sigmoid_output_op, axis=-1) raw_output_classes = net.layers['class_attention_fc'] pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) # 启动Session/加载模型 sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) sess.run(tf.global_variables_initializer()) Tools.restore_if_y(sess, self.log_dir) # 运行 raw_output, sigmoid_output, predict_output_r, raw_output_classes_r, pred_classes_r = sess.run( [ raw_output_op, sigmoid_output_op, predict_output_op, raw_output_classes, pred_classes ], feed_dict={img_placeholder: final_batch_data}) # 保存 print("{} {} {}".format(pred_classes_r[0], CategoryNames[pred_classes_r[0]], raw_output_classes_r)) print("result in {}".format( os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "data.png")) output_result = np.squeeze( np.split(np.asarray(sigmoid_output[0] * 255, dtype=np.uint8), axis=-1, indices_or_sections=4)) Image.fromarray( np.squeeze( np.asarray(predict_output_r[0] * 255 // 4, dtype=np.uint8))).save( os.path.join(self.save_dir, result_filename + "pred.png")) Image.fromarray(output_result[0]).save( os.path.join(self.save_dir, result_filename + "pred_0.png")) Image.fromarray(output_result[1]).save( os.path.join(self.save_dir, result_filename + "pred_1.png")) Image.fromarray(output_result[2]).save( os.path.join(self.save_dir, result_filename + "pred_2.png")) Image.fromarray(output_result[3]).save( os.path.join(self.save_dir, result_filename + "pred_3.png")) Image.fromarray( np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "mask.bmp")) if annotation_filename: Image.fromarray( np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "ann.bmp")) pass pass
def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 100001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.label_seg_placeholder = tf.placeholder( tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) # 网络 self.net = BAISNet(self.image_placeholder, True, num_classes=self.num_classes) self.segments, self.features = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss( self.segments, self.label_seg_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.8)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的 segment_side segment_side_trainable = [ v for v in tf.trainable_variables() if 'segment_side' in v.name ] print(len(segment_side_trainable)) self.train_segment_side_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=segment_side_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image( "2-segment", tf.cast(self.label_seg_placeholder * 255 // self.num_classes, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) segment = tf.split(segment, num_or_size_splits=self.batch_size, axis=0) ii = 3 if self.batch_size >= 3 else self.batch_size for i in range(ii): tf.summary.image( "predict-{}-{}".format(segment_index, i), tf.expand_dims(segment[i] * 255 // self.num_classes, axis=-1)) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def load_model(self): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) pass