def __init__(self, log_dir, save_dir):
        self.save_dir = Tools.new_dir(save_dir)
        self.log_dir = Tools.new_dir(log_dir)

        self.last_pool_size = 50
        self.input_size = [self.last_pool_size * 8, self.last_pool_size * 8]
        pass
예제 #2
0
 def inference(self, image_path, image_index, save_path=None):
     im_data = Data.load_data(image_path=image_path, input_size=self.input_size)
     im_data = np.expand_dims(im_data, axis=0)
     pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op],
                                                 feed_dict={self.image_placeholder: im_data})
     self.summary_writer.add_summary(summary_now, global_step=image_index)
     s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8))
     if save_path is None:
         s_image.show()
     else:
         Tools.new_dir(save_path)
         s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0]))
     pass
예제 #3
0
    def __init__(self, input_size, log_dir, model_name="model.ckpt"):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           False,
                           num_classes=self.num_classes)
        self.segments = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1),
                                    dtype=tf.uint8)

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)
        pass
    def __init__(self,
                 input_size,
                 summary_dir,
                 log_dir,
                 model_name="model.ckpt"):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))

        # 网络
        self.features = self._feature(self.image_placeholder)

        with tf.name_scope("image"):
            tf.summary.image("input", self.image_placeholder)
            pass

        with tf.name_scope("block"):
            for feature_index, feature in enumerate(self.features[:-1]):
                feature_split = tf.split(feature,
                                         num_or_size_splits=int(
                                             feature.shape[-1]),
                                         axis=-1)
                for feature_one_index, feature_one in enumerate(feature_split):
                    tf.summary.image(
                        "{}-{}".format(feature_index, feature_one_index),
                        feature_one)
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(summary_dir,
                                                    self.sess.graph)
        pass
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net(
        )

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        pass
예제 #6
0
    def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"):
        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes)
        self.segments, self.features = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8)

        with tf.name_scope("image"):
            tf.summary.image("input", self.image_placeholder)

            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                tf.summary.image("predict-{}".format(segment_index), tf.expand_dims(segment * 255, axis=-1))
                pass
            pass

        for key in list(self.features.keys()):
            with tf.name_scope(key):
                for feature_index, feature in enumerate(self.features[key]):
                    feature_split = tf.split(feature, num_or_size_splits=int(feature.shape[-1]), axis=-1)
                    for feature_one_index, feature_one in enumerate(feature_split):
                        tf.summary.image("{}-{}".format(feature_index, feature_one_index), feature_one)
                    pass
                pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
        pass
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 4  # 解码通道数:其他对象、attention、边界、背景
        self.segment_attention = 1  # 当解码的通道数是4时,attention所在的位置
        self.attention_module_num = 2  # attention模块中,解码通道数是2(背景、attention)的模块个数

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 5 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test,
                                has_255=True)

        # 数据
        self.image_placeholder = tf.placeholder(dtype=tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 4))
        self.label_segment_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_attention_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_classes_placeholder = tf.placeholder(dtype=tf.int32,
                                                        shape=(None, ))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           is_training=True,
                           num_classes=self.num_classes,
                           num_segment=self.num_segment,
                           segment_attention=self.segment_attention,
                           last_pool_size=self.last_pool_size,
                           filter_number=self.filter_number,
                           attention_module_num=self.attention_module_num)

        self.segments, self.attentions, self.classes = self.net.build()
        self.final_segment_logit = self.segments[0]
        self.final_class_logit = self.classes[0]

        # Predictions
        self.pred_segment = tf.cast(
            tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1),
                           axis=-1), tf.int32)
        self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1),
                                    tf.int32)

        # loss
        self.label_batch = tf.image.resize_nearest_neighbor(
            self.label_segment_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.label_attention_batch = tf.image.resize_nearest_neighbor(
            self.label_attention_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments,
            self.classes,
            self.label_batch,
            self.label_attention_batch,
            self.label_classes_placeholder,
            self.num_segment,
            attention_module_num=self.attention_module_num)

        # 当前批次的准确率:accuracy
        self.accuracy_segment = tcm.accuracy(self.pred_segment,
                                             self.label_segment_placeholder)
        self.accuracy_classes = tcm.accuracy(self.pred_classes,
                                             self.label_classes_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [
                v for v in tf.trainable_variables()
                if 'attention' in v.name or "class_attention" in v.name
            ]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index),
                                  loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_segment", self.accuracy_segment)
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            split = tf.split(self.image_placeholder,
                             num_or_size_splits=4,
                             axis=3)
            tf.summary.image("0-mask", split[3])
            tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
            tf.summary.image(
                "2-label",
                tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8))
            tf.summary.image(
                "3-attention",
                tf.cast(self.label_attention_placeholder * 255,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("predict"):
            tf.summary.image("predict",
                             tf.cast(self.pred_segment * 85, dtype=tf.uint8))
            pass

        with tf.name_scope("attention"):
            # attention
            for attention_index, attention in enumerate(self.attentions):
                tf.summary.image("{}-attention".format(attention_index),
                                 attention)
                pass
            pass

        with tf.name_scope("sigmoid"):
            for segment_index, segment in enumerate(self.segments):
                if segment_index < self.attention_module_num:
                    split = tf.split(segment,
                                     num_or_size_splits=self.num_segment,
                                     axis=3)
                    tf.summary.image("{}-other".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    tf.summary.image("{}-border".format(segment_index),
                                     split[2])
                    tf.summary.image("{}-background".format(segment_index),
                                     split[-1])
                else:
                    split = tf.split(segment, num_or_size_splits=2, axis=3)
                    tf.summary.image("{}-background".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
    def __init__(self, batch_size, last_pool_size, input_size, log_dir,
                 data_root_path, train_list, data_path, annotation_path, class_path,
                 model_name="model.ckpt", is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.has_255 = True  # 是否预测边界
        self.num_segment = 4 if self.has_255 else 3

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 1 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size,
                                is_test=is_test, has_255=self.has_255)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder,
         self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes,
         self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0: 3], axis=3))
        tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8))
        split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3)
        tf.summary.image("3-attention", split[1])
        tf.summary.image("4-other class", split[0])
        tf.summary.image("5-background", split[-1])
        if self.has_255:
            tf.summary.image("5-border", split[2])
            pass
        tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder,
         self.label_classes_placeholder, self.raw_output_segment,
         self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_0,
         self.accuracy_1, self.accuracy_classes, self.step_ph, self.train_op,
         self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_0", self.accuracy_0)
        tf.summary.scalar("accuracy_1", self.accuracy_1)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)
        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("image", tf.concat(split[0:3], axis=3))
        tf.summary.image("mask", split[3])
        tf.summary.image("label", self.label_segment_placeholder)
        tf.summary.image("raw_output_segment", self.raw_output_segment)
        tf.summary.image("pred_segment", tf.cast(self.pred_segment,
                                                 tf.float32))
        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
예제 #11
0
    def __init__(self, log_dir, data, model_name="model.ckpt", is_test=False):

        # 读取数据
        self.data_reader = data
        self.batch_size = self.data_reader.batch_size
        self.num_classes = self.data_reader.num_classes
        self.input_size = self.data_reader.image_size
        self.ratio = self.data_reader.ratio
        self.num_segment = self.data_reader.num_segment
        self.attention_class = self.data_reader.attention_class

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.last_pool_size = self.input_size[0] // self.ratio
        self.filter_number = 32
        self.learning_rate = 5e-3
        self.num_steps = 1000001
        self.print_step = 1 if is_test else 25

        # 网络
        (self.image_placeholder, self.label_segment_placeholder,
         self.label_classes_placeholder, self.raw_output_segment,
         self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss,
         self.accuracy_segment, self.accuracy_classes, self.step_ph,
         self.train_op, self.train_classes_op,
         self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
        tf.summary.image(
            "2-label",
            tf.cast(self.label_segment_placeholder * (255 //
                                                      (self.num_segment - 1)),
                    dtype=tf.uint8))

        split = tf.split(self.raw_output_segment,
                         num_or_size_splits=self.num_segment,
                         axis=3)
        tf.summary.image("3-attention", split[self.attention_class])
        for num_segment in range(self.num_segment):
            tf.summary.image("4-segment-output-{}".format(num_segment),
                             split[num_segment])
        tf.summary.image(
            "5-pred_segment",
            tf.cast(self.pred_segment * (255 // (self.num_segment - 1)),
                    dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass