def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, 5):

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            (raw_output_r, pred_segment_r, raw_output_classes_r,
             pred_classes_r) = self.sess.run(
                 [
                     self.raw_output_segment, self.pred_segment,
                     self.raw_output_classes, self.pred_classes
                 ],
                 feed_dict={self.image_placeholder: final_batch_data})

            if step % save_pred_freq == 0:
                self.saver.save(self.sess,
                                self.checkpoint_path,
                                global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            # Tools.print_info('step {:d} {} {} {}'.format(
            #     step, list(final_batch_class), list(pred_classes_r), list(raw_output_classes_r)))
            Tools.print_info('step {:d} {} {}'.format(step,
                                                      list(final_batch_class),
                                                      list(pred_classes_r)))

            pass

        pass
    def train(self, save_pred_freq, begin_step=0):
        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir)

        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            final_batch_data, final_batch_ann, final_batch_class, batch_data, batch_mask = \
                self.data_reader.next_batch_train()

            # train_op = self.train_classes_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r,
                 summary_now) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes,
                     self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_segment_r, accuracy_classes_r,
                 _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 raw_output_r, pred_segment_r, raw_output_classes_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_segment, self.accuracy_classes,
                     train_op, self.learning_rate,
                     self.loss_segment, self.loss_classes, self.loss,
                     self.raw_output_segment, self.pred_segment, self.raw_output_classes, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: final_batch_data,
                               self.label_segment_placeholder: final_batch_ann,
                               self.label_classes_placeholder: final_batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            Tools.print_info(
                'step {:d} loss={:.3f} seg={:.3f} class={:.3f} acc={:.3f} acc_class={:.3f}'
                ' lr={:.6f} ({:.3f} s/step) {} {}'.format(
                    step, loss_r, loss_segment_r, loss_classes_r, accuracy_segment_r, accuracy_classes_r,
                    learning_rate_r, duration, list(final_batch_class), list(pred_classes_r)))

            pass

        pass
Example #3
0
    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_segment = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                _, learning_rate_r, loss_segment_r, loss_r, summary_now = self.sess.run(
                    [train_op, self.learning_rate, self.loss_segment_all, self.loss, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.label_seg_placeholder: batch_segment})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                _, learning_rate_r, loss_segment_r, loss_r = self.sess.run(
                    [train_op, self.learning_rate, self.loss_segment_all, self.loss],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.label_seg_placeholder: batch_segment})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                total_loss = loss_r
            else:
                total_loss += loss_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info('step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} loss={:.3f} seg={:.3f} lr={:.6f} '
                                 '({:.3f} s/step)'.format(step, pre_avg_loss, total_loss / total_loss_step,
                                                          loss_r, loss_segment_r, learning_rate_r, duration))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass
    def run(self, result_filename, image_filename, where=None, annotation_filename=None, ann_index=0):
        # 读入图片数据
        final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image(
            image_filename, where=where, annotation_filename=annotation_filename,
            ann_index=ann_index, image_size=self.input_size)

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=1, last_pool_size=self.last_pool_size,
                     filter_number=32)

        # 输出/预测
        raw_output_op = net.layers["conv6_n"]
        sigmoid_output_op = tf.sigmoid(raw_output_op)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        # 运行
        raw_output, sigmoid_output = sess.run([raw_output_op, sigmoid_output_op], feed_dict={img_placeholder: final_batch_data})

        # 保存
        Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "data.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(sigmoid_output[0] * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(np.greater(raw_output[0], 0.5) * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred_raw.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(np.greater(sigmoid_output[0], 0.5) * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "pred_sigmoid.png"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "mask.bmp"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "ann.bmp"))
        Tools.print_info('over : result save in {}'.format(os.path.join(self.save_dir, result_filename)))
        pass
    def load_net(self):
        print("begin to build net and start session and load model ...")

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size[0], self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder}, is_training=True, num_classes=21,
                     last_pool_size=self.last_pool_size, filter_number=32, num_segment=4)

        # 输出/预测
        raw_output_op = net.layers["conv6_n_4"]
        raw_output_op = tf.image.resize_bilinear(raw_output_op, size=self.input_size)
        predict_output_op = tf.argmax(tf.sigmoid(raw_output_op), axis=-1)

        raw_output_classes = net.layers['class_attention_fc']
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1), tf.int32)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        print("end build net and start session and load model ...")

        return sess, img_placeholder, predict_output_op, pred_classes
    def run(self,
            result_filename,
            image_filename,
            where=None,
            annotation_filename=None,
            ann_index=0):
        # 读入图片数据
        if annotation_filename:
            final_batch_data, data_raw, gaussian_mask, ann_data, ann_mask = Data.load_image(
                image_filename,
                where=where,
                annotation_filename=annotation_filename,
                ann_index=ann_index,
                image_size=self.input_size)
        else:
            final_batch_data, data_raw, gaussian_mask = Data.load_image(
                image_filename,
                where=where,
                annotation_filename=annotation_filename,
                ann_index=ann_index,
                image_size=self.input_size)

        # 网络
        img_placeholder = tf.placeholder(dtype=tf.float32,
                                         shape=(None, self.input_size[0],
                                                self.input_size[1], 4))
        net = PSPNet({'data': img_placeholder},
                     is_training=True,
                     num_classes=21,
                     last_pool_size=self.last_pool_size,
                     filter_number=32,
                     num_segment=4)

        # 输出/预测
        raw_output_op = net.layers["conv6_n_4"]
        sigmoid_output_op = tf.sigmoid(raw_output_op)
        predict_output_op = tf.argmax(sigmoid_output_op, axis=-1)

        raw_output_classes = net.layers['class_attention_fc']
        pred_classes = tf.cast(tf.argmax(raw_output_classes, axis=-1),
                               tf.int32)

        # 启动Session/加载模型
        sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True)))
        sess.run(tf.global_variables_initializer())
        Tools.restore_if_y(sess, self.log_dir)

        # 运行
        raw_output, sigmoid_output, predict_output_r, raw_output_classes_r, pred_classes_r = sess.run(
            [
                raw_output_op, sigmoid_output_op, predict_output_op,
                raw_output_classes, pred_classes
            ],
            feed_dict={img_placeholder: final_batch_data})

        # 保存
        print("{} {} {}".format(pred_classes_r[0],
                                CategoryNames[pred_classes_r[0]],
                                raw_output_classes_r))
        print("result in {}".format(
            os.path.join(self.save_dir, result_filename)))
        Image.fromarray(np.asarray(np.squeeze(data_raw), dtype=np.uint8)).save(
            os.path.join(self.save_dir, result_filename + "data.png"))

        output_result = np.squeeze(
            np.split(np.asarray(sigmoid_output[0] * 255, dtype=np.uint8),
                     axis=-1,
                     indices_or_sections=4))

        Image.fromarray(
            np.squeeze(
                np.asarray(predict_output_r[0] * 255 // 4,
                           dtype=np.uint8))).save(
                               os.path.join(self.save_dir,
                                            result_filename + "pred.png"))
        Image.fromarray(output_result[0]).save(
            os.path.join(self.save_dir, result_filename + "pred_0.png"))
        Image.fromarray(output_result[1]).save(
            os.path.join(self.save_dir, result_filename + "pred_1.png"))
        Image.fromarray(output_result[2]).save(
            os.path.join(self.save_dir, result_filename + "pred_2.png"))
        Image.fromarray(output_result[3]).save(
            os.path.join(self.save_dir, result_filename + "pred_3.png"))

        Image.fromarray(
            np.asarray(np.squeeze(gaussian_mask * 255), dtype=np.uint8)).save(
                os.path.join(self.save_dir, result_filename + "mask.bmp"))

        if annotation_filename:
            Image.fromarray(
                np.asarray(np.squeeze(ann_mask * 255), dtype=np.uint8)).save(
                    os.path.join(self.save_dir, result_filename + "ann.bmp"))
            pass
        pass
Example #7
0
 def load_model(self, pretrain):
     # 加载模型
     Tools.restore_if_y(self.sess, self.log_dir, pretrain=pretrain)
     pass
Example #8
0
 def load_model(self):
     # 加载模型
     Tools.restore_if_y(self.sess, self.log_dir)
     pass
    def train(self, save_pred_freq, begin_step=0):

        # 加载模型
        Tools.restore_if_y(self.sess, self.log_dir, pretrain=self.pretrain)

        total_loss = 0.0
        pre_avg_loss = 0.0
        total_acc = 0.0
        pre_avg_acc = 0.0
        for step in range(begin_step, self.num_steps):
            start_time = time.time()

            batch_data, batch_mask, batch_attention, batch_class = self.data_reader.next_batch_train()

            # train_op = self.train_attention_op
            train_op = self.train_op

            if step % self.print_step == 0:
                # summary 3
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r,
                 pred_classes_r, summary_now) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss,
                     self.pred_classes, self.summary_op],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                self.summary_writer.add_summary(summary_now, global_step=step)
            else:
                (accuracy_classes_r, _, learning_rate_r,
                 loss_segment_r, loss_classes_r, loss_r, pred_classes_r) = self.sess.run(
                    [self.accuracy_classes, train_op, self.learning_rate,
                     self.loss_segment_all, self.loss_class_all, self.loss, self.pred_classes],
                    feed_dict={self.step_ph: step, self.image_placeholder: batch_data,
                               self.mask_placeholder: batch_mask,
                               self.label_seg_placeholder: batch_attention,
                               self.label_cls_placeholder: batch_class})
                pass

            if step % save_pred_freq == 0:
                self.saver.save(self.sess, self.checkpoint_path, global_step=step)
                Tools.print_info('The checkpoint has been created.')
                pass

            duration = time.time() - start_time

            if step % self.cal_step == 0:
                pre_avg_loss = total_loss / self.cal_step
                pre_avg_acc = total_acc / self.cal_step
                total_loss = loss_r
                total_acc = accuracy_classes_r
            else:
                total_loss += loss_r
                total_acc += accuracy_classes_r

            total_loss_step = ((step % self.cal_step) + 1)

            if step % (self.cal_step // 10) == 0:
                Tools.print_info(
                    'step {:d} pre_avg_loss={:.3f} avg_loss={:.3f} pre_avg_acc={:.3f} avg_acc={:.3f} '
                    'loss={:.3f} seg={:.3f} class={:.3f} acc_class={:.3f} '
                    'lr={:.6f} ({:.3f} s/step) {} {}'.format(
                        step, pre_avg_loss, total_loss / total_loss_step, pre_avg_acc, total_acc / total_loss_step,
                        loss_r, loss_segment_r, loss_classes_r, accuracy_classes_r,
                        learning_rate_r, duration, list(batch_class), list(pred_classes_r)))
                pass
            if step % self.print_step == 0:
                Tools.print_info("")

            pass

        pass