Example #1
0
    def __init__(self, input_size, log_dir, model_name="model.ckpt"):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           False,
                           num_classes=self.num_classes)
        self.segments = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1),
                                    dtype=tf.uint8)

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)
        pass
Example #2
0
class Inference(object):
    def __init__(self, input_size, log_dir, model_name="model.ckpt"):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           False,
                           num_classes=self.num_classes)
        self.segments = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1),
                                    dtype=tf.uint8)

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)
        pass

    def load_model(self):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)
        pass

    def inference(self, image_path, save_path=None):
        im_data = Data.load_data(image_path=image_path,
                                 input_size=self.input_size)
        im_data = np.expand_dims(im_data, axis=0)
        pred_segment_r = self.sess.run(
            self.pred_segment, feed_dict={self.image_placeholder: im_data})
        s_image = Image.fromarray(
            np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8))
        if save_path is None:
            s_image.show()
        else:
            Tools.new_dir(save_path)
            s_image.convert("L").save("{}/{}.bmp".format(
                save_path,
                os.path.splitext(os.path.basename(image_path))[0]))
        pass

    pass
Example #3
0
    def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"):
        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes)
        self.segments, self.features = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8)

        with tf.name_scope("image"):
            tf.summary.image("input", self.image_placeholder)

            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                tf.summary.image("predict-{}".format(segment_index), tf.expand_dims(segment * 255, axis=-1))
                pass
            pass

        for key in list(self.features.keys()):
            with tf.name_scope(key):
                for feature_index, feature in enumerate(self.features[key]):
                    feature_split = tf.split(feature, num_or_size_splits=int(feature.shape[-1]), axis=-1)
                    for feature_one_index, feature_one in enumerate(feature_split):
                        tf.summary.image("{}-{}".format(feature_index, feature_one_index), feature_one)
                    pass
                pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
        pass
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
class Train(object):
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass

    def cal_loss(self, segments, label_segment):
        loss_segments = []
        for segment_index, segment in enumerate(segments):
            # if "segment_side" not in segment.name:
            #     continue
            now_label_segment = tf.image.resize_nearest_neighbor(
                label_segment, tf.stack(segment.get_shape()[1:3]))

            now_loss_segment = tf.reduce_mean(
                tf.nn.weighted_cross_entropy_with_logits(
                    targets=tf.one_hot(tf.reshape(now_label_segment, [
                        -1,
                    ]),
                                       depth=self.num_classes),
                    logits=tf.reshape(segment, [-1, self.num_classes]),
                    pos_weight=1))

            loss_segments.append(now_loss_segment)
            pass

        loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        # 总损失
        loss = loss_segment_all
        return loss, loss_segment_all, loss_segments

    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_segment = self.data_reader.next_batch_train()

            train_op = self.train_op
            # train_op = self.train_segment_side_op

            if step % self.print_step == 0:
                # summary 3
                _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run(
                    [
                        train_op, self.learning_rate, self.loss_segment_all,
                        self.loss, self.summary_op
                    ],
                    feed_dict={
                        self.step_ph: step,
                        self.image_placeholder: batch_data,
                        self.label_seg_placeholder: batch_segment
                    })
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                _, learning_rate_r, loss_segment_r, loss_r = self.sess.run(
                    [
                        train_op, self.learning_rate, self.loss_segment_all,
                        self.loss
                    ],
                    feed_dict={
                        self.step_ph: step,
                        self.image_placeholder: batch_data,
                        self.label_seg_placeholder: batch_segment
                    })
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                total_loss = loss_r
            else:
                total_loss += loss_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} '
                    '({:.3f} s/step)'.format(step, pre_avg_loss,
                                             total_loss / total_loss_step,
                                             loss_r, loss_segment_r,
                                             learning_rate_r, duration))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass

    pass
Example #6
0
class Inference(object):

    def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"):
        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes)
        self.segments, self.features = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8)

        with tf.name_scope("image"):
            tf.summary.image("input", self.image_placeholder)

            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                tf.summary.image("predict-{}".format(segment_index), tf.expand_dims(segment * 255, axis=-1))
                pass
            pass

        for key in list(self.features.keys()):
            with tf.name_scope(key):
                for feature_index, feature in enumerate(self.features[key]):
                    feature_split = tf.split(feature, num_or_size_splits=int(feature.shape[-1]), axis=-1)
                    for feature_one_index, feature_one in enumerate(feature_split):
                        tf.summary.image("{}-{}".format(feature_index, feature_one_index), feature_one)
                    pass
                pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
        pass

    def load_model(self):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)
        pass

    def inference(self, image_path, image_index, save_path=None):
        im_data = Data.load_data(image_path=image_path, input_size=self.input_size)
        im_data = np.expand_dims(im_data, axis=0)
        pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op],
                                                    feed_dict={self.image_placeholder: im_data})
        self.summary_writer.add_summary(summary_now, global_step=image_index)
        s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8))
        if save_path is None:
            s_image.show()
        else:
            Tools.new_dir(save_path)
            s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0]))
        pass

    pass
class Train(object):

    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass

    @staticmethod
    def cal_loss(segments, attentions, classes, label_segment, label_classes):

        # loss_segments = []
        # for segment_index, segment in enumerate(segments):
        #     now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(segment.get_shape()[1:3]))
        #     now_loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        #         labels=tf.reshape(now_label_segment, [-1, ]),
        #         logits=tf.reshape(segment, [-1, 2])))
        #     loss_segments.append(now_loss_segment)
        #     pass

        loss_attentions = []
        for attention_index, attention in enumerate(attentions):
            now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(attention.get_shape()[1:3]))

            # now_loss_attention = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            #     labels=tf.reshape(now_label_segment, [-1, ]),
            #     logits=tf.reshape(attention, [-1, 2])))

            now_loss_attention = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
                targets=tf.one_hot(tf.reshape(now_label_segment, [-1, ]), depth=2),
                logits=tf.reshape(attention, [-1, 2]), pos_weight=3))

            loss_attentions.append(now_loss_attention)
            pass

        loss_classes = []
        for class_one in classes:
            loss_classes.append(tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes, logits=class_one)))
            pass

        # loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        loss_attention_all = tf.add_n(loss_attentions) / len(loss_attentions)
        loss_class_all = tf.add_n(loss_classes) / len(loss_classes)
        # 总损失
        # loss = loss_segment_all + loss_attention_all + 0.5 * loss_class_all
        # return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes
        loss = loss_attention_all + loss_class_all
        return loss, loss_attention_all, loss_class_all, loss_attentions, loss_classes

    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        total_acc = 0.0
        pre_avg_acc = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 pred_classes_r, summary_now) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss,
                     self.pred_classes, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                pre_avg_acc = total_acc / self.cal_step
                total_loss = loss_r
                total_acc = accuracy_classes_r
            else:
                total_loss += loss_r
                total_acc += accuracy_classes_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} '
                    'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} '
                    'lr={:.6f} ({:.3f} s/step) {} {}'.format(
                        step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step,
                        loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r,
                        learning_rate_r, duration, list(batch_class), list(pred_classes_r)))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass

    pass
    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass