def __init__(self, log_dir, save_dir): self.save_dir = Tools.new_dir(save_dir) self.log_dir = Tools.new_dir(log_dir) self.last_pool_size = 50 self.input_size = [self.last_pool_size * 8, self.last_pool_size * 8] pass
def inference(self, image_path, image_index, save_path=None): im_data = Data.load_data(image_path=image_path, input_size=self.input_size) im_data = np.expand_dims(im_data, axis=0) pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op], feed_dict={self.image_placeholder: im_data}) self.summary_writer.add_summary(summary_now, global_step=image_index) s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8)) if save_path is None: s_image.show() else: Tools.new_dir(save_path) s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0])) pass
def __init__(self, input_size, log_dir, model_name="model.ckpt"): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.num_classes = 21 # 网络 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) # 网络 self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes) self.segments = self.net.build() self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass
def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.num_classes = 21 # 网络 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) # 网络 self.features = self._feature(self.image_placeholder) with tf.name_scope("image"): tf.summary.image("input", self.image_placeholder) pass with tf.name_scope("block"): for feature_index, feature in enumerate(self.features[:-1]): feature_split = tf.split(feature, num_or_size_splits=int( feature.shape[-1]), axis=-1) for feature_one_index, feature_one in enumerate(feature_split): tf.summary.image( "{}-{}".format(feature_index, feature_one_index), feature_one) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net( ) # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) pass
def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.num_classes = 21 # 网络 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) # 网络 self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes) self.segments, self.features = self.net.build() self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8) with tf.name_scope("image"): tf.summary.image("input", self.image_placeholder) # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) tf.summary.image("predict-{}".format(segment_index), tf.expand_dims(segment * 255, axis=-1)) pass pass for key in list(self.features.keys()): with tf.name_scope(key): for feature_index, feature in enumerate(self.features[key]): feature_split = tf.split(feature, num_or_size_splits=int(feature.shape[-1]), axis=-1) for feature_one_index, feature_one in enumerate(feature_split): tf.summary.image("{}-{}".format(feature_index, feature_one_index), feature_one) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 4 # 解码通道数:其他对象、attention、边界、背景 self.segment_attention = 1 # 当解码的通道数是4时,attention所在的位置 self.attention_module_num = 2 # attention模块中,解码通道数是2(背景、attention)的模块个数 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 5 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=True) # 数据 self.image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4)) self.label_segment_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_attention_placeholder = tf.placeholder( dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio, self.input_size[1] // self.ratio, 1)) self.label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, )) # 网络 self.net = BAISNet(self.image_placeholder, is_training=True, num_classes=self.num_classes, num_segment=self.num_segment, segment_attention=self.segment_attention, last_pool_size=self.last_pool_size, filter_number=self.filter_number, attention_module_num=self.attention_module_num) self.segments, self.attentions, self.classes = self.net.build() self.final_segment_logit = self.segments[0] self.final_class_logit = self.classes[0] # Predictions self.pred_segment = tf.cast( tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1), axis=-1), tf.int32) self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1), tf.int32) # loss self.label_batch = tf.image.resize_nearest_neighbor( self.label_segment_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.label_attention_batch = tf.image.resize_nearest_neighbor( self.label_attention_placeholder, tf.stack(self.final_segment_logit.get_shape()[1:3])) self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.classes, self.label_batch, self.label_attention_batch, self.label_classes_placeholder, self.num_segment, attention_module_num=self.attention_module_num) # 当前批次的准确率:accuracy self.accuracy_segment = tcm.accuracy(self.pred_segment, self.label_segment_placeholder) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_classes_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [ v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name ] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8)) tf.summary.image( "3-attention", tf.cast(self.label_attention_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("predict"): tf.summary.image("predict", tf.cast(self.pred_segment * 85, dtype=tf.uint8)) pass with tf.name_scope("attention"): # attention for attention_index, attention in enumerate(self.attentions): tf.summary.image("{}-attention".format(attention_index), attention) pass pass with tf.name_scope("sigmoid"): for segment_index, segment in enumerate(self.segments): if segment_index < self.attention_module_num: split = tf.split(segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("{}-other".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) tf.summary.image("{}-border".format(segment_index), split[2]) tf.summary.image("{}-background".format(segment_index), split[-1]) else: split = tf.split(segment, num_or_size_splits=2, axis=3) tf.summary.image("{}-background".format(segment_index), split[0]) tf.summary.image("{}-attention".format(segment_index), split[1]) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.has_255 = True # 是否预测边界 self.num_segment = 4 if self.has_255 else 3 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 self.print_step = 1 if is_test else 25 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test, has_255=self.has_255) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0: 3], axis=3)) tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[1]) tf.summary.image("4-other class", split[0]) tf.summary.image("5-background", split[-1]) if self.has_255: tf.summary.image("5-border", split[2]) pass tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 100001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.label_seg_placeholder = tf.placeholder( tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) # 网络 self.net = BAISNet(self.image_placeholder, True, num_classes=self.num_classes) self.segments, self.features = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss( self.segments, self.label_seg_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul( tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.8)) self.train_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss) # 单独训练最后的 segment_side segment_side_trainable = [ v for v in tf.trainable_variables() if 'segment_side' in v.name ] print(len(segment_side_trainable)) self.train_segment_side_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=segment_side_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) for loss_segment_index, loss_segment in enumerate( self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image( "2-segment", tf.cast(self.label_seg_placeholder * 255 // self.num_classes, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8) segment = tf.split(segment, num_or_size_splits=self.batch_size, axis=0) ii = 3 if self.batch_size >= 3 else self.batch_size for i in range(ii): tf.summary.image( "predict-{}-{}".format(segment_index, i), tf.expand_dims(segment[i] * 255 // self.num_classes, axis=-1)) pass pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, batch_size, last_pool_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 self.num_segment = 1 # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.ratio = 8 self.last_pool_size = last_pool_size self.filter_number = 32 # 和模型训练相关的参数 self.learning_rate = 5e-3 self.num_steps = 500001 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_0, self.accuracy_1, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_0", self.accuracy_0) tf.summary.scalar("accuracy_1", self.accuracy_1) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("image", tf.concat(split[0:3], axis=3)) tf.summary.image("mask", split[3]) tf.summary.image("label", self.label_segment_placeholder) tf.summary.image("raw_output_segment", self.raw_output_segment) tf.summary.image("pred_segment", tf.cast(self.pred_segment, tf.float32)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, log_dir, data, model_name="model.ckpt", is_test=False): # 读取数据 self.data_reader = data self.batch_size = self.data_reader.batch_size self.num_classes = self.data_reader.num_classes self.input_size = self.data_reader.image_size self.ratio = self.data_reader.ratio self.num_segment = self.data_reader.num_segment self.attention_class = self.data_reader.attention_class # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size self.last_pool_size = self.input_size[0] // self.ratio self.filter_number = 32 self.learning_rate = 5e-3 self.num_steps = 1000001 self.print_step = 1 if is_test else 25 # 网络 (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes, self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes, self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net() # summary 1 tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment) tf.summary.scalar("loss_classes", self.loss_classes) tf.summary.scalar("accuracy_segment", self.accuracy_segment) tf.summary.scalar("accuracy_classes", self.accuracy_classes) split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3) tf.summary.image("0-mask", split[3]) tf.summary.image("1-image", tf.concat(split[0:3], axis=3)) tf.summary.image( "2-label", tf.cast(self.label_segment_placeholder * (255 // (self.num_segment - 1)), dtype=tf.uint8)) split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3) tf.summary.image("3-attention", split[self.attention_class]) for num_segment in range(self.num_segment): tf.summary.image("4-segment-output-{}".format(num_segment), split[num_segment]) tf.summary.image( "5-pred_segment", tf.cast(self.pred_segment * (255 // (self.num_segment - 1)), dtype=tf.uint8)) self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass
def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path, annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False): # 和保存模型相关的参数 self.log_dir = Tools.new_dir(log_dir) self.model_name = model_name self.checkpoint_path = os.path.join(self.log_dir, self.model_name) self.pretrain = pretrain # 和数据相关的参数 self.input_size = input_size self.batch_size = batch_size self.num_classes = 21 # 和模型训练相关的参数 self.learning_rate = 5e-4 self.num_steps = 500001 self.print_step = 10 if is_test else 100 self.cal_step = 100 if is_test else 1000 # 读取数据 self.data_reader = Data(data_root_path=data_root_path, data_list=train_list, data_path=data_path, annotation_path=annotation_path, class_path=class_path, batch_size=self.batch_size, image_size=self.input_size, is_test=is_test) # 数据 self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3)) self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1)) self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,)) # 网络 self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes) self.segments, self.attentions, self.classes = self.net.build() # loss self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss( self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder) # 当前批次的准确率:accuracy self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32) self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder) with tf.name_scope("train"): # 学习率策略 self.step_ph = tf.placeholder(dtype=tf.float32, shape=()) self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - self.step_ph / self.num_steps), 0.9)) self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) # 单独训练最后的attention attention_trainable = [v for v in tf.trainable_variables() if 'attention' in v.name or "class_attention" in v.name] print(len(attention_trainable)) self.train_attention_op = tf.train.GradientDescentOptimizer( self.learning_rate).minimize(self.loss, var_list=attention_trainable) pass # summary 1 with tf.name_scope("loss"): tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_segment", self.loss_segment_all) tf.summary.scalar("loss_class", self.loss_class_all) for loss_segment_index, loss_segment in enumerate(self.loss_segments): tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment) for loss_class_index, loss_class in enumerate(self.loss_classes): tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class) pass with tf.name_scope("accuracy"): tf.summary.scalar("accuracy_classes", self.accuracy_classes) pass with tf.name_scope("label"): tf.summary.image("1-image", self.image_placeholder) tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8)) tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8)) pass with tf.name_scope("result"): # segment for segment_index, segment in enumerate(self.segments): segment = tf.split(segment, num_or_size_splits=2, axis=3) # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0])) tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1])) # attention for attention_index, attention in enumerate(self.attentions): attention = tf.split(attention, num_or_size_splits=2, axis=3) # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0])) tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1])) pass pass self.summary_op = tf.summary.merge_all() # sess 和 saver self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) # summary 2 self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) pass