def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net(
        )

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        pass
Exemplo n.º 2
0
 def inference(self, image_path, image_index):
     im_data = Data.load_data(image_path=image_path, input_size=self.input_size)
     im_data = np.expand_dims(im_data, axis=0)
     result, summary_now = self.sess.run([self.features[-1], self.summary_op],
                                    feed_dict={self.image_placeholder: im_data})
     self.summary_writer.add_summary(summary_now, global_step=image_index)
     print(result)
     pass
Exemplo n.º 3
0
 def inference(self, image_path, image_index, save_path=None):
     im_data = Data.load_data(image_path=image_path, input_size=self.input_size)
     im_data = np.expand_dims(im_data, axis=0)
     pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op],
                                                 feed_dict={self.image_placeholder: im_data})
     self.summary_writer.add_summary(summary_now, global_step=image_index)
     s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8))
     if save_path is None:
         s_image.show()
     else:
         Tools.new_dir(save_path)
         s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0]))
     pass
    def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0):
        # 读入图片数据
        final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image(
            image_filename, where=where, annotation_filename=annotation_filename,
            ann_index=ann_index, image_size=self.input_size)

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=1, last_pool_size=self.last_pool_size,
                     filter_number=32)

        # 输出/预测
        raw_output_op = net.layers["conv6_n"]
        sigmoid_output_op = tf.sigmoid(raw_output_op)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        # 运行
        raw_output, sigmoid_output = sess.run([raw_output_op, sigmoid_output_op], feed_dict={img_placeholder: final_batch_data})

        # 保存
        Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "data.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(sigmoid_output[0] * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(np.greater(raw_output[0], 0.5) * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred_raw.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(np.greater(sigmoid_output[0], 0.5) * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred_sigmoid.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "mask.bmp"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "ann.bmp"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        pass
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 4  # 解码通道数:其他对象、attention、边界、背景
        self.segment_attention = 1  # 当解码的通道数是4时,attention所在的位置
        self.attention_module_num = 2  # attention模块中,解码通道数是2(背景、attention)的模块个数

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 5 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test,
                                has_255=True)

        # 数据
        self.image_placeholder = tf.placeholder(dtype=tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 4))
        self.label_segment_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_attention_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_classes_placeholder = tf.placeholder(dtype=tf.int32,
                                                        shape=(None, ))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           is_training=True,
                           num_classes=self.num_classes,
                           num_segment=self.num_segment,
                           segment_attention=self.segment_attention,
                           last_pool_size=self.last_pool_size,
                           filter_number=self.filter_number,
                           attention_module_num=self.attention_module_num)

        self.segments, self.attentions, self.classes = self.net.build()
        self.final_segment_logit = self.segments[0]
        self.final_class_logit = self.classes[0]

        # Predictions
        self.pred_segment = tf.cast(
            tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1),
                           axis=-1), tf.int32)
        self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1),
                                    tf.int32)

        # loss
        self.label_batch = tf.image.resize_nearest_neighbor(
            self.label_segment_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.label_attention_batch = tf.image.resize_nearest_neighbor(
            self.label_attention_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments,
            self.classes,
            self.label_batch,
            self.label_attention_batch,
            self.label_classes_placeholder,
            self.num_segment,
            attention_module_num=self.attention_module_num)

        # 当前批次的准确率:accuracy
        self.accuracy_segment = tcm.accuracy(self.pred_segment,
                                             self.label_segment_placeholder)
        self.accuracy_classes = tcm.accuracy(self.pred_classes,
                                             self.label_classes_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [
                v for v in tf.trainable_variables()
                if 'attention' in v.name or "class_attention" in v.name
            ]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index),
                                  loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_segment", self.accuracy_segment)
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            split = tf.split(self.image_placeholder,
                             num_or_size_splits=4,
                             axis=3)
            tf.summary.image("0-mask", split[3])
            tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
            tf.summary.image(
                "2-label",
                tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8))
            tf.summary.image(
                "3-attention",
                tf.cast(self.label_attention_placeholder * 255,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("predict"):
            tf.summary.image("predict",
                             tf.cast(self.pred_segment * 85, dtype=tf.uint8))
            pass

        with tf.name_scope("attention"):
            # attention
            for attention_index, attention in enumerate(self.attentions):
                tf.summary.image("{}-attention".format(attention_index),
                                 attention)
                pass
            pass

        with tf.name_scope("sigmoid"):
            for segment_index, segment in enumerate(self.segments):
                if segment_index < self.attention_module_num:
                    split = tf.split(segment,
                                     num_or_size_splits=self.num_segment,
                                     axis=3)
                    tf.summary.image("{}-other".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    tf.summary.image("{}-border".format(segment_index),
                                     split[2])
                    tf.summary.image("{}-background".format(segment_index),
                                     split[-1])
                else:
                    split = tf.split(segment, num_or_size_splits=2, axis=3)
                    tf.summary.image("{}-background".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
class Train(object):
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 4  # 解码通道数:其他对象、attention、边界、背景
        self.segment_attention = 1  # 当解码的通道数是4时,attention所在的位置
        self.attention_module_num = 2  # attention模块中,解码通道数是2(背景、attention)的模块个数

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 5 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test,
                                has_255=True)

        # 数据
        self.image_placeholder = tf.placeholder(dtype=tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 4))
        self.label_segment_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_attention_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_classes_placeholder = tf.placeholder(dtype=tf.int32,
                                                        shape=(None, ))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           is_training=True,
                           num_classes=self.num_classes,
                           num_segment=self.num_segment,
                           segment_attention=self.segment_attention,
                           last_pool_size=self.last_pool_size,
                           filter_number=self.filter_number,
                           attention_module_num=self.attention_module_num)

        self.segments, self.attentions, self.classes = self.net.build()
        self.final_segment_logit = self.segments[0]
        self.final_class_logit = self.classes[0]

        # Predictions
        self.pred_segment = tf.cast(
            tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1),
                           axis=-1), tf.int32)
        self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1),
                                    tf.int32)

        # loss
        self.label_batch = tf.image.resize_nearest_neighbor(
            self.label_segment_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.label_attention_batch = tf.image.resize_nearest_neighbor(
            self.label_attention_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments,
            self.classes,
            self.label_batch,
            self.label_attention_batch,
            self.label_classes_placeholder,
            self.num_segment,
            attention_module_num=self.attention_module_num)

        # 当前批次的准确率:accuracy
        self.accuracy_segment = tcm.accuracy(self.pred_segment,
                                             self.label_segment_placeholder)
        self.accuracy_classes = tcm.accuracy(self.pred_classes,
                                             self.label_classes_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [
                v for v in tf.trainable_variables()
                if 'attention' in v.name or "class_attention" in v.name
            ]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index),
                                  loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_segment", self.accuracy_segment)
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            split = tf.split(self.image_placeholder,
                             num_or_size_splits=4,
                             axis=3)
            tf.summary.image("0-mask", split[3])
            tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
            tf.summary.image(
                "2-label",
                tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8))
            tf.summary.image(
                "3-attention",
                tf.cast(self.label_attention_placeholder * 255,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("predict"):
            tf.summary.image("predict",
                             tf.cast(self.pred_segment * 85, dtype=tf.uint8))
            pass

        with tf.name_scope("attention"):
            # attention
            for attention_index, attention in enumerate(self.attentions):
                tf.summary.image("{}-attention".format(attention_index),
                                 attention)
                pass
            pass

        with tf.name_scope("sigmoid"):
            for segment_index, segment in enumerate(self.segments):
                if segment_index < self.attention_module_num:
                    split = tf.split(segment,
                                     num_or_size_splits=self.num_segment,
                                     axis=3)
                    tf.summary.image("{}-other".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    tf.summary.image("{}-border".format(segment_index),
                                     split[2])
                    tf.summary.image("{}-background".format(segment_index),
                                     split[-1])
                else:
                    split = tf.split(segment, num_or_size_splits=2, axis=3)
                    tf.summary.image("{}-background".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass

    @staticmethod
    def cal_loss(segments, classes, label_segment, label_attention,
                 label_classes, num_segment, attention_module_num):

        label_segment = tf.reshape(label_segment, [
            -1,
        ])
        label_attention = tf.reshape(label_attention, [
            -1,
        ])

        loss_segments = []
        for segment_index, segment in enumerate(segments):
            if segment_index < len(segments) - attention_module_num:
                now_loss_segment = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=label_segment,
                        logits=tf.reshape(segment, [-1, num_segment])))
                loss_segments.append(now_loss_segment)
            else:
                # now_loss_segment = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
                #     targets=tf.one_hot(label_attention, depth=2), logits=tf.reshape(segment, [-1, 2]), pos_weight=3))
                segment = tf.split(segment, num_or_size_splits=2, axis=3)[1]
                now_loss_segment = tf.reduce_mean(
                    tf.nn.weighted_cross_entropy_with_logits(
                        targets=tf.cast(label_attention, dtype=tf.float32),
                        logits=tf.reshape(segment, [
                            -1,
                        ]),
                        pos_weight=3)) * 2
                loss_segments.append(now_loss_segment)
            pass

        loss_classes = []
        for class_one in classes:
            loss_classes.append(
                tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=label_classes, logits=class_one)))
            pass

        loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        loss_class_all = tf.add_n(loss_classes) / len(loss_classes)
        # 总损失
        loss = loss_segment_all + 0.1 * loss_class_all
        return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes

    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            final_batch_data, final_batch_ann, final_batch_ann_attention, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, raw_output_r,
                 pred_segment_r, raw_output_classes_r, pred_classes_r,
                 summary_now) = self.sess.run(
                     [
                         self.accuracy_segment, self.accuracy_classes,
                         train_op, self.learning_rate, self.loss_segment_all,
                         self.loss_class_all, self.loss,
                         self.final_segment_logit, self.pred_segment,
                         self.final_class_logit, self.pred_classes,
                         self.summary_op
                     ],
                     feed_dict={
                         self.step_ph: step,
                         self.image_placeholder: final_batch_data,
                         self.label_segment_placeholder: final_batch_ann,
                         self.label_attention_placeholder:
                         final_batch_ann_attention,
                         self.label_classes_placeholder: final_batch_class
                     })
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, raw_output_r,
                 pred_segment_r, raw_output_classes_r,
                 pred_classes_r) = self.sess.run(
                     [
                         self.accuracy_segment, self.accuracy_classes,
                         train_op, self.learning_rate, self.loss_segment_all,
                         self.loss_class_all, self.loss,
                         self.final_segment_logit, self.pred_segment,
                         self.final_class_logit, self.pred_classes
                     ],
                     feed_dict={
                         self.step_ph: step,
                         self.image_placeholder: final_batch_data,
                         self.label_segment_placeholder: final_batch_ann,
                         self.label_attention_placeholder:
                         final_batch_ann_attention,
                         self.label_classes_placeholder: final_batch_class
                     })
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            Tools.print_info(
                'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}'
                ' lr={:.6f} ({:.3f} s/step) {} {}'.format(
                    step, loss_r, loss_segment_r, loss_classes_r,
                    accuracy_segment_r, accuracy_classes_r, learning_rate_r,
                    duration, list(final_batch_class), list(pred_classes_r)))

            pass

        pass

    pass
    def __init__(self, batch_size, last_pool_size, input_size, log_dir,
                 data_root_path, train_list, data_path, annotation_path, class_path,
                 model_name="model.ckpt", is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.has_255 = True  # 是否预测边界
        self.num_segment = 4 if self.has_255 else 3

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 1 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size,
                                is_test=is_test, has_255=self.has_255)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder,
         self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes,
         self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0: 3], axis=3))
        tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8))
        split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3)
        tf.summary.image("3-attention", split[1])
        tf.summary.image("4-other class", split[0])
        tf.summary.image("5-background", split[-1])
        if self.has_255:
            tf.summary.image("5-border", split[2])
            pass
        tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass
class Train(object):

    def __init__(self, batch_size, last_pool_size, input_size, log_dir,
                 data_root_path, train_list, data_path, annotation_path, class_path,
                 model_name="model.ckpt", is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.has_255 = True  # 是否预测边界
        self.num_segment = 4 if self.has_255 else 3

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 1 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size,
                                is_test=is_test, has_255=self.has_255)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder,
         self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes,
         self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0: 3], axis=3))
        tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8))
        split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3)
        tf.summary.image("3-attention", split[1])
        tf.summary.image("4-other class", split[0])
        tf.summary.image("5-background", split[-1])
        if self.has_255:
            tf.summary.image("5-border", split[2])
            pass
        tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass

    def build_net(self):
        # 数据
        image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        label_segment_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio,
                                                                          self.input_size[1] // self.ratio, 1))
        label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None,))

        # 网络
        net = PSPNet({'data': image_placeholder}, is_training=True, num_classes=self.num_classes,
                     num_segment=self.num_segment, last_pool_size=self.last_pool_size, filter_number=self.filter_number)
        raw_output_segment = net.layers['conv6_n_4']
        raw_output_classes = net.layers['class_attention_fc']

        # Predictions
        prediction = tf.reshape(raw_output_segment, [-1, self.num_segment])
        pred_segment = tf.cast(tf.expand_dims(tf.argmax(raw_output_segment, axis=-1), axis=-1), tf.int32)
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32)

        # label
        label_batch = tf.image.resize_nearest_neighbor(label_segment_placeholder,
                                                       tf.stack(raw_output_segment.get_shape()[1:3]))
        label_batch = tf.reshape(label_batch, [-1, ])

        # 当前批次的准确率:accuracy
        accuracy_segment = tcm.accuracy(pred_segment, label_segment_placeholder)
        accuracy_classes = tcm.accuracy(pred_classes, label_classes_placeholder)

        # loss
        loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_batch,
                                                                                     logits=prediction))

        # 分类损失
        loss_classes = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes_placeholder,
                                                                                     logits=raw_output_classes))
        # 总损失
        loss = tf.add_n([loss_segment, 0.1 * loss_classes])

        # 学习率策略
        step_ph = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - step_ph / self.num_steps), 0.9))
        train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

        # 单独训练最后的分类
        classes_trainable = [v for v in tf.trainable_variables() if 'class_attention' in v.name]
        train_classes_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, var_list=classes_trainable)

        return (image_placeholder, label_segment_placeholder, label_classes_placeholder,
                raw_output_segment, raw_output_classes, pred_segment, pred_classes,
                loss_segment, loss_classes, loss, accuracy_segment, accuracy_classes,
                step_ph, train_op, train_classes_op, learning_rate)

    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            # train_op = self.train_classes_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r,
                 summary_now) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes,
                     self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            Tools.print_info(
                'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}'
                ' lr={:.6f} ({:.3f} s/step) {} {}'.format(
                    step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r,
                    learning_rate_r, duration, list(final_batch_class), list(pred_classes_r)))

            pass

        pass

    pass
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
class Train(object):

    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass

    @staticmethod
    def cal_loss(segments, attentions, classes, label_segment, label_classes):

        # loss_segments = []
        # for segment_index, segment in enumerate(segments):
        #     now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(segment.get_shape()[1:3]))
        #     now_loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        #         labels=tf.reshape(now_label_segment, [-1, ]),
        #         logits=tf.reshape(segment, [-1, 2])))
        #     loss_segments.append(now_loss_segment)
        #     pass

        loss_attentions = []
        for attention_index, attention in enumerate(attentions):
            now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(attention.get_shape()[1:3]))

            # now_loss_attention = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            #     labels=tf.reshape(now_label_segment, [-1, ]),
            #     logits=tf.reshape(attention, [-1, 2])))

            now_loss_attention = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
                targets=tf.one_hot(tf.reshape(now_label_segment, [-1, ]), depth=2),
                logits=tf.reshape(attention, [-1, 2]), pos_weight=3))

            loss_attentions.append(now_loss_attention)
            pass

        loss_classes = []
        for class_one in classes:
            loss_classes.append(tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes, logits=class_one)))
            pass

        # loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        loss_attention_all = tf.add_n(loss_attentions) / len(loss_attentions)
        loss_class_all = tf.add_n(loss_classes) / len(loss_classes)
        # 总损失
        # loss = loss_segment_all + loss_attention_all + 0.5 * loss_class_all
        # return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes
        loss = loss_attention_all + loss_class_all
        return loss, loss_attention_all, loss_class_all, loss_attentions, loss_classes

    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        total_acc = 0.0
        pre_avg_acc = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 pred_classes_r, summary_now) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss,
                     self.pred_classes, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                pre_avg_acc = total_acc / self.cal_step
                total_loss = loss_r
                total_acc = accuracy_classes_r
            else:
                total_loss += loss_r
                total_acc += accuracy_classes_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} '
                    'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} '
                    'lr={:.6f} ({:.3f} s/step) {} {}'.format(
                        step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step,
                        loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r,
                        learning_rate_r, duration, list(batch_class), list(pred_classes_r)))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass

    pass
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder,
         self.label_classes_placeholder, self.raw_output_segment,
         self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_0,
         self.accuracy_1, self.accuracy_classes, self.step_ph, self.train_op,
         self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_0", self.accuracy_0)
        tf.summary.scalar("accuracy_1", self.accuracy_1)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)
        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("image", tf.concat(split[0:3], axis=3))
        tf.summary.image("mask", split[3])
        tf.summary.image("label", self.label_segment_placeholder)
        tf.summary.image("raw_output_segment", self.raw_output_segment)
        tf.summary.image("pred_segment", tf.cast(self.pred_segment,
                                                 tf.float32))
        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
    def run(self, image_filename_or_data, mask_color, opacity):

        plt.ion()
        plt.axis('off')

        if isinstance(image_filename_or_data, str):
            image_data = np.array(Image.open(image_filename_or_data))
        elif isinstance(image_filename_or_data, list) or isinstance(image_filename_or_data, np.ndarray):
            image_data = image_filename_or_data
        else:
            print("image_filename_or_data is error")
            return

        plt.imshow(image_data)
        plt.title('Click one point of the object that you interested')

        try:

            while 1:
                object_point = np.array(plt.ginput(1, timeout=0)).astype(np.int)[0]
                where = [int(self.input_size[0] * object_point[1] / len(image_data)),
                         int(self.input_size[1] * object_point[0] / len(image_data[0]))]
                print("point=[{},{}] where=[{},{}]".format(object_point[0], object_point[1], where[0], where[1]))

                final_batch_data, data_raw, gaussian_mask = Data.load_image(image_data, where=where,
                                                                            image_size=self.input_size)

                print("begin to run ...")

                # 运行
                predict_output_r, pred_classes_r = self.sess.run([self.predict_output, self.pred_classes],
                                                                 feed_dict={self.img_placeholder: final_batch_data})

                print("end run")

                # 类别
                print("the class is {}({})".format(pred_classes_r[0], CategoryNames[pred_classes_r[0]]))

                # 分割
                segment = np.squeeze(np.asarray(np.where(predict_output_r[0] == 1, 1, 0), dtype=np.uint8))
                segment = np.asarray(Image.fromarray(segment).resize((len(image_data[0]), len(image_data))))

                image_mask = np.ndarray(image_data.shape)
                image_mask[:, :, 0] = (1 - segment) * image_data[:, :, 0] + segment * (
                    opacity * mask_color[0] + (1 - opacity) * image_data[:, :, 0])
                image_mask[:, :, 1] = (1 - segment) * image_data[:, :, 1] + segment * (
                    opacity * mask_color[1] + (1 - opacity) * image_data[:, :, 1])
                image_mask[:, :, 2] = (1 - segment) * image_data[:, :, 2] + segment * (
                    opacity * mask_color[2] + (1 - opacity) * image_data[:, :, 2])

                plt.clf()  # clear image
                plt.text(len(image_data[0]) // 2 - 10, -6, CategoryNames[pred_classes_r[0]], fontsize=15)
                plt.imshow(image_mask.astype(np.uint8))

                print("")
                pass

        except Exception:
            print("..................")
            print("...... close .....")
            print("..................")
            pass

        pass
Exemplo n.º 13
0
    pass


if __name__ == '__main__':

    is_win = False
    is_voc = False

    if is_win:
        if is_voc:
            data_reader = Data(
                data_root_path=
                "C:\\ALISURE\\DataModel\\Data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\",
                data_list="ImageSets\\Segmentation\\train.txt",
                data_path="JPEGImages\\",
                annotation_path="SegmentationObject\\",
                class_path="SegmentationClass\\",
                batch_size=3,
                image_size=[720, 720],
                is_test=False)
        else:
            data_reader = COCOData(
                data_root_path="C:\\ALISURE\\DataModel\\Data\\COCO",
                annotation_path="annotations_trainval2014\\annotations",
                data_type="val2014",
                batch_size=3,
                image_size=[720, 720])
            pass

        Train(log_dir="./model/coco/first", data=data_reader,
              is_test=True).train(save_pred_freq=2, begin_step=0)
    def run(self,
            result_filename,
            image_filename,
            where=None,
            annotation_filename=None,
            ann_index=0):
        # 读入图片数据
        if annotation_filename:
            final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image(
                image_filename,
                where=where,
                annotation_filename=annotation_filename,
                ann_index=ann_index,
                image_size=self.input_size)
        else:
            final_batch_data, data_raw, gaussian_mask = Data.load_image(
                image_filename,
                where=where,
                annotation_filename=annotation_filename,
                ann_index=ann_index,
                image_size=self.input_size)

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32,
                                         shape=(None, self.input_size[0],
                                                self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder},
                     is_training=True,
                     num_classes=21,
                     last_pool_size=self.last_pool_size,
                     filter_number=32,
                     num_segment=4)

        # 输出/预测
        raw_output_op = net.layers["conv6_n_4"]
        sigmoid_output_op = tf.sigmoid(raw_output_op)
        predict_output_op = tf.argmax(sigmoid_output_op, axis=-1)

        raw_output_classes = net.layers['class_attention_fc']
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1),
                               tf.int32)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        # 运行
        raw_output, sigmoid_output, predict_output_r, raw_output_classes_r, pred_classes_r = sess.run(
            [
                raw_output_op, sigmoid_output_op, predict_output_op,
                raw_output_classes, pred_classes
            ],
            feed_dict={img_placeholder: final_batch_data})

        # 保存
        print("{} {} {}".format(pred_classes_r[0],
                                CategoryNames[pred_classes_r[0]],
                                raw_output_classes_r))
        print("result in {}".format(
            os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "data.png"))

        output_result = np.squeeze(
            np.split(np.asarray(sigmoid_output[0] * 255, dtype=np.uint8),
                     axis=-1,
                     indices_or_sections=4))

        Image.fromarray(
            np.squeeze(
                np.asarray(predict_output_r[0] * 255 // 4,
                           dtype=np.uint8))).save(
                               os.path.join(self.save_dir,
                                            result_filename + "pred.png"))
        Image.fromarray(output_result[0]).save(
            os.path.join(self.save_dir, result_filename + "pred_0.png"))
        Image.fromarray(output_result[1]).save(
            os.path.join(self.save_dir, result_filename + "pred_1.png"))
        Image.fromarray(output_result[2]).save(
            os.path.join(self.save_dir, result_filename + "pred_2.png"))
        Image.fromarray(output_result[3]).save(
            os.path.join(self.save_dir, result_filename + "pred_3.png"))

        Image.fromarray(
            np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save(
                os.path.join(self.save_dir, result_filename + "mask.bmp"))

        if annotation_filename:
            Image.fromarray(
                np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save(
                    os.path.join(self.save_dir, result_filename + "ann.bmp"))
            pass
        pass
class Train(object):
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net(
        )

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        pass

    def build_net(self):
        # 数据
        image_placeholder = tf.placeholder(dtype=tf.float32,
                                           shape=(None, self.input_size[0],
                                                  self.input_size[1], 4))

        # 网络
        net = PSPNet({'data': image_placeholder},
                     is_training=True,
                     num_classes=self.num_classes,
                     num_segment=self.num_segment,
                     last_pool_size=self.last_pool_size,
                     filter_number=self.filter_number)
        raw_output_segment = net.layers['conv6_n']
        raw_output_classes = net.layers['class_attention_fc']

        # Predictions
        pred_segment = tf.cast(tf.greater(raw_output_segment, 0.5), tf.int32)
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1),
                               tf.int32)

        return image_placeholder, raw_output_segment, raw_output_classes, pred_segment, pred_classes

    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, 5):

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            (raw_output_r, pred_segment_r, raw_output_classes_r,
             pred_classes_r) = self.sess.run(
                 [
                     self.raw_output_segment, self.pred_segment,
                     self.raw_output_classes, self.pred_classes
                 ],
                 feed_dict={self.image_placeholder: final_batch_data})

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            # Tools.print_info('step {:d} {} {} {}'.format(
            #     step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r)))
            Tools.print_info('step {:d} {} {}'.format(step,
                                                      list(final_batch_class),
                                                      list(pred_classes_r)))

            pass

        pass

    pass
class Train(object):
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass

    def cal_loss(self, segments, label_segment):
        loss_segments = []
        for segment_index, segment in enumerate(segments):
            # if "segment_side" not in segment.name:
            #     continue
            now_label_segment = tf.image.resize_nearest_neighbor(
                label_segment, tf.stack(segment.get_shape()[1:3]))

            now_loss_segment = tf.reduce_mean(
                tf.nn.weighted_cross_entropy_with_logits(
                    targets=tf.one_hot(tf.reshape(now_label_segment, [
                        -1,
                    ]),
                                       depth=self.num_classes),
                    logits=tf.reshape(segment, [-1, self.num_classes]),
                    pos_weight=1))

            loss_segments.append(now_loss_segment)
            pass

        loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        # 总损失
        loss = loss_segment_all
        return loss, loss_segment_all, loss_segments

    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_segment = self.data_reader.next_batch_train()

            train_op = self.train_op
            # train_op = self.train_segment_side_op

            if step % self.print_step == 0:
                # summary 3
                _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run(
                    [
                        train_op, self.learning_rate, self.loss_segment_all,
                        self.loss, self.summary_op
                    ],
                    feed_dict={
                        self.step_ph: step,
                        self.image_placeholder: batch_data,
                        self.label_seg_placeholder: batch_segment
                    })
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                _, learning_rate_r, loss_segment_r, loss_r = self.sess.run(
                    [
                        train_op, self.learning_rate, self.loss_segment_all,
                        self.loss
                    ],
                    feed_dict={
                        self.step_ph: step,
                        self.image_placeholder: batch_data,
                        self.label_seg_placeholder: batch_segment
                    })
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                total_loss = loss_r
            else:
                total_loss += loss_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} '
                    '({:.3f} s/step)'.format(step, pre_avg_loss,
                                             total_loss / total_loss_step,
                                             loss_r, loss_segment_r,
                                             learning_rate_r, duration))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass

    pass
    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass