class Train(object): def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net( ) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass def build_net(self): # 数据 image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) # 网络 net = PSPNet({'data': image_placeholder}, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, last_pool_size=self.last_pool_size, filter_number=self.filter_number) raw_output_segment = net.layers['conv6_n'] raw_output_classes = net.layers['class_attention_fc'] # Predictions pred_segment = tf.cast(tf.greater(raw_output_segment, 0.5), tf.int32) pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) return image_placeholder, raw_output_segment, raw_output_classes, pred_segment, pred_classes def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, 5): final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() (raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [ self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes ], feed_dict={self.image_placeholder: final_batch_data}) if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass # Tools.print_info('step {:d} {} {} {}'.format( # step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r))) Tools.print_info('step {:d} {} {}'.format(step, list(final_batch_class), list(pred_classes_r))) pass pass pass
class Train(object): def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 4 # 解码通道数:其他对象、attention、边界、背景 self.segment_attention = 1 # 当解码的通道数是4时,attention所在的位置 self.attention_module_num = 2 # attention模块中,解码通道数是2(背景、attention)的模块个数 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 5 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=True) # 数据 self.image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) self.label_segment_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_attention_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, )) # 网络 self.net = BAISNet(self.image_placeholder, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, segment_attention=self.segment_attention, last_pool_size=self.last_pool_size, filter_number=self.filter_number, attention_module_num=self.attention_module_num) self.segments, self.attentions, self.classes = self.net.build() self.final_segment_logit = self.segments[0] self.final_class_logit = self.classes[0] # Predictions self.pred_segment = tf.cast( tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1), axis=-1), tf.int32) self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1), tf.int32) # loss self.label_batch = tf.image.resize_nearest_neighbor( self.label_segment_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.label_attention_batch = tf.image.resize_nearest_neighbor( self.label_attention_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.classes, self.label_batch, self.label_attention_batch, self.label_classes_placeholder, self.num_segment, attention_module_num=self.attention_module_num) # 当前批次的准确率:accuracy self.accuracy_segment = tcm.accuracy(self.pred_segment, self.label_segment_placeholder) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_classes_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [ v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name ] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8)) tf.summary.image( "3-attention", tf.cast(self.label_attention_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("predict"): tf.summary.image("predict", tf.cast(self.pred_segment * 85, dtype=tf.uint8)) pass with tf.name_scope("attention"): # attention for attention_index, attention in enumerate(self.attentions): tf.summary.image("{}-attention".format(attention_index), attention) pass pass with tf.name_scope("sigmoid"): for segment_index, segment in enumerate(self.segments): if segment_index < self.attention_module_num: split = tf.split(segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("{}-other".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) tf.summary.image("{}-border".format(segment_index), split[2]) tf.summary.image("{}-background".format(segment_index), split[-1]) else: split = tf.split(segment, num_or_size_splits=2, axis=3) tf.summary.image("{}-background".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass @staticmethod def cal_loss(segments, classes, label_segment, label_attention, label_classes, num_segment, attention_module_num): label_segment = tf.reshape(label_segment, [ -1, ]) label_attention = tf.reshape(label_attention, [ -1, ]) loss_segments = [] for segment_index, segment in enumerate(segments): if segment_index < len(segments) - attention_module_num: now_loss_segment = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_segment, logits=tf.reshape(segment, [-1, num_segment]))) loss_segments.append(now_loss_segment) else: # now_loss_segment = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits( # targets=tf.one_hot(label_attention, depth=2), logits=tf.reshape(segment, [-1, 2]), pos_weight=3)) segment = tf.split(segment, num_or_size_splits=2, axis=3)[1] now_loss_segment = tf.reduce_mean( tf.nn.weighted_cross_entropy_with_logits( targets=tf.cast(label_attention, dtype=tf.float32), logits=tf.reshape(segment, [ -1, ]), pos_weight=3)) * 2 loss_segments.append(now_loss_segment) pass loss_classes = [] for class_one in classes: loss_classes.append( tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_classes, logits=class_one))) pass loss_segment_all = tf.add_n(loss_segments) / len(loss_segments) loss_class_all = tf.add_n(loss_classes) / len(loss_classes) # 总损失 loss = loss_segment_all + 0.1 * loss_class_all return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, self.num_steps): start_time = time.time() final_batch_data, final_batch_ann, final_batch_ann_attention, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() # train_op = self.train_attention_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r, summary_now) = self.sess.run( [ self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.final_segment_logit, self.pred_segment, self.final_class_logit, self.pred_classes, self.summary_op ], feed_dict={ self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_attention_placeholder: final_batch_ann_attention, self.label_classes_placeholder: final_batch_class }) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [ self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.final_segment_logit, self.pred_segment, self.final_class_logit, self.pred_classes ], feed_dict={ self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_attention_placeholder: final_batch_ann_attention, self.label_classes_placeholder: final_batch_class }) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time Tools.print_info( 'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}' ' lr={:.6f} ({:.3f} s/step) {} {}'.format( step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r, learning_rate_r, duration, list(final_batch_class), list(pred_classes_r))) pass pass pass
class Train(object): def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 100001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.label_seg_placeholder = tf.placeholder( tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) # 网络 self.net = BAISNet(self.image_placeholder, True, num_classes=self.num_classes) self.segments, self.features = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss( self.segments, self.label_seg_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.8)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的 segment_side segment_side_trainable = [ v for v in tf.trainable_variables() if 'segment_side' in v.name ] print(len(segment_side_trainable)) self.train_segment_side_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=segment_side_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image( "2-segment", tf.cast(self.label_seg_placeholder * 255 // self.num_classes, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) segment = tf.split(segment, num_or_size_splits=self.batch_size, axis=0) ii = 3 if self.batch_size >= 3 else self.batch_size for i in range(ii): tf.summary.image( "predict-{}-{}".format(segment_index, i), tf.expand_dims(segment[i] * 255 // self.num_classes, axis=-1)) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass def cal_loss(self, segments, label_segment): loss_segments = [] for segment_index, segment in enumerate(segments): # if "segment_side" not in segment.name: # continue now_label_segment = tf.image.resize_nearest_neighbor( label_segment, tf.stack(segment.get_shape()[1:3])) now_loss_segment = tf.reduce_mean( tf.nn.weighted_cross_entropy_with_logits( targets=tf.one_hot(tf.reshape(now_label_segment, [ -1, ]), depth=self.num_classes), logits=tf.reshape(segment, [-1, self.num_classes]), pos_weight=1)) loss_segments.append(now_loss_segment) pass loss_segment_all = tf.add_n(loss_segments) / len(loss_segments) # 总损失 loss = loss_segment_all return loss, loss_segment_all, loss_segments def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain) total_loss = 0.0 pre_avg_loss = 0.0 for step in range(begin_step, self.num_steps): start_time = time.time() batch_data, batch_segment = self.data_reader.next_batch_train() train_op = self.train_op # train_op = self.train_segment_side_op if step % self.print_step == 0: # summary 3 _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run( [ train_op, self.learning_rate, self.loss_segment_all, self.loss, self.summary_op ], feed_dict={ self.step_ph: step, self.image_placeholder: batch_data, self.label_seg_placeholder: batch_segment }) self.summary_writer.add_summary(summary_now, global_step=step) else: _, learning_rate_r, loss_segment_r, loss_r = self.sess.run( [ train_op, self.learning_rate, self.loss_segment_all, self.loss ], feed_dict={ self.step_ph: step, self.image_placeholder: batch_data, self.label_seg_placeholder: batch_segment }) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time if step % self.cal_step == 0: pre_avg_loss = total_loss / self.cal_step total_loss = loss_r else: total_loss += loss_r total_loss_step = ((step % self.cal_step) + 1) if step % (self.cal_step // 10) == 0: Tools.print_info( 'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} ' '({:.3f} s/step)'.format(step, pre_avg_loss, total_loss / total_loss_step, loss_r, loss_segment_r, learning_rate_r, duration)) pass if step % self.print_step == 0: Tools.print_info("") pass pass pass
class Train(object): def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.has_255 = True # 是否预测边界 self.num_segment = 4 if self.has_255 else 3 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 1 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=self.has_255) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0: 3], axis=3)) tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[1]) tf.summary.image("4-other class", split[0]) tf.summary.image("5-background", split[-1]) if self.has_255: tf.summary.image("5-border", split[2]) pass tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass def build_net(self): # 数据 image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) label_segment_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None,)) # 网络 net = PSPNet({'data': image_placeholder}, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, last_pool_size=self.last_pool_size, filter_number=self.filter_number) raw_output_segment = net.layers['conv6_n_4'] raw_output_classes = net.layers['class_attention_fc'] # Predictions prediction = tf.reshape(raw_output_segment, [-1, self.num_segment]) pred_segment = tf.cast(tf.expand_dims(tf.argmax(raw_output_segment, axis=-1), axis=-1), tf.int32) pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32) # label label_batch = tf.image.resize_nearest_neighbor(label_segment_placeholder, tf.stack(raw_output_segment.get_shape()[1:3])) label_batch = tf.reshape(label_batch, [-1, ]) # 当前批次的准确率:accuracy accuracy_segment = tcm.accuracy(pred_segment, label_segment_placeholder) accuracy_classes = tcm.accuracy(pred_classes, label_classes_placeholder) # loss loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_batch, logits=prediction)) # 分类损失 loss_classes = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes_placeholder, logits=raw_output_classes)) # 总损失 loss = tf.add_n([loss_segment, 0.1 * loss_classes]) # 学习率策略 step_ph = tf.placeholder(dtype=tf.float32, shape=()) learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - step_ph / self.num_steps), 0.9)) train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) # 单独训练最后的分类 classes_trainable = [v for v in tf.trainable_variables() if 'class_attention' in v.name] train_classes_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, var_list=classes_trainable) return (image_placeholder, label_segment_placeholder, label_classes_placeholder, raw_output_segment, raw_output_classes, pred_segment, pred_classes, loss_segment, loss_classes, loss, accuracy_segment, accuracy_classes, step_ph, train_op, train_classes_op, learning_rate) def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir) for step in range(begin_step, self.num_steps): start_time = time.time() final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \ self.data_reader.next_batch_train() # train_op = self.train_classes_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r, summary_now) = self.sess.run( [self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment, self.loss_classes, self.loss, self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_classes_placeholder: final_batch_class}) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run( [self.accuracy_segment, self.accuracy_classes, train_op, self.learning_rate, self.loss_segment, self.loss_classes, self.loss, self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes], feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data, self.label_segment_placeholder: final_batch_ann, self.label_classes_placeholder: final_batch_class}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time Tools.print_info( 'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}' ' lr={:.6f} ({:.3f} s/step) {} {}'.format( step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r, learning_rate_r, duration, list(final_batch_class), list(pred_classes_r))) pass pass pass
class Train(object): def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-4 self.num_steps = 500001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,)) # 网络 self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes) self.segments, self.attentions, self.classes = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder) # 当前批次的准确率:accuracy self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate(self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8)) tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.split(segment, num_or_size_splits=2, axis=3) # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0])) tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1])) # attention for attention_index, attention in enumerate(self.attentions): attention = tf.split(attention, num_or_size_splits=2, axis=3) # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0])) tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1])) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass @staticmethod def cal_loss(segments, attentions, classes, label_segment, label_classes): # loss_segments = [] # for segment_index, segment in enumerate(segments): # now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(segment.get_shape()[1:3])) # now_loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( # labels=tf.reshape(now_label_segment, [-1, ]), # logits=tf.reshape(segment, [-1, 2]))) # loss_segments.append(now_loss_segment) # pass loss_attentions = [] for attention_index, attention in enumerate(attentions): now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(attention.get_shape()[1:3])) # now_loss_attention = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( # labels=tf.reshape(now_label_segment, [-1, ]), # logits=tf.reshape(attention, [-1, 2]))) now_loss_attention = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits( targets=tf.one_hot(tf.reshape(now_label_segment, [-1, ]), depth=2), logits=tf.reshape(attention, [-1, 2]), pos_weight=3)) loss_attentions.append(now_loss_attention) pass loss_classes = [] for class_one in classes: loss_classes.append(tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes, logits=class_one))) pass # loss_segment_all = tf.add_n(loss_segments) / len(loss_segments) loss_attention_all = tf.add_n(loss_attentions) / len(loss_attentions) loss_class_all = tf.add_n(loss_classes) / len(loss_classes) # 总损失 # loss = loss_segment_all + loss_attention_all + 0.5 * loss_class_all # return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes loss = loss_attention_all + loss_class_all return loss, loss_attention_all, loss_class_all, loss_attentions, loss_classes def train(self, save_pred_freq, begin_step=0): # 加载模型 Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain) total_loss = 0.0 pre_avg_loss = 0.0 total_acc = 0.0 pre_avg_acc = 0.0 for step in range(begin_step, self.num_steps): start_time = time.time() batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train() # train_op = self.train_attention_op train_op = self.train_op if step % self.print_step == 0: # summary 3 (accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, pred_classes_r, summary_now) = self.sess.run( [self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes, self.summary_op], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.mask_placeholder: batch_mask, self.label_seg_placeholder: batch_attention, self.label_cls_placeholder: batch_class}) self.summary_writer.add_summary(summary_now, global_step=step) else: (accuracy_classes_r, _, learning_rate_r, loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run( [self.accuracy_classes, train_op, self.learning_rate, self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes], feed_dict={self.step_ph: step, self.image_placeholder: batch_data, self.mask_placeholder: batch_mask, self.label_seg_placeholder: batch_attention, self.label_cls_placeholder: batch_class}) pass if step % save_pred_freq == 0: self.saver.save(self.sess, self.checkpoint_path, global_step=step) Tools.print_info('The checkpoint has been created.') pass duration = time.time() - start_time if step % self.cal_step == 0: pre_avg_loss = total_loss / self.cal_step pre_avg_acc = total_acc / self.cal_step total_loss = loss_r total_acc = accuracy_classes_r else: total_loss += loss_r total_acc += accuracy_classes_r total_loss_step = ((step % self.cal_step) + 1) if step % (self.cal_step // 10) == 0: Tools.print_info( 'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} ' 'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} ' 'lr={:.6f} ({:.3f} s/step) {} {}'.format( step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step, loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r, learning_rate_r, duration, list(batch_class), list(pred_classes_r))) pass if step % self.print_step == 0: Tools.print_info("") pass pass pass