Ejemplo n.º 1
0
def train(yolov3_trainer):
    """ YOLO v3模型训练 """
    logging.info('加载训练数据集:%s', FLAGS.train_label_path)
    train_dataset = FileUtil.get_dataset(FLAGS.train_label_path, FLAGS.train_set_dir,
                                         image_size=FLAGS.input_image_size[0:2],
                                         batch_size=FLAGS.batch_size, is_augment=FLAGS.is_augment, is_test=False)
    yolov3_trainer.train(train_dataset, None)
    logging.info('训练完毕!')
Ejemplo n.º 2
0
def test(yolov3_trainer, yolov3_decoder, save_path=None):
    """
    YOLO v3模型测试
    :param yolov3_trainer: yolov3检测模型
    :param yolov3_decoder: yolov3模型输出解码器
    :param save_path:测试结果图形报错路径
    """
    logging.info('加载测试数据集:%s', FLAGS.test_label_path)
    test_set = FileUtil.get_dataset(FLAGS.test_label_path,
                                    FLAGS.test_set_dir,
                                    image_size=FLAGS.input_image_size[0:2],
                                    batch_size=FLAGS.batch_size,
                                    is_augment=False,
                                    is_test=True)
    total_test = int(np.ceil(FLAGS.val_set_size / FLAGS.batch_size))
    input_box_size = np.tile(FLAGS.input_image_size[1::-1],
                             [2])  # 网络输入尺度,[W, H, W, H]
    # images为转为[0,1]范围的float32类型的TensorFlow矩阵
    for batch_counter, (images, labels, image_paths) in enumerate(test_set):
        if batch_counter > total_test:
            break
        images, labels, image_paths = np.array(images), np.array(
            labels), np.array(image_paths)
        predictions = yolov3_trainer.predict(images)
        [(_, head_8_predicts, head_8_predicts_boxes),
         (_, head_16_predicts, head_16_predicts_boxes),
         (_, head_32_predicts, head_32_predicts_boxes)
         ] = yolov3_decoder.decode(predictions)
        for image, label, image_path, head_8_prediction, head_8_boxes, \
            head_16_prediction, head_16_boxes, head_32_prediction, head_32_boxes in \
                zip(images, labels, image_paths, np.array(head_8_predicts), np.array(head_8_predicts_boxes),
                    np.array(head_16_predicts), np.array(head_16_predicts_boxes),
                    np.array(head_32_predicts), np.array(head_32_predicts_boxes)):
            # (k, 8), 归一化尺度->网络输入尺度的[(left top right bottom iou prob class score) ... ]
            high_score_boxes = YOLOv3PostProcessor.filter_boxes(
                head_8_prediction, head_8_boxes, head_16_prediction,
                head_16_boxes, head_32_prediction, head_32_boxes,
                FLAGS.confidence_thresh)
            nms_boxes = YOLOv3PostProcessor.apply_nms(high_score_boxes,
                                                      FLAGS.nms_thresh)
            in_boxes = YOLOv3PostProcessor.resize_boxes(
                nms_boxes, target_size=input_box_size)
            if save_path is not None:
                image_path = os.path.join(
                    save_path, str(os.path.basename(image_path), 'utf-8'))
                YOLOv3PostProcessor.visualize(image,
                                              in_boxes,
                                              src_box_size=input_box_size,
                                              image_path=image_path)
            # TODO 根据预测结果,计算AP,mAP
            # 使用开源库 [Cartucho/mAP](https://github.com/Cartucho/mAP),真香
    return
Ejemplo n.º 3
0
 def prepare_data(self, label_file_path, image_root_dir, is_augment=False, is_test=False):
     """
     数据集准备,返回可初始化迭代器,使用前需要先sess.run(iterator.initializer)进行初始化
     :param label_file_path: 标签文件路径,格式参考 代码具体接口解释
     :param image_root_dir: 图片文件根目录
     :param is_augment: 是否进行数据增强
     :param is_test: 是否为测试阶段
     :return: tf.data.Dataset对象
     """
     logging.info('加载数据集:%s', label_file_path)
     dataset = FileUtil.get_dataset(label_file_path, image_root_dir, image_size=self.input_shape[0:2],
                                    num_labels=len(self.output_shapes), batch_size=self.mini_batch,
                                    is_augment=is_augment, is_test=is_test)
     return dataset