コード例 #1
0
class Train(object):
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net(
        )

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        pass

    def build_net(self):
        # 数据
        image_placeholder = tf.placeholder(dtype=tf.float32,
                                           shape=(None, self.input_size[0],
                                                  self.input_size[1], 4))

        # 网络
        net = PSPNet({'data': image_placeholder},
                     is_training=True,
                     num_classes=self.num_classes,
                     num_segment=self.num_segment,
                     last_pool_size=self.last_pool_size,
                     filter_number=self.filter_number)
        raw_output_segment = net.layers['conv6_n']
        raw_output_classes = net.layers['class_attention_fc']

        # Predictions
        pred_segment = tf.cast(tf.greater(raw_output_segment, 0.5), tf.int32)
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1),
                               tf.int32)

        return image_placeholder, raw_output_segment, raw_output_classes, pred_segment, pred_classes

    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, 5):

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            (raw_output_r, pred_segment_r, raw_output_classes_r,
             pred_classes_r) = self.sess.run(
                 [
                     self.raw_output_segment, self.pred_segment,
                     self.raw_output_classes, self.pred_classes
                 ],
                 feed_dict={self.image_placeholder: final_batch_data})

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            # Tools.print_info('step {:d} {} {} {}'.format(
            #     step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r)))
            Tools.print_info('step {:d} {} {}'.format(step,
                                                      list(final_batch_class),
                                                      list(pred_classes_r)))

            pass

        pass

    pass
コード例 #2
0
class Train(object):
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 4  # 解码通道数:其他对象、attention、边界、背景
        self.segment_attention = 1  # 当解码的通道数是4时,attention所在的位置
        self.attention_module_num = 2  # attention模块中,解码通道数是2(背景、attention)的模块个数

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 5 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test,
                                has_255=True)

        # 数据
        self.image_placeholder = tf.placeholder(dtype=tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 4))
        self.label_segment_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_attention_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_classes_placeholder = tf.placeholder(dtype=tf.int32,
                                                        shape=(None, ))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           is_training=True,
                           num_classes=self.num_classes,
                           num_segment=self.num_segment,
                           segment_attention=self.segment_attention,
                           last_pool_size=self.last_pool_size,
                           filter_number=self.filter_number,
                           attention_module_num=self.attention_module_num)

        self.segments, self.attentions, self.classes = self.net.build()
        self.final_segment_logit = self.segments[0]
        self.final_class_logit = self.classes[0]

        # Predictions
        self.pred_segment = tf.cast(
            tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1),
                           axis=-1), tf.int32)
        self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1),
                                    tf.int32)

        # loss
        self.label_batch = tf.image.resize_nearest_neighbor(
            self.label_segment_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.label_attention_batch = tf.image.resize_nearest_neighbor(
            self.label_attention_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments,
            self.classes,
            self.label_batch,
            self.label_attention_batch,
            self.label_classes_placeholder,
            self.num_segment,
            attention_module_num=self.attention_module_num)

        # 当前批次的准确率:accuracy
        self.accuracy_segment = tcm.accuracy(self.pred_segment,
                                             self.label_segment_placeholder)
        self.accuracy_classes = tcm.accuracy(self.pred_classes,
                                             self.label_classes_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [
                v for v in tf.trainable_variables()
                if 'attention' in v.name or "class_attention" in v.name
            ]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index),
                                  loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_segment", self.accuracy_segment)
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            split = tf.split(self.image_placeholder,
                             num_or_size_splits=4,
                             axis=3)
            tf.summary.image("0-mask", split[3])
            tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
            tf.summary.image(
                "2-label",
                tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8))
            tf.summary.image(
                "3-attention",
                tf.cast(self.label_attention_placeholder * 255,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("predict"):
            tf.summary.image("predict",
                             tf.cast(self.pred_segment * 85, dtype=tf.uint8))
            pass

        with tf.name_scope("attention"):
            # attention
            for attention_index, attention in enumerate(self.attentions):
                tf.summary.image("{}-attention".format(attention_index),
                                 attention)
                pass
            pass

        with tf.name_scope("sigmoid"):
            for segment_index, segment in enumerate(self.segments):
                if segment_index < self.attention_module_num:
                    split = tf.split(segment,
                                     num_or_size_splits=self.num_segment,
                                     axis=3)
                    tf.summary.image("{}-other".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    tf.summary.image("{}-border".format(segment_index),
                                     split[2])
                    tf.summary.image("{}-background".format(segment_index),
                                     split[-1])
                else:
                    split = tf.split(segment, num_or_size_splits=2, axis=3)
                    tf.summary.image("{}-background".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass

    @staticmethod
    def cal_loss(segments, classes, label_segment, label_attention,
                 label_classes, num_segment, attention_module_num):

        label_segment = tf.reshape(label_segment, [
            -1,
        ])
        label_attention = tf.reshape(label_attention, [
            -1,
        ])

        loss_segments = []
        for segment_index, segment in enumerate(segments):
            if segment_index < len(segments) - attention_module_num:
                now_loss_segment = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=label_segment,
                        logits=tf.reshape(segment, [-1, num_segment])))
                loss_segments.append(now_loss_segment)
            else:
                # now_loss_segment = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
                #     targets=tf.one_hot(label_attention, depth=2), logits=tf.reshape(segment, [-1, 2]), pos_weight=3))
                segment = tf.split(segment, num_or_size_splits=2, axis=3)[1]
                now_loss_segment = tf.reduce_mean(
                    tf.nn.weighted_cross_entropy_with_logits(
                        targets=tf.cast(label_attention, dtype=tf.float32),
                        logits=tf.reshape(segment, [
                            -1,
                        ]),
                        pos_weight=3)) * 2
                loss_segments.append(now_loss_segment)
            pass

        loss_classes = []
        for class_one in classes:
            loss_classes.append(
                tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=label_classes, logits=class_one)))
            pass

        loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        loss_class_all = tf.add_n(loss_classes) / len(loss_classes)
        # 总损失
        loss = loss_segment_all + 0.1 * loss_class_all
        return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes

    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            final_batch_data, final_batch_ann, final_batch_ann_attention, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, raw_output_r,
                 pred_segment_r, raw_output_classes_r, pred_classes_r,
                 summary_now) = self.sess.run(
                     [
                         self.accuracy_segment, self.accuracy_classes,
                         train_op, self.learning_rate, self.loss_segment_all,
                         self.loss_class_all, self.loss,
                         self.final_segment_logit, self.pred_segment,
                         self.final_class_logit, self.pred_classes,
                         self.summary_op
                     ],
                     feed_dict={
                         self.step_ph: step,
                         self.image_placeholder: final_batch_data,
                         self.label_segment_placeholder: final_batch_ann,
                         self.label_attention_placeholder:
                         final_batch_ann_attention,
                         self.label_classes_placeholder: final_batch_class
                     })
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_segment_r, accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, raw_output_r,
                 pred_segment_r, raw_output_classes_r,
                 pred_classes_r) = self.sess.run(
                     [
                         self.accuracy_segment, self.accuracy_classes,
                         train_op, self.learning_rate, self.loss_segment_all,
                         self.loss_class_all, self.loss,
                         self.final_segment_logit, self.pred_segment,
                         self.final_class_logit, self.pred_classes
                     ],
                     feed_dict={
                         self.step_ph: step,
                         self.image_placeholder: final_batch_data,
                         self.label_segment_placeholder: final_batch_ann,
                         self.label_attention_placeholder:
                         final_batch_ann_attention,
                         self.label_classes_placeholder: final_batch_class
                     })
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            Tools.print_info(
                'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}'
                ' lr={:.6f} ({:.3f} s/step) {} {}'.format(
                    step, loss_r, loss_segment_r, loss_classes_r,
                    accuracy_segment_r, accuracy_classes_r, learning_rate_r,
                    duration, list(final_batch_class), list(pred_classes_r)))

            pass

        pass

    pass
コード例 #3
0
class Train(object):
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass

    def cal_loss(self, segments, label_segment):
        loss_segments = []
        for segment_index, segment in enumerate(segments):
            # if "segment_side" not in segment.name:
            #     continue
            now_label_segment = tf.image.resize_nearest_neighbor(
                label_segment, tf.stack(segment.get_shape()[1:3]))

            now_loss_segment = tf.reduce_mean(
                tf.nn.weighted_cross_entropy_with_logits(
                    targets=tf.one_hot(tf.reshape(now_label_segment, [
                        -1,
                    ]),
                                       depth=self.num_classes),
                    logits=tf.reshape(segment, [-1, self.num_classes]),
                    pos_weight=1))

            loss_segments.append(now_loss_segment)
            pass

        loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        # 总损失
        loss = loss_segment_all
        return loss, loss_segment_all, loss_segments

    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_segment = self.data_reader.next_batch_train()

            train_op = self.train_op
            # train_op = self.train_segment_side_op

            if step % self.print_step == 0:
                # summary 3
                _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run(
                    [
                        train_op, self.learning_rate, self.loss_segment_all,
                        self.loss, self.summary_op
                    ],
                    feed_dict={
                        self.step_ph: step,
                        self.image_placeholder: batch_data,
                        self.label_seg_placeholder: batch_segment
                    })
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                _, learning_rate_r, loss_segment_r, loss_r = self.sess.run(
                    [
                        train_op, self.learning_rate, self.loss_segment_all,
                        self.loss
                    ],
                    feed_dict={
                        self.step_ph: step,
                        self.image_placeholder: batch_data,
                        self.label_seg_placeholder: batch_segment
                    })
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                total_loss = loss_r
            else:
                total_loss += loss_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} '
                    '({:.3f} s/step)'.format(step, pre_avg_loss,
                                             total_loss / total_loss_step,
                                             loss_r, loss_segment_r,
                                             learning_rate_r, duration))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass

    pass
コード例 #4
0
class Train(object):

    def __init__(self, batch_size, last_pool_size, input_size, log_dir,
                 data_root_path, train_list, data_path, annotation_path, class_path,
                 model_name="model.ckpt", is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.has_255 = True  # 是否预测边界
        self.num_segment = 4 if self.has_255 else 3

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 1 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size,
                                is_test=is_test, has_255=self.has_255)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder,
         self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes,
         self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0: 3], axis=3))
        tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8))
        split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3)
        tf.summary.image("3-attention", split[1])
        tf.summary.image("4-other class", split[0])
        tf.summary.image("5-background", split[-1])
        if self.has_255:
            tf.summary.image("5-border", split[2])
            pass
        tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass

    def build_net(self):
        # 数据
        image_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        label_segment_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, self.input_size[0] // self.ratio,
                                                                          self.input_size[1] // self.ratio, 1))
        label_classes_placeholder = tf.placeholder(dtype=tf.int32, shape=(None,))

        # 网络
        net = PSPNet({'data': image_placeholder}, is_training=True, num_classes=self.num_classes,
                     num_segment=self.num_segment, last_pool_size=self.last_pool_size, filter_number=self.filter_number)
        raw_output_segment = net.layers['conv6_n_4']
        raw_output_classes = net.layers['class_attention_fc']

        # Predictions
        prediction = tf.reshape(raw_output_segment, [-1, self.num_segment])
        pred_segment = tf.cast(tf.expand_dims(tf.argmax(raw_output_segment, axis=-1), axis=-1), tf.int32)
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32)

        # label
        label_batch = tf.image.resize_nearest_neighbor(label_segment_placeholder,
                                                       tf.stack(raw_output_segment.get_shape()[1:3]))
        label_batch = tf.reshape(label_batch, [-1, ])

        # 当前批次的准确率:accuracy
        accuracy_segment = tcm.accuracy(pred_segment, label_segment_placeholder)
        accuracy_classes = tcm.accuracy(pred_classes, label_classes_placeholder)

        # loss
        loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_batch,
                                                                                     logits=prediction))

        # 分类损失
        loss_classes = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes_placeholder,
                                                                                     logits=raw_output_classes))
        # 总损失
        loss = tf.add_n([loss_segment, 0.1 * loss_classes])

        # 学习率策略
        step_ph = tf.placeholder(dtype=tf.float32, shape=())
        learning_rate = tf.scalar_mul(tf.constant(self.learning_rate), tf.pow((1 - step_ph / self.num_steps), 0.9))
        train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

        # 单独训练最后的分类
        classes_trainable = [v for v in tf.trainable_variables() if 'class_attention' in v.name]
        train_classes_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, var_list=classes_trainable)

        return (image_placeholder, label_segment_placeholder, label_classes_placeholder,
                raw_output_segment, raw_output_classes, pred_segment, pred_classes,
                loss_segment, loss_classes, loss, accuracy_segment, accuracy_classes,
                step_ph, train_op, train_classes_op, learning_rate)

    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            # train_op = self.train_classes_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r,
                 summary_now) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes,
                     self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            Tools.print_info(
                'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}'
                ' lr={:.6f} ({:.3f} s/step) {} {}'.format(
                    step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r,
                    learning_rate_r, duration, list(final_batch_class), list(pred_classes_r)))

            pass

        pass

    pass
コード例 #5
0
class Train(object):

    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass

    @staticmethod
    def cal_loss(segments, attentions, classes, label_segment, label_classes):

        # loss_segments = []
        # for segment_index, segment in enumerate(segments):
        #     now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(segment.get_shape()[1:3]))
        #     now_loss_segment = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        #         labels=tf.reshape(now_label_segment, [-1, ]),
        #         logits=tf.reshape(segment, [-1, 2])))
        #     loss_segments.append(now_loss_segment)
        #     pass

        loss_attentions = []
        for attention_index, attention in enumerate(attentions):
            now_label_segment = tf.image.resize_nearest_neighbor(label_segment, tf.stack(attention.get_shape()[1:3]))

            # now_loss_attention = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            #     labels=tf.reshape(now_label_segment, [-1, ]),
            #     logits=tf.reshape(attention, [-1, 2])))

            now_loss_attention = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
                targets=tf.one_hot(tf.reshape(now_label_segment, [-1, ]), depth=2),
                logits=tf.reshape(attention, [-1, 2]), pos_weight=3))

            loss_attentions.append(now_loss_attention)
            pass

        loss_classes = []
        for class_one in classes:
            loss_classes.append(tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_classes, logits=class_one)))
            pass

        # loss_segment_all = tf.add_n(loss_segments) / len(loss_segments)
        loss_attention_all = tf.add_n(loss_attentions) / len(loss_attentions)
        loss_class_all = tf.add_n(loss_classes) / len(loss_classes)
        # 总损失
        # loss = loss_segment_all + loss_attention_all + 0.5 * loss_class_all
        # return loss, loss_segment_all, loss_class_all, loss_segments, loss_classes
        loss = loss_attention_all + loss_class_all
        return loss, loss_attention_all, loss_class_all, loss_attentions, loss_classes

    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        total_acc = 0.0
        pre_avg_acc = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 pred_classes_r, summary_now) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss,
                     self.pred_classes, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                pre_avg_acc = total_acc / self.cal_step
                total_loss = loss_r
                total_acc = accuracy_classes_r
            else:
                total_loss += loss_r
                total_acc += accuracy_classes_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} '
                    'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} '
                    'lr={:.6f} ({:.3f} s/step) {} {}'.format(
                        step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step,
                        loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r,
                        learning_rate_r, duration, list(batch_class), list(pred_classes_r)))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass

    pass