def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net( ) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass
def inference(self, image_path, image_index): im_data = Data.load_data(image_path=image_path, input_size=self.input_size) im_data = np.expand_dims(im_data, axis=0) result, summary_now = self.sess.run([self.features[-1], self.summary_op], feed_dict={self.image_placeholder: im_data}) self.summary_writer.add_summary(summary_now, global_step=image_index) print(result) pass
def inference(self, image_path, image_index, save_path=None): im_data = Data.load_data(image_path=image_path, input_size=self.input_size) im_data = np.expand_dims(im_data, axis=0) pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op], feed_dict={self.image_placeholder: im_data}) self.summary_writer.add_summary(summary_now, global_step=image_index) s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8)) if save_path is None: s_image.show() else: Tools.new_dir(save_path) s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0])) pass
def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0): # 读入图片数据 final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image( image_filename, where=where, annotation_filename=annotation_filename, ann_index=ann_index, image_size=self.input_size) # 网络 img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=1, last_pool_size=self.last_pool_size, filter_number=32) # 输出/预测 raw_output_op = net.layers["conv6_n"] sigmoid_output_op = tf.sigmoid(raw_output_op) # 启动Session/加载模型 sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) sess.run(tf.global_variables_initializer()) Tools.restore_if_y(sess, self.log_dir) # 运行 raw_output, sigmoid_output = sess.run([raw_output_op, sigmoid_output_op], feed_dict={img_placeholder: final_batch_data}) # 保存 Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "data.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(sigmoid_output[0] * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "pred.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(np.greater(raw_output[0], 0.5) * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "pred_raw.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(np.greater(sigmoid_output[0], 0.5) * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "pred_sigmoid.png")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "mask.bmp")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "ann.bmp")) Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename))) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 4 # 解码通道数:其他对象、attention、边界、背景 self.segment_attention = 1 # 当解码的通道数是4时,attention所在的位置 self.attention_module_num = 2 # attention模块中,解码通道数是2(背景、attention)的模块个数 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 5 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=True) # 数据 self.image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) self.label_segment_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_attention_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, )) # 网络 self.net = BAISNet(self.image_placeholder, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, segment_attention=self.segment_attention, last_pool_size=self.last_pool_size, filter_number=self.filter_number, attention_module_num=self.attention_module_num) self.segments, self.attentions, self.classes = self.net.build() self.final_segment_logit = self.segments[0] self.final_class_logit = self.classes[0] # Predictions self.pred_segment = tf.cast( tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1), axis=-1), tf.int32) self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1), tf.int32) # loss self.label_batch = tf.image.resize_nearest_neighbor( self.label_segment_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.label_attention_batch = tf.image.resize_nearest_neighbor( self.label_attention_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.classes, self.label_batch, self.label_attention_batch, self.label_classes_placeholder, self.num_segment, attention_module_num=self.attention_module_num) # 当前批次的准确率:accuracy self.accuracy_segment = tcm.accuracy(self.pred_segment, self.label_segment_placeholder) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_classes_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [ v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name ] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8)) tf.summary.image( "3-attention", tf.cast(self.label_attention_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("predict"): tf.summary.image("predict", tf.cast(self.pred_segment * 85, dtype=tf.uint8)) pass with tf.name_scope("attention"): # attention for attention_index, attention in enumerate(self.attentions): tf.summary.image("{}-attention".format(attention_index), attention) pass pass with tf.name_scope("sigmoid"): for segment_index, segment in enumerate(self.segments): if segment_index < self.attention_module_num: split = tf.split(segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("{}-other".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) tf.summary.image("{}-border".format(segment_index), split[2]) tf.summary.image("{}-background".format(segment_index), split[-1]) else: split = tf.split(segment, num_or_size_splits=2, axis=3) tf.summary.image("{}-background".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
class Train(object): def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 4 # 解码通道数:其他对象、attention、边界、背景 self.segment_attention = 1 # 当解码的通道数是4时,attention所在的位置 self.attention_module_num = 2 # attention模块中,解码通道数是2(背景、attention)的模块个数 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 5 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=True) # 数据 self.image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) self.label_segment_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_attention_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, )) # 网络 self.net = BAISNet(self.image_placeholder, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, segment_attention=self.segment_attention, last_pool_size=self.last_pool_size, filter_number=self.filter_number, attention_module_num=self.attention_module_num) self.segments, self.attentions, self.classes = self.net.build() self.final_segment_logit = self.segments[0] self.final_class_logit = self.classes[0] # Predictions self.pred_segment = tf.cast( tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1), axis=-1), tf.int32) self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1), tf.int32) # loss self.label_batch = tf.image.resize_nearest_neighbor( self.label_segment_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.label_attention_batch = tf.image.resize_nearest_neighbor( self.label_attention_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.classes, self.label_batch, self.label_attention_batch, self.label_classes_placeholder, self.num_segment, attention_module_num=self.attention_module_num) # 当前批次的准确率:accuracy self.accuracy_segment = tcm.accuracy(self.pred_segment, self.label_segment_placeholder) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_classes_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [ v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name ] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8)) tf.summary.image( "3-attention", tf.cast(self.label_attention_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("predict"): tf.summary.image("predict", tf.cast(self.pred_segment * 85, dtype=tf.uint8)) pass with tf.name_scope("attention"): # attention for attention_index, attention in enumerate(self.attentions): tf.summary.image("{}-attention".format(attention_index), attention) pass pass with tf.name_scope("sigmoid"): for segment_index, segment in enumerate(self.segments): if segment_index < self.attention_module_num: split = tf.split(segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("{}-other".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) tf.summary.image("{}-border".format(segment_index), split[2]) tf.summary.image("{}-background".format(segment_index), split[-1]) else: split = tf.split(segment, num_or_size_splits=2, axis=3) tf.summary.image("{}-background".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass @staticmethod def cal_loss(segments, classes, label_segment, label_attention, label_classes, num_segment, attention_module_num): label_segment = tf.reshape(label_segment, [ -1, ]) label_attention = tf.reshape(label_attention, [ -1, ]) loss_segments = [] for segment_index, segment in enumerate(segments): if segment_index < len(segments) - attention_module_num: now_loss_segment = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_segment, logits=tf.reshape(segment, [-1, num_segment]))) loss_segments.append(now_loss_segment) else: # now_loss_segment = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits( # targets=tf.one_hot(label_attention, depth=2), logits=tf.reshape(segment, [-1, 2]), pos_weight=3)) segment = tf.split(segment, num_or_size_splits=2, axis=3)[1] now_loss_segment = tf.reduce_mean( tf.nn.weighted_cross_entropy_with_logits( targets=tf.cast(label_attention, dtype=tf.float32), logits=tf.reshape(segment, [ -1, ]), pos_weight=3)) * 2 loss_segments.append(now_loss_segment) pass loss_classes = [] for class_one in classes: loss_classes.append( tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_classes, logits=class_one))) pass loss_segment_all = tf.add_n(loss_segments) / len(loss_segments) loss_class_all = tf.add_n(loss_classes) / len(loss_classes) # 总损失 loss = loss_segment_all + 0.1 * loss_class_all return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, self.num_steps): start_time = time.time() final_batch_data, final_batch_ann, final_batch_ann_attention, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() # train_op = self.train_attention_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r, summary_now) = self.sess.run( [ self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.final_segment_logit, self.pred_segment, self.final_class_logit, self.pred_classes, self.summary_op ], feed_dict={ self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_attention_placeholder: final_batch_ann_attention, self.label_classes_placeholder: final_batch_class }) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [ self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.final_segment_logit, self.pred_segment, self.final_class_logit, self.pred_classes ], feed_dict={ self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_attention_placeholder: final_batch_ann_attention, self.label_classes_placeholder: final_batch_class }) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time Tools.print_info( 'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}' ' lr={:.6f} ({:.3f} s/step) {} {}'.format( step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r, learning_rate_r, duration, list(final_batch_class), list(pred_classes_r))) pass pass pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.has_255 = True # 是否预测边界 self.num_segment = 4 if self.has_255 else 3 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 1 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=self.has_255) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0: 3], axis=3)) tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[1]) tf.summary.image("4-other class", split[0]) tf.summary.image("5-background", split[-1]) if self.has_255: tf.summary.image("5-border", split[2]) pass tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
class Train(object): def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.has_255 = True # 是否预测边界 self.num_segment = 4 if self.has_255 else 3 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 1 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=self.has_255) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0: 3], axis=3)) tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[1]) tf.summary.image("4-other class", split[0]) tf.summary.image("5-background", split[-1]) if self.has_255: tf.summary.image("5-border", split[2]) pass tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass def build_net(self): # 数据 image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) label_segment_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None,)) # 网络 net = PSPNet({'data': image_placeholder}, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, last_pool_size=self.last_pool_size, filter_number=self.filter_number) raw_output_segment = net.layers['conv6_n_4'] raw_output_classes = net.layers['class_attention_fc'] # Predictions prediction = tf.reshape(raw_output_segment, [-1, self.num_segment]) pred_segment = tf.cast(tf.expand_dims(tf.argmax(raw_output_segment, axis=-1), axis=-1), tf.int32) pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) # label label_batch = tf.image.resize_nearest_neighbor(label_segment_placeholder, tf.stack(raw_output_segment.get_shape()[1:3])) label_batch = tf.reshape(label_batch, [-1, ]) # 当前批次的准确率:accuracy accuracy_segment = tcm.accuracy(pred_segment, label_segment_placeholder) accuracy_classes = tcm.accuracy(pred_classes, label_classes_placeholder) # loss loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_batch, logits=prediction)) # 分类损失 loss_classes = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes_placeholder, logits=raw_output_classes)) # 总损失 loss = tf.add_n([loss_segment, 0.1 * loss_classes]) # 学习率策略 step_ph = tf.placeholder(dtype=tf.float32, shape=()) learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - step_ph / self.num_steps), 0.9)) train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) # 单独训练最后的分类 classes_trainable = [v for v in tf.trainable_variables() if 'class_attention' in v.name] train_classes_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, var_list=classes_trainable) return (image_placeholder, label_segment_placeholder, label_classes_placeholder, raw_output_segment, raw_output_classes, pred_segment, pred_classes, loss_segment, loss_classes, loss, accuracy_segment, accuracy_classes, step_ph, train_op, train_classes_op, learning_rate) def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, self.num_steps): start_time = time.time() final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() # train_op = self.train_classes_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r, summary_now) = self.sess.run( [self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment, self.loss_classes, self.loss, self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_classes_placeholder: final_batch_class}) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment, self.loss_classes, self.loss, self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes], feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_classes_placeholder: final_batch_class}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time Tools.print_info( 'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}' ' lr={:.6f} ({:.3f} s/step) {} {}'.format( step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r, learning_rate_r, duration, list(final_batch_class), list(pred_classes_r))) pass pass pass
def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 100001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.label_seg_placeholder = tf.placeholder( tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) # 网络 self.net = BAISNet(self.image_placeholder, True, num_classes=self.num_classes) self.segments, self.features = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss( self.segments, self.label_seg_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.8)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的 segment_side segment_side_trainable = [ v for v in tf.trainable_variables() if 'segment_side' in v.name ] print(len(segment_side_trainable)) self.train_segment_side_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=segment_side_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image( "2-segment", tf.cast(self.label_seg_placeholder * 255 // self.num_classes, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) segment = tf.split(segment, num_or_size_splits=self.batch_size, axis=0) ii = 3 if self.batch_size >= 3 else self.batch_size for i in range(ii): tf.summary.image( "predict-{}-{}".format(segment_index, i), tf.expand_dims(segment[i] * 255 // self.num_classes, axis=-1)) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
class Train(object): def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-4 self.num_steps = 500001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,)) # 网络 self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes) self.segments, self.attentions, self.classes = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder) # 当前批次的准确率:accuracy self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate(self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8)) tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.split(segment, num_or_size_splits=2, axis=3) # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0])) tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1])) # attention for attention_index, attention in enumerate(self.attentions): attention = tf.split(attention, num_or_size_splits=2, axis=3) # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0])) tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1])) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass @staticmethod def cal_loss(segments, attentions, classes, label_segment, label_classes): # loss_segments = [] # for segment_index, segment in enumerate(segments): # now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(segment.get_shape()[1:3])) # now_loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( # labels=tf.reshape(now_label_segment, [-1, ]), # logits=tf.reshape(segment, [-1, 2]))) # loss_segments.append(now_loss_segment) # pass loss_attentions = [] for attention_index, attention in enumerate(attentions): now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(attention.get_shape()[1:3])) # now_loss_attention = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( # labels=tf.reshape(now_label_segment, [-1, ]), # logits=tf.reshape(attention, [-1, 2]))) now_loss_attention = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits( targets=tf.one_hot(tf.reshape(now_label_segment, [-1, ]), depth=2), logits=tf.reshape(attention, [-1, 2]), pos_weight=3)) loss_attentions.append(now_loss_attention) pass loss_classes = [] for class_one in classes: loss_classes.append(tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes, logits=class_one))) pass # loss_segment_all = tf.add_n(loss_segments) / len(loss_segments) loss_attention_all = tf.add_n(loss_attentions) / len(loss_attentions) loss_class_all = tf.add_n(loss_classes) / len(loss_classes) # 总损失 # loss = loss_segment_all + loss_attention_all + 0.5 * loss_class_all # return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes loss = loss_attention_all + loss_class_all return loss, loss_attention_all, loss_class_all, loss_attentions, loss_classes def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain) total_loss = 0.0 pre_avg_loss = 0.0 total_acc = 0.0 pre_avg_acc = 0.0 for step in range(begin_step, self.num_steps): start_time = time.time() batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train() # train_op = self.train_attention_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, pred_classes_r, summary_now) = self.sess.run( [self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.mask_placeholder: batch_mask, self.label_seg_placeholder: batch_attention, self.label_cls_placeholder: batch_class}) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run( [self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.mask_placeholder: batch_mask, self.label_seg_placeholder: batch_attention, self.label_cls_placeholder: batch_class}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time if step % self.cal_step == 0: pre_avg_loss = total_loss / self.cal_step pre_avg_acc = total_acc / self.cal_step total_loss = loss_r total_acc = accuracy_classes_r else: total_loss += loss_r total_acc += accuracy_classes_r total_loss_step = ((step % self.cal_step) + 1) if step % (self.cal_step // 10) == 0: Tools.print_info( 'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} ' 'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} ' 'lr={:.6f} ({:.3f} s/step) {} {}'.format( step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step, loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r, learning_rate_r, duration, list(batch_class), list(pred_classes_r))) pass if step % self.print_step == 0: Tools.print_info("") pass pass pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_0, self.accuracy_1, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_0", self.accuracy_0) tf.summary.scalar("accuracy_1", self.accuracy_1) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("image", tf.concat(split[0:3], axis=3)) tf.summary.image("mask", split[3]) tf.summary.image("label", self.label_segment_placeholder) tf.summary.image("raw_output_segment", self.raw_output_segment) tf.summary.image("pred_segment", tf.cast(self.pred_segment, tf.float32)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def run(self, image_filename_or_data, mask_color, opacity): plt.ion() plt.axis('off') if isinstance(image_filename_or_data, str): image_data = np.array(Image.open(image_filename_or_data)) elif isinstance(image_filename_or_data, list) or isinstance(image_filename_or_data, np.ndarray): image_data = image_filename_or_data else: print("image_filename_or_data is error") return plt.imshow(image_data) plt.title('Click one point of the object that you interested') try: while 1: object_point = np.array(plt.ginput(1, timeout=0)).astype(np.int)[0] where = [int(self.input_size[0] * object_point[1] / len(image_data)), int(self.input_size[1] * object_point[0] / len(image_data[0]))] print("point=[{},{}] where=[{},{}]".format(object_point[0], object_point[1], where[0], where[1])) final_batch_data, data_raw, gaussian_mask = Data.load_image(image_data, where=where, image_size=self.input_size) print("begin to run ...") # 运行 predict_output_r, pred_classes_r = self.sess.run([self.predict_output, self.pred_classes], feed_dict={self.img_placeholder: final_batch_data}) print("end run") # 类别 print("the class is {}({})".format(pred_classes_r[0], CategoryNames[pred_classes_r[0]])) # 分割 segment = np.squeeze(np.asarray(np.where(predict_output_r[0] == 1, 1, 0), dtype=np.uint8)) segment = np.asarray(Image.fromarray(segment).resize((len(image_data[0]), len(image_data)))) image_mask = np.ndarray(image_data.shape) image_mask[:, :, 0] = (1 - segment) * image_data[:, :, 0] + segment * ( opacity * mask_color[0] + (1 - opacity) * image_data[:, :, 0]) image_mask[:, :, 1] = (1 - segment) * image_data[:, :, 1] + segment * ( opacity * mask_color[1] + (1 - opacity) * image_data[:, :, 1]) image_mask[:, :, 2] = (1 - segment) * image_data[:, :, 2] + segment * ( opacity * mask_color[2] + (1 - opacity) * image_data[:, :, 2]) plt.clf() # clear image plt.text(len(image_data[0]) // 2 - 10, -6, CategoryNames[pred_classes_r[0]], fontsize=15) plt.imshow(image_mask.astype(np.uint8)) print("") pass except Exception: print("..................") print("...... close .....") print("..................") pass pass
pass if __name__ == '__main__': is_win = False is_voc = False if is_win: if is_voc: data_reader = Data( data_root_path= "C:\\ALISURE\\DataModel\\Data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\", data_list="ImageSets\\Segmentation\\train.txt", data_path="JPEGImages\\", annotation_path="SegmentationObject\\", class_path="SegmentationClass\\", batch_size=3, image_size=[720, 720], is_test=False) else: data_reader = COCOData( data_root_path="C:\\ALISURE\\DataModel\\Data\\COCO", annotation_path="annotations_trainval2014\\annotations", data_type="val2014", batch_size=3, image_size=[720, 720]) pass Train(log_dir="./model/coco/first", data=data_reader, is_test=True).train(save_pred_freq=2, begin_step=0)
def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0): # 读入图片数据 if annotation_filename: final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image( image_filename, where=where, annotation_filename=annotation_filename, ann_index=ann_index, image_size=self.input_size) else: final_batch_data, data_raw, gaussian_mask = Data.load_image( image_filename, where=where, annotation_filename=annotation_filename, ann_index=ann_index, image_size=self.input_size) # 网络 img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=21, last_pool_size=self.last_pool_size, filter_number=32, num_segment=4) # 输出/预测 raw_output_op = net.layers["conv6_n_4"] sigmoid_output_op = tf.sigmoid(raw_output_op) predict_output_op = tf.argmax(sigmoid_output_op, axis=-1) raw_output_classes = net.layers['class_attention_fc'] pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) # 启动Session/加载模型 sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) sess.run(tf.global_variables_initializer()) Tools.restore_if_y(sess, self.log_dir) # 运行 raw_output, sigmoid_output, predict_output_r, raw_output_classes_r, pred_classes_r = sess.run( [ raw_output_op, sigmoid_output_op, predict_output_op, raw_output_classes, pred_classes ], feed_dict={img_placeholder: final_batch_data}) # 保存 print("{} {} {}".format(pred_classes_r[0], CategoryNames[pred_classes_r[0]], raw_output_classes_r)) print("result in {}".format( os.path.join(self.save_dir, result_filename))) Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "data.png")) output_result = np.squeeze( np.split(np.asarray(sigmoid_output[0] * 255, dtype=np.uint8), axis=-1, indices_or_sections=4)) Image.fromarray( np.squeeze( np.asarray(predict_output_r[0] * 255 // 4, dtype=np.uint8))).save( os.path.join(self.save_dir, result_filename + "pred.png")) Image.fromarray(output_result[0]).save( os.path.join(self.save_dir, result_filename + "pred_0.png")) Image.fromarray(output_result[1]).save( os.path.join(self.save_dir, result_filename + "pred_1.png")) Image.fromarray(output_result[2]).save( os.path.join(self.save_dir, result_filename + "pred_2.png")) Image.fromarray(output_result[3]).save( os.path.join(self.save_dir, result_filename + "pred_3.png")) Image.fromarray( np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "mask.bmp")) if annotation_filename: Image.fromarray( np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save( os.path.join(self.save_dir, result_filename + "ann.bmp")) pass pass
class Train(object): def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net( ) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass def build_net(self): # 数据 image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) # 网络 net = PSPNet({'data': image_placeholder}, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, last_pool_size=self.last_pool_size, filter_number=self.filter_number) raw_output_segment = net.layers['conv6_n'] raw_output_classes = net.layers['class_attention_fc'] # Predictions pred_segment = tf.cast(tf.greater(raw_output_segment, 0.5), tf.int32) pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) return image_placeholder, raw_output_segment, raw_output_classes, pred_segment, pred_classes def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, 5): final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() (raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [ self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes ], feed_dict={self.image_placeholder: final_batch_data}) if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass # Tools.print_info('step {:d} {} {} {}'.format( # step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r))) Tools.print_info('step {:d} {} {}'.format(step, list(final_batch_class), list(pred_classes_r))) pass pass pass
class Train(object): def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 100001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.label_seg_placeholder = tf.placeholder( tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) # 网络 self.net = BAISNet(self.image_placeholder, True, num_classes=self.num_classes) self.segments, self.features = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss( self.segments, self.label_seg_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.8)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的 segment_side segment_side_trainable = [ v for v in tf.trainable_variables() if 'segment_side' in v.name ] print(len(segment_side_trainable)) self.train_segment_side_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=segment_side_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image( "2-segment", tf.cast(self.label_seg_placeholder * 255 // self.num_classes, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) segment = tf.split(segment, num_or_size_splits=self.batch_size, axis=0) ii = 3 if self.batch_size >= 3 else self.batch_size for i in range(ii): tf.summary.image( "predict-{}-{}".format(segment_index, i), tf.expand_dims(segment[i] * 255 // self.num_classes, axis=-1)) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass def cal_loss(self, segments, label_segment): loss_segments = [] for segment_index, segment in enumerate(segments): # if "segment_side" not in segment.name: # continue now_label_segment = tf.image.resize_nearest_neighbor( label_segment, tf.stack(segment.get_shape()[1:3])) now_loss_segment = tf.reduce_mean( tf.nn.weighted_cross_entropy_with_logits( targets=tf.one_hot(tf.reshape(now_label_segment, [ -1, ]), depth=self.num_classes), logits=tf.reshape(segment, [-1, self.num_classes]), pos_weight=1)) loss_segments.append(now_loss_segment) pass loss_segment_all = tf.add_n(loss_segments) / len(loss_segments) # 总损失 loss = loss_segment_all return loss, loss_segment_all, loss_segments def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain) total_loss = 0.0 pre_avg_loss = 0.0 for step in range(begin_step, self.num_steps): start_time = time.time() batch_data, batch_segment = self.data_reader.next_batch_train() train_op = self.train_op # train_op = self.train_segment_side_op if step % self.print_step == 0: # summary 3 _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run( [ train_op, self.learning_rate, self.loss_segment_all, self.loss, self.summary_op ], feed_dict={ self.step_ph: step, self.image_placeholder: batch_data, self.label_seg_placeholder: batch_segment }) self.summary_writer.add_summary(summary_now, global_step=step) else: _, learning_rate_r, loss_segment_r, loss_r = self.sess.run( [ train_op, self.learning_rate, self.loss_segment_all, self.loss ], feed_dict={ self.step_ph: step, self.image_placeholder: batch_data, self.label_seg_placeholder: batch_segment }) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time if step % self.cal_step == 0: pre_avg_loss = total_loss / self.cal_step total_loss = loss_r else: total_loss += loss_r total_loss_step = ((step % self.cal_step) + 1) if step % (self.cal_step // 10) == 0: Tools.print_info( 'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} ' '({:.3f} s/step)'.format(step, pre_avg_loss, total_loss / total_loss_step, loss_r, loss_segment_r, learning_rate_r, duration)) pass if step % self.print_step == 0: Tools.print_info("") pass pass pass
def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-4 self.num_steps = 500001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,)) # 网络 self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes) self.segments, self.attentions, self.classes = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder) # 当前批次的准确率:accuracy self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate(self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8)) tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.split(segment, num_or_size_splits=2, axis=3) # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0])) tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1])) # attention for attention_index, attention in enumerate(self.attentions): attention = tf.split(attention, num_or_size_splits=2, axis=3) # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0])) tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1])) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass