コード例 #1
0
    def __init__(self, log_dir, save_dir):
        self.save_dir = Tools.new_dir(save_dir)
        self.log_dir = Tools.new_dir(log_dir)

        self.last_pool_size = 50
        self.input_size = [self.last_pool_size * 8, self.last_pool_size * 8]
        pass
コード例 #2
0
    def build(self):

        # 提取特征,属于公共部分
        block1, block2, block3, block4 = self._feature(self.input_data)
        blocks = [block1, block2, block3, block4]
        block4_shape = Tools.get_shape(block4)  # 45, 512
        block3_shape = Tools.get_shape(block3)  # 90, 512
        block2_shape = Tools.get_shape(block2)  # 180, 256
        block1_shape = Tools.get_shape(block1)  # 360, 128

        segments = []
        segments_output = []

        ######################################################
        # 确定初始attention的输入点:建议在进入attention时输入
        ######################################################

        with tf.variable_scope(name_or_scope="attention_4"):
            net_output = self._decoder(block4, block4_shape, block3_shape, name="4")
            net_segment_output = self._segment(block4, block4_shape, block3_shape, name="segment_side_4")
            segments.append(net_segment_output)  # segment
            pass

        segments_output.append(net_output)
        block3_add = Net.add([block3, net_output], name='attention_3_add')
        with tf.variable_scope(name_or_scope="attention_3"):
            net_output = self._decoder(block3_add, block3_shape, block2_shape, name="3")
            net_segment_output = self._segment(block3_add, block3_shape, block2_shape, name="segment_side_3")
            segments.append(net_segment_output)  # segment
            pass

        segments_output.append(net_output)
        block2_add = Net.add([block2, net_output], name="attention_2_net_output_relu")
        with tf.variable_scope(name_or_scope="attention_2"):
            net_output = self._decoder(block2_add, block2_shape, block1_shape, name="2")
            net_segment_output = self._segment(block2_add, block2_shape, block1_shape, name="segment_side_2")
            segments.append(net_segment_output)  # segment
            pass

        segments_output.append(net_output)
        block1_add = Net.add([block1, net_output], name="attention_1_concat")
        with tf.variable_scope(name_or_scope="attention_1"):
            net_output = self._decoder(block1_add, block1_shape, block1_shape, name="1")
            net_segment_output = self._segment(block1_add, block1_shape, block1_shape, name="segment_side_1")
            segments.append(net_segment_output)  # segment
            pass

        segments_output.append(net_output)
        net_output = Net.conv(net_output, 3, 3, self.num_classes, 1, 1,
                              biased=True, relu=False, name='attention_0_21')
        segments.append(net_output)  # segment

        features = {"block": blocks, "segment": segments_output}
        return segments, features
コード例 #3
0
 def inference(self, image_path, image_index, save_path=None):
     im_data = Data.load_data(image_path=image_path, input_size=self.input_size)
     im_data = np.expand_dims(im_data, axis=0)
     pred_segment_r, summary_now = self.sess.run([self.pred_segment, self.summary_op],
                                                 feed_dict={self.image_placeholder: im_data})
     self.summary_writer.add_summary(summary_now, global_step=image_index)
     s_image = Image.fromarray(np.asarray(np.squeeze(pred_segment_r) * 255, dtype=np.uint8))
     if save_path is None:
         s_image.show()
     else:
         Tools.new_dir(save_path)
         s_image.convert("L").save("{}/{}.bmp".format(save_path, os.path.splitext(os.path.basename(image_path))[0]))
     pass
コード例 #4
0
    def __init__(self, input_size, log_dir, model_name="model.ckpt"):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           False,
                           num_classes=self.num_classes)
        self.segments = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1),
                                    dtype=tf.uint8)

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)
        pass
コード例 #5
0
    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_segment = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run(
                    [train_op, self.learning_rate, self.loss_segment_all, self.loss, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.label_seg_placeholder: batch_segment})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                _, learning_rate_r, loss_segment_r, loss_r = self.sess.run(
                    [train_op, self.learning_rate, self.loss_segment_all, self.loss],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.label_seg_placeholder: batch_segment})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                total_loss = loss_r
            else:
                total_loss += loss_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info('step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} '
                                 '({:.3f} s/step)'.format(step, pre_avg_loss, total_loss / total_loss_step,
                                                          loss_r, loss_segment_r, learning_rate_r, duration))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass
コード例 #6
0
    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, 5):

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            (raw_output_r, pred_segment_r, raw_output_classes_r,
             pred_classes_r) = self.sess.run(
                 [
                     self.raw_output_segment, self.pred_segment,
                     self.raw_output_classes, self.pred_classes
                 ],
                 feed_dict={self.image_placeholder: final_batch_data})

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            # Tools.print_info('step {:d} {} {} {}'.format(
            #     step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r)))
            Tools.print_info('step {:d} {} {}'.format(step,
                                                      list(final_batch_class),
                                                      list(pred_classes_r)))

            pass

        pass
コード例 #7
0
    def __init__(self,
                 input_size,
                 summary_dir,
                 log_dir,
                 model_name="model.ckpt"):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))

        # 网络
        self.features = self._feature(self.image_placeholder)

        with tf.name_scope("image"):
            tf.summary.image("input", self.image_placeholder)
            pass

        with tf.name_scope("block"):
            for feature_index, feature in enumerate(self.features[:-1]):
                feature_split = tf.split(feature,
                                         num_or_size_splits=int(
                                             feature.shape[-1]),
                                         axis=-1)
                for feature_one_index, feature_one in enumerate(feature_split):
                    tf.summary.image(
                        "{}-{}".format(feature_index, feature_one_index),
                        feature_one)
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(summary_dir,
                                                    self.sess.graph)
        pass
コード例 #8
0
    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            # train_op = self.train_classes_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r,
                 summary_now) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes,
                     self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            Tools.print_info(
                'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}'
                ' lr={:.6f} ({:.3f} s/step) {} {}'.format(
                    step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r,
                    learning_rate_r, duration, list(final_batch_class), list(pred_classes_r)))

            pass

        pass
コード例 #9
0
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        self.image_placeholder, self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes = self.build_net(
        )

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        pass
コード例 #10
0
    def build(self):

        # 提取特征,属于公共部分
        block1, block2, block3, block4 = self._feature(self.input_data)
        block4_shape = Tools.get_shape(block4)  # 45, 512
        block3_shape = Tools.get_shape(block3)  # 90, 512
        block2_shape = Tools.get_shape(block2)  # 180, 256
        block1_shape = Tools.get_shape(block1)  # 360, 128

        segments = []

        ######################################################
        # 确定初始attention的输入点:建议在进入attention时输入
        ######################################################

        with tf.variable_scope(name_or_scope="attention_4"):
            net_output = self._decoder(block4, block4_shape, block3_shape, name="4")
            pass

        block3_add = Net.add([block3, net_output], name='attention_3_add')
        with tf.variable_scope(name_or_scope="attention_3"):
            net_output = self._decoder(block3_add, block3_shape, block2_shape, name="3")
            pass

        block2_add = Net.add([block2, net_output], name="attention_2_net_output_relu")
        with tf.variable_scope(name_or_scope="attention_2"):
            net_output = self._decoder(block2_add, block2_shape, block1_shape, name="2")
            pass

        block1_add = Net.add([block1, net_output], name="attention_1_concat")
        with tf.variable_scope(name_or_scope="attention_1"):
            net_output = self._decoder(block1_add, block1_shape, block1_shape, name="1")
            pass

        net_output = Net.conv(net_output, 3, 3, 2, 1, 1, biased=True, relu=False, name='attention_0')
        segments.append(net_output)  # segment
        return segments
コード例 #11
0
    def load_net(self):
        print("begin to build net and start session and load model ...")

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=21,
                     last_pool_size=self.last_pool_size, filter_number=32, num_segment=4)

        # 输出/预测
        raw_output_op = net.layers["conv6_n_4"]
        raw_output_op = tf.image.resize_bilinear(raw_output_op, size=self.input_size)
        predict_output_op = tf.argmax(tf.sigmoid(raw_output_op), axis=-1)

        raw_output_classes = net.layers['class_attention_fc']
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        print("end build net and start session and load model ...")

        return sess, img_placeholder, predict_output_op, pred_classes
コード例 #12
0
    def __init__(self, input_size, summary_dir, log_dir, model_name="model.ckpt"):
        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.num_classes = 21

        # 网络
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))

        # 网络
        self.net = BAISNet(self.image_placeholder, False, num_classes=self.num_classes)
        self.segments, self.features = self.net.build()
        self.pred_segment = tf.cast(tf.argmax(self.segments[0], axis=-1), dtype=tf.uint8)

        with tf.name_scope("image"):
            tf.summary.image("input", self.image_placeholder)

            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                tf.summary.image("predict-{}".format(segment_index), tf.expand_dims(segment * 255, axis=-1))
                pass
            pass

        for key in list(self.features.keys()):
            with tf.name_scope(key):
                for feature_index, feature in enumerate(self.features[key]):
                    feature_split = tf.split(feature, num_or_size_splits=int(feature.shape[-1]), axis=-1)
                    for feature_one_index, feature_one in enumerate(feature_split):
                        tf.summary.image("{}-{}".format(feature_index, feature_one_index), feature_one)
                    pass
                pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
        pass
コード例 #13
0
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 4  # 解码通道数:其他对象、attention、边界、背景
        self.segment_attention = 1  # 当解码的通道数是4时,attention所在的位置
        self.attention_module_num = 2  # attention模块中,解码通道数是2(背景、attention)的模块个数

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 5 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test,
                                has_255=True)

        # 数据
        self.image_placeholder = tf.placeholder(dtype=tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 4))
        self.label_segment_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_attention_placeholder = tf.placeholder(
            dtype=tf.int32,
            shape=(None, self.input_size[0] // self.ratio,
                   self.input_size[1] // self.ratio, 1))
        self.label_classes_placeholder = tf.placeholder(dtype=tf.int32,
                                                        shape=(None, ))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           is_training=True,
                           num_classes=self.num_classes,
                           num_segment=self.num_segment,
                           segment_attention=self.segment_attention,
                           last_pool_size=self.last_pool_size,
                           filter_number=self.filter_number,
                           attention_module_num=self.attention_module_num)

        self.segments, self.attentions, self.classes = self.net.build()
        self.final_segment_logit = self.segments[0]
        self.final_class_logit = self.classes[0]

        # Predictions
        self.pred_segment = tf.cast(
            tf.expand_dims(tf.argmax(self.final_segment_logit, axis=-1),
                           axis=-1), tf.int32)
        self.pred_classes = tf.cast(tf.argmax(self.final_class_logit, axis=-1),
                                    tf.int32)

        # loss
        self.label_batch = tf.image.resize_nearest_neighbor(
            self.label_segment_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.label_attention_batch = tf.image.resize_nearest_neighbor(
            self.label_attention_placeholder,
            tf.stack(self.final_segment_logit.get_shape()[1:3]))
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments,
            self.classes,
            self.label_batch,
            self.label_attention_batch,
            self.label_classes_placeholder,
            self.num_segment,
            attention_module_num=self.attention_module_num)

        # 当前批次的准确率:accuracy
        self.accuracy_segment = tcm.accuracy(self.pred_segment,
                                             self.label_segment_placeholder)
        self.accuracy_classes = tcm.accuracy(self.pred_classes,
                                             self.label_classes_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [
                v for v in tf.trainable_variables()
                if 'attention' in v.name or "class_attention" in v.name
            ]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index),
                                  loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_segment", self.accuracy_segment)
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            split = tf.split(self.image_placeholder,
                             num_or_size_splits=4,
                             axis=3)
            tf.summary.image("0-mask", split[3])
            tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
            tf.summary.image(
                "2-label",
                tf.cast(self.label_segment_placeholder * 85, dtype=tf.uint8))
            tf.summary.image(
                "3-attention",
                tf.cast(self.label_attention_placeholder * 255,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("predict"):
            tf.summary.image("predict",
                             tf.cast(self.pred_segment * 85, dtype=tf.uint8))
            pass

        with tf.name_scope("attention"):
            # attention
            for attention_index, attention in enumerate(self.attentions):
                tf.summary.image("{}-attention".format(attention_index),
                                 attention)
                pass
            pass

        with tf.name_scope("sigmoid"):
            for segment_index, segment in enumerate(self.segments):
                if segment_index < self.attention_module_num:
                    split = tf.split(segment,
                                     num_or_size_splits=self.num_segment,
                                     axis=3)
                    tf.summary.image("{}-other".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    tf.summary.image("{}-border".format(segment_index),
                                     split[2])
                    tf.summary.image("{}-background".format(segment_index),
                                     split[-1])
                else:
                    split = tf.split(segment, num_or_size_splits=2, axis=3)
                    tf.summary.image("{}-background".format(segment_index),
                                     split[0])
                    tf.summary.image("{}-attention".format(segment_index),
                                     split[1])
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
コード例 #14
0
    def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0):
        # 读入图片数据
        final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image(
            image_filename, where=where, annotation_filename=annotation_filename,
            ann_index=ann_index, image_size=self.input_size)

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=1, last_pool_size=self.last_pool_size,
                     filter_number=32)

        # 输出/预测
        raw_output_op = net.layers["conv6_n"]
        sigmoid_output_op = tf.sigmoid(raw_output_op)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        # 运行
        raw_output, sigmoid_output = sess.run([raw_output_op, sigmoid_output_op], feed_dict={img_placeholder: final_batch_data})

        # 保存
        Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "data.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(sigmoid_output[0] * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(np.greater(raw_output[0], 0.5) * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred_raw.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(np.greater(sigmoid_output[0], 0.5) * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred_sigmoid.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "mask.bmp"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "ann.bmp"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        pass
コード例 #15
0
    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        total_acc = 0.0
        pre_avg_acc = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 pred_classes_r, summary_now) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss,
                     self.pred_classes, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                pre_avg_acc = total_acc / self.cal_step
                total_loss = loss_r
                total_acc = accuracy_classes_r
            else:
                total_loss += loss_r
                total_acc += accuracy_classes_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} '
                    'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} '
                    'lr={:.6f} ({:.3f} s/step) {} {}'.format(
                        step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step,
                        loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r,
                        learning_rate_r, duration, list(batch_class), list(pred_classes_r)))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass
コード例 #16
0
    def __init__(self, batch_size, last_pool_size, input_size, log_dir,
                 data_root_path, train_list, data_path, annotation_path, class_path,
                 model_name="model.ckpt", is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.has_255 = True  # 是否预测边界
        self.num_segment = 4 if self.has_255 else 3

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001
        self.print_step = 1 if is_test else 25

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size,
                                is_test=is_test, has_255=self.has_255)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder, self.label_classes_placeholder,
         self.raw_output_segment, self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_segment, self.accuracy_classes,
         self.step_ph, self.train_op, self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0: 3], axis=3))
        tf.summary.image("2-label", tf.cast(self.label_segment_placeholder * (85 if self.has_255 else 127), dtype=tf.uint8))
        split = tf.split(self.raw_output_segment, num_or_size_splits=self.num_segment, axis=3)
        tf.summary.image("3-attention", split[1])
        tf.summary.image("4-other class", split[0])
        tf.summary.image("5-background", split[-1])
        if self.has_255:
            tf.summary.image("5-border", split[2])
            pass
        tf.summary.image("6-pred_segment", tf.cast(self.pred_segment * (85 if self.has_255 else 127), dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass
コード例 #17
0
    def build_old(self):

        # 提取特征,属于公共部分
        block1, block2, block3, block4 = self._feature(self.input_data)
        blocks = [block1, block2, block3, block4]
        block4_shape = Tools.get_shape(block4)  # 45, 512
        block3_shape = Tools.get_shape(block3)  # 90, 512
        block2_shape = Tools.get_shape(block2)  # 180, 256
        block1_shape = Tools.get_shape(block1)  # 360, 128

        adds = []
        adds_in_block = []
        adds_in_2 = []
        adds = []
        temps = []
        segments = []
        segments_output = []

        ######################################################
        # 确定初始attention的输入点:建议在进入attention时输入
        ######################################################

        with tf.variable_scope(name_or_scope="attention_4"):
            # 0
            net_segment_output, temp = self._segment(block4,
                                                     block4_shape,
                                                     block3_shape,
                                                     name="segment_side_4")
            temps.append(temp)
            segments.append(net_segment_output)  # segment

            net_output = self._decoder(block4,
                                       block4_shape,
                                       block3_shape,
                                       name="4")
            segments_output.append(net_output)
            pass

        with tf.variable_scope(name_or_scope="attention_3"):
            # 1  ==>
            # net_segment_block_output, temp = self._segment(block3, block3_shape, block2_shape, name="segment_side_3_block")
            # temps.append(temp)
            # segments.append(net_segment_block_output)  # segment

            # 2
            block3_add = Net.add([block3, net_output], name='add')
            adds_in_block.append(block3)
            adds_in_2.append(net_output)
            adds.append(block3_add)

            net_segment_output, temp = self._segment(block3_add,
                                                     block3_shape,
                                                     block2_shape,
                                                     name="segment_side_3")
            temps.append(temp)
            segments.append(net_segment_output)  # segment

            net_output = self._decoder(block3_add,
                                       block3_shape,
                                       block2_shape,
                                       name="3")
            segments_output.append(net_output)
            pass

        with tf.variable_scope(name_or_scope="attention_2"):
            # 3  ==>
            # net_segment_block_output, temp = self._segment(block2, block2_shape, block1_shape, name="segment_side_2_block")
            # temps.append(temp)
            # segments.append(net_segment_block_output)  # segment

            # 4
            block2_add = Net.add([block2, net_output], name="add")
            adds_in_block.append(block2)
            adds_in_2.append(net_output)
            adds.append(block2_add)

            net_segment_output, temp = self._segment(block2_add,
                                                     block2_shape,
                                                     block1_shape,
                                                     name="segment_side_2")
            temps.append(temp)
            segments.append(net_segment_output)  # segment

            net_output = self._decoder(block2_add,
                                       block2_shape,
                                       block1_shape,
                                       name="2")
            segments_output.append(net_output)
            pass

        with tf.variable_scope(name_or_scope="attention_1"):
            # 5  ==>
            # net_segment_block_output, temp = self._segment(block1, block1_shape, block1_shape, name="segment_side_1_block")
            # temps.append(temp)
            # segments.append(net_segment_block_output)  # segment

            # 6
            block1_add = Net.add([block1, net_output], name="add")
            adds_in_block.append(block1)
            adds_in_2.append(net_output)
            adds.append(block1_add)

            net_segment_output, temp = self._segment(block1_add,
                                                     block1_shape,
                                                     block1_shape,
                                                     name="segment_side_1")
            temps.append(temp)
            segments.append(net_segment_output)  # segment

            net_output = self._decoder(block1_add,
                                       block1_shape,
                                       block1_shape,
                                       name="1")
            segments_output.append(net_output)
            pass

        # 7
        net_output = Net.conv(net_output,
                              3,
                              3,
                              2,
                              1,
                              1,
                              biased=True,
                              relu=False,
                              name='attention_0')
        segments.append(net_output)  # segment

        features = {
            "block": blocks,
            "segment": segments_output,
            "temp": temps,
            "add": adds,
            "adds_in_block": adds_in_block,
            "adds_in_2": adds_in_2
        }
        return segments, features
コード例 #18
0
 def load_model(self, pretrain):
     # 加载模型
     Tools.restore_if_y(self.sess, self.log_dir, pretrain=pretrain)
     pass
コード例 #19
0
    def __init__(self,
                 batch_size,
                 last_pool_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21
        self.num_segment = 1

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.ratio = 8
        self.last_pool_size = last_pool_size
        self.filter_number = 32

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 500001

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)
        # 网络
        (self.image_placeholder, self.label_segment_placeholder,
         self.label_classes_placeholder, self.raw_output_segment,
         self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss, self.accuracy_0,
         self.accuracy_1, self.accuracy_classes, self.step_ph, self.train_op,
         self.train_classes_op, self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_0", self.accuracy_0)
        tf.summary.scalar("accuracy_1", self.accuracy_1)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)
        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("image", tf.concat(split[0:3], axis=3))
        tf.summary.image("mask", split[3])
        tf.summary.image("label", self.label_segment_placeholder)
        tf.summary.image("raw_output_segment", self.raw_output_segment)
        tf.summary.image("pred_segment", tf.cast(self.pred_segment,
                                                 tf.float32))
        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
コード例 #20
0
    def __init__(self, batch_size, input_size, log_dir, data_root_path, train_list, data_path,
                 annotation_path, class_path, model_name="model.ckpt", pretrain=None, is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-4
        self.num_steps = 500001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path, data_list=train_list,
                                data_path=data_path, annotation_path=annotation_path, class_path=class_path,
                                batch_size=self.batch_size, image_size=self.input_size, is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 3))
        self.mask_placeholder = tf.placeholder(tf.float32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_seg_placeholder = tf.placeholder(tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))
        self.label_cls_placeholder = tf.placeholder(tf.int32, shape=(None,))

        # 网络
        self.net = BAISNet(self.image_placeholder, self.mask_placeholder, True, num_classes=self.num_classes)
        self.segments, self.attentions, self.classes = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_class_all, self.loss_segments, self.loss_classes = self.cal_loss(
            self.segments, self.attentions, self.classes, self.label_seg_placeholder, self.label_cls_placeholder)

        # 当前批次的准确率:accuracy
        self.pred_classes = tf.cast(tf.argmax(self.classes[0], axis=-1), tf.int32)
        self.accuracy_classes = tcm.accuracy(self.pred_classes, self.label_cls_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(tf.constant(self.learning_rate),
                                               tf.pow((1 - self.step_ph / self.num_steps), 0.9))
            self.train_op = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss)

            # 单独训练最后的attention
            attention_trainable = [v for v in tf.trainable_variables()
                                 if 'attention' in v.name or "class_attention" in v.name]
            print(len(attention_trainable))
            self.train_attention_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss, var_list=attention_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            tf.summary.scalar("loss_class", self.loss_class_all)
            for loss_segment_index, loss_segment in enumerate(self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index), loss_segment)
            for loss_class_index, loss_class in enumerate(self.loss_classes):
                tf.summary.scalar("loss_class_{}".format(loss_class_index), loss_class)
            pass

        with tf.name_scope("accuracy"):
            tf.summary.scalar("accuracy_classes", self.accuracy_classes)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image("2-segment", tf.cast(self.label_seg_placeholder * 255, dtype=tf.uint8))
            tf.summary.image("3-mask", tf.cast(self.mask_placeholder * 255, dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.split(segment, num_or_size_splits=2, axis=3)
                # tf.summary.image("predict-{}-0".format(segment_index), tf.nn.sigmoid(segment[0]))
                tf.summary.image("predict-{}-1".format(segment_index), tf.nn.sigmoid(segment[1]))
            # attention
            for attention_index, attention in enumerate(self.attentions):
                attention = tf.split(attention, num_or_size_splits=2, axis=3)
                # tf.summary.image("attention-{}-0".format(attention_index), tf.nn.sigmoid(attention[0]))
                tf.summary.image("attention-{}-1".format(attention_index), tf.nn.sigmoid(attention[1]))
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
        pass
コード例 #21
0
    def __init__(self, log_dir, data, model_name="model.ckpt", is_test=False):

        # 读取数据
        self.data_reader = data
        self.batch_size = self.data_reader.batch_size
        self.num_classes = self.data_reader.num_classes
        self.input_size = self.data_reader.image_size
        self.ratio = self.data_reader.ratio
        self.num_segment = self.data_reader.num_segment
        self.attention_class = self.data_reader.attention_class

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)

        # 和模型相关的参数:必须保证input_size大于8倍的last_pool_size
        self.last_pool_size = self.input_size[0] // self.ratio
        self.filter_number = 32
        self.learning_rate = 5e-3
        self.num_steps = 1000001
        self.print_step = 1 if is_test else 25

        # 网络
        (self.image_placeholder, self.label_segment_placeholder,
         self.label_classes_placeholder, self.raw_output_segment,
         self.raw_output_classes, self.pred_segment, self.pred_classes,
         self.loss_segment, self.loss_classes, self.loss,
         self.accuracy_segment, self.accuracy_classes, self.step_ph,
         self.train_op, self.train_classes_op,
         self.learning_rate) = self.build_net()

        # summary 1
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_segment", self.loss_segment)
        tf.summary.scalar("loss_classes", self.loss_classes)
        tf.summary.scalar("accuracy_segment", self.accuracy_segment)
        tf.summary.scalar("accuracy_classes", self.accuracy_classes)

        split = tf.split(self.image_placeholder, num_or_size_splits=4, axis=3)
        tf.summary.image("0-mask", split[3])
        tf.summary.image("1-image", tf.concat(split[0:3], axis=3))
        tf.summary.image(
            "2-label",
            tf.cast(self.label_segment_placeholder * (255 //
                                                      (self.num_segment - 1)),
                    dtype=tf.uint8))

        split = tf.split(self.raw_output_segment,
                         num_or_size_splits=self.num_segment,
                         axis=3)
        tf.summary.image("3-attention", split[self.attention_class])
        for num_segment in range(self.num_segment):
            tf.summary.image("4-segment-output-{}".format(num_segment),
                             split[num_segment])
        tf.summary.image(
            "5-pred_segment",
            tf.cast(self.pred_segment * (255 // (self.num_segment - 1)),
                    dtype=tf.uint8))

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
コード例 #22
0
    def run(self,
            result_filename,
            image_filename,
            where=None,
            annotation_filename=None,
            ann_index=0):
        # 读入图片数据
        if annotation_filename:
            final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image(
                image_filename,
                where=where,
                annotation_filename=annotation_filename,
                ann_index=ann_index,
                image_size=self.input_size)
        else:
            final_batch_data, data_raw, gaussian_mask = Data.load_image(
                image_filename,
                where=where,
                annotation_filename=annotation_filename,
                ann_index=ann_index,
                image_size=self.input_size)

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32,
                                         shape=(None, self.input_size[0],
                                                self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder},
                     is_training=True,
                     num_classes=21,
                     last_pool_size=self.last_pool_size,
                     filter_number=32,
                     num_segment=4)

        # 输出/预测
        raw_output_op = net.layers["conv6_n_4"]
        sigmoid_output_op = tf.sigmoid(raw_output_op)
        predict_output_op = tf.argmax(sigmoid_output_op, axis=-1)

        raw_output_classes = net.layers['class_attention_fc']
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1),
                               tf.int32)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        # 运行
        raw_output, sigmoid_output, predict_output_r, raw_output_classes_r, pred_classes_r = sess.run(
            [
                raw_output_op, sigmoid_output_op, predict_output_op,
                raw_output_classes, pred_classes
            ],
            feed_dict={img_placeholder: final_batch_data})

        # 保存
        print("{} {} {}".format(pred_classes_r[0],
                                CategoryNames[pred_classes_r[0]],
                                raw_output_classes_r))
        print("result in {}".format(
            os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "data.png"))

        output_result = np.squeeze(
            np.split(np.asarray(sigmoid_output[0] * 255, dtype=np.uint8),
                     axis=-1,
                     indices_or_sections=4))

        Image.fromarray(
            np.squeeze(
                np.asarray(predict_output_r[0] * 255 // 4,
                           dtype=np.uint8))).save(
                               os.path.join(self.save_dir,
                                            result_filename + "pred.png"))
        Image.fromarray(output_result[0]).save(
            os.path.join(self.save_dir, result_filename + "pred_0.png"))
        Image.fromarray(output_result[1]).save(
            os.path.join(self.save_dir, result_filename + "pred_1.png"))
        Image.fromarray(output_result[2]).save(
            os.path.join(self.save_dir, result_filename + "pred_2.png"))
        Image.fromarray(output_result[3]).save(
            os.path.join(self.save_dir, result_filename + "pred_3.png"))

        Image.fromarray(
            np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save(
                os.path.join(self.save_dir, result_filename + "mask.bmp"))

        if annotation_filename:
            Image.fromarray(
                np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save(
                    os.path.join(self.save_dir, result_filename + "ann.bmp"))
            pass
        pass
コード例 #23
0
    def __init__(self,
                 batch_size,
                 input_size,
                 log_dir,
                 data_root_path,
                 train_list,
                 data_path,
                 annotation_path,
                 class_path,
                 model_name="model.ckpt",
                 pretrain=None,
                 is_test=False):

        # 和保存模型相关的参数
        self.log_dir = Tools.new_dir(log_dir)
        self.model_name = model_name
        self.checkpoint_path = os.path.join(self.log_dir, self.model_name)
        self.pretrain = pretrain

        # 和数据相关的参数
        self.input_size = input_size
        self.batch_size = batch_size
        self.num_classes = 21

        # 和模型训练相关的参数
        self.learning_rate = 5e-3
        self.num_steps = 100001
        self.print_step = 10 if is_test else 100
        self.cal_step = 100 if is_test else 1000

        # 读取数据
        self.data_reader = Data(data_root_path=data_root_path,
                                data_list=train_list,
                                data_path=data_path,
                                annotation_path=annotation_path,
                                class_path=class_path,
                                batch_size=self.batch_size,
                                image_size=self.input_size,
                                is_test=is_test)

        # 数据
        self.image_placeholder = tf.placeholder(tf.float32,
                                                shape=(None,
                                                       self.input_size[0],
                                                       self.input_size[1], 3))
        self.label_seg_placeholder = tf.placeholder(
            tf.int32, shape=(None, self.input_size[0], self.input_size[1], 1))

        # 网络
        self.net = BAISNet(self.image_placeholder,
                           True,
                           num_classes=self.num_classes)
        self.segments, self.features = self.net.build()

        # loss
        self.loss, self.loss_segment_all, self.loss_segments = self.cal_loss(
            self.segments, self.label_seg_placeholder)

        with tf.name_scope("train"):
            # 学习率策略
            self.step_ph = tf.placeholder(dtype=tf.float32, shape=())
            self.learning_rate = tf.scalar_mul(
                tf.constant(self.learning_rate),
                tf.pow((1 - self.step_ph / self.num_steps), 0.8))
            self.train_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss)

            # 单独训练最后的 segment_side
            segment_side_trainable = [
                v for v in tf.trainable_variables() if 'segment_side' in v.name
            ]
            print(len(segment_side_trainable))
            self.train_segment_side_op = tf.train.GradientDescentOptimizer(
                self.learning_rate).minimize(self.loss,
                                             var_list=segment_side_trainable)
            pass

        # summary 1
        with tf.name_scope("loss"):
            tf.summary.scalar("loss", self.loss)
            tf.summary.scalar("loss_segment", self.loss_segment_all)
            for loss_segment_index, loss_segment in enumerate(
                    self.loss_segments):
                tf.summary.scalar("loss_segment_{}".format(loss_segment_index),
                                  loss_segment)
            pass

        with tf.name_scope("label"):
            tf.summary.image("1-image", self.image_placeholder)
            tf.summary.image(
                "2-segment",
                tf.cast(self.label_seg_placeholder * 255 // self.num_classes,
                        dtype=tf.uint8))
            pass

        with tf.name_scope("result"):
            # segment
            for segment_index, segment in enumerate(self.segments):
                segment = tf.cast(tf.argmax(segment, axis=-1), dtype=tf.uint8)
                segment = tf.split(segment,
                                   num_or_size_splits=self.batch_size,
                                   axis=0)
                ii = 3 if self.batch_size >= 3 else self.batch_size
                for i in range(ii):
                    tf.summary.image(
                        "predict-{}-{}".format(segment_index, i),
                        tf.expand_dims(segment[i] * 255 // self.num_classes,
                                       axis=-1))
                    pass
                pass
            pass

        self.summary_op = tf.summary.merge_all()

        # sess 和 saver
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=10)

        # summary 2
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        pass
コード例 #24
0
 def load_model(self):
     # 加载模型
     Tools.restore_if_y(self.sess, self.log_dir)
     pass