Ejemplo n.º 1
0
    def generate(self):
        #------------------#
        #   载入模型
        #------------------#
        self.net = Retinanet(len(self.class_names), self.phi).eval()

        #----------------------------------------#
        #   载入权值
        #----------------------------------------#
        print('Loading weights into state dict...')
        state_dict = torch.load(self.model_path)
        self.net.load_state_dict(state_dict)
        if self.cuda:
            os.environ["CUDA_VISIBLE_DEVICES"] = '0'
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()
        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        #----------------------------------------#
        #   画框设置不同的颜色
        #----------------------------------------#
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))
Ejemplo n.º 2
0
class RetinaNet(object):
    _defaults = {
        "model_path": 'model_data/retinanet_resnet50.pth',
        "classes_path": 'model_data/voc_classes.txt',
        "input_shape": [600, 600, 3],
        "confidence": 0.5,
        "iou": 0.3,
        "phi": 2,
        "cuda": True,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化Retinanet
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        self.class_names = self._get_class()
        self.generate()

    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    #---------------------------------------------------#
    #   生成模型
    #---------------------------------------------------#
    def generate(self):
        #------------------#
        #   载入模型
        #------------------#
        self.net = Retinanet(len(self.class_names), self.phi).eval()

        #----------------------------------------#
        #   载入权值
        #----------------------------------------#
        print('Loading weights into state dict...')
        state_dict = torch.load(self.model_path)
        self.net.load_state_dict(state_dict)
        if self.cuda:
            os.environ["CUDA_VISIBLE_DEVICES"] = '0'
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()
        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        #----------------------------------------#
        #   画框设置不同的颜色
        #----------------------------------------#
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image):
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #---------------------------------------------------------#
        crop_img = np.array(
            letterbox_image(image, [self.input_shape[1], self.input_shape[0]]))
        photo = np.array(crop_img, dtype=np.float32)
        photo = np.transpose(preprocess_input(photo), (2, 0, 1))

        with torch.no_grad():
            images = torch.from_numpy(np.asarray([photo]))
            if self.cuda:
                images = images.cuda()

            #---------------------------------------------------------#
            #   传入网络当中进行预测
            #---------------------------------------------------------#
            _, regression, classification, anchors = self.net(images)

            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            regression = decodebox(regression, anchors, images)
            detection = torch.cat([regression, classification], axis=-1)
            batch_detections = non_max_suppression(detection,
                                                   len(self.class_names),
                                                   conf_thres=self.confidence,
                                                   nms_thres=self.iou)
            #--------------------------------------#
            #   如果没有检测到物体,则返回原图
            #--------------------------------------#
            try:
                batch_detections = batch_detections[0].cpu().numpy()
            except:
                return image

            #-----------------------------------------------------------#
            #   筛选出其中得分高于confidence的框
            #-----------------------------------------------------------#
            top_index = batch_detections[:, 4] > self.confidence
            top_conf = batch_detections[top_index, 4]
            top_label = np.array(batch_detections[top_index, -1], np.int32)
            top_bboxes = np.array(batch_detections[top_index, :4])
            top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(
                top_bboxes[:, 0],
                -1), np.expand_dims(top_bboxes[:, 1], -1), np.expand_dims(
                    top_bboxes[:, 2],
                    -1), np.expand_dims(top_bboxes[:, 3], -1)

            #-----------------------------------------------------------#
            #   去掉灰条部分
            #-----------------------------------------------------------#
            boxes = retinanet_correct_boxes(
                top_ymin, top_xmin, top_ymax, top_xmax,
                np.array([self.input_shape[0], self.input_shape[1]]),
                image_shape)

        font = ImageFont.truetype(font='model_data/simhei.ttf',
                                  size=np.floor(3e-2 * np.shape(image)[1] +
                                                0.5).astype('int32'))

        thickness = max(
            (np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0],
            1)

        for i, c in enumerate(top_label):
            predicted_class = self.class_names[c]
            score = top_conf[i]

            top, left, bottom, right = boxes[i]
            top = top - 5
            left = left - 5
            bottom = bottom + 5
            right = right + 5

            top = max(0, np.floor(top + 0.5).astype('int32'))
            left = max(0, np.floor(left + 0.5).astype('int32'))
            bottom = min(
                np.shape(image)[0],
                np.floor(bottom + 0.5).astype('int32'))
            right = min(
                np.shape(image)[1],
                np.floor(right + 0.5).astype('int32'))

            # 画框框
            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            label = label.encode('utf-8')
            print(label, top, left, bottom, right)

            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle([left + i, top + i, right - i, bottom - i],
                               outline=self.colors[self.class_names.index(
                                   predicted_class)])
            draw.rectangle(
                [tuple(text_origin),
                 tuple(text_origin + label_size)],
                fill=self.colors[self.class_names.index(predicted_class)])
            draw.text(text_origin,
                      str(label, 'UTF-8'),
                      fill=(0, 0, 0),
                      font=font)
            del draw
        return image
Ejemplo n.º 3
0
    #--------------------------------------------#
    #   训练前一定要注意注意修改
    #   classes_path对应的txt的内容
    #   修改成自己需要分的类
    #--------------------------------------------#
    classes_path = 'model_data/ship_classes.txt'
    #--------------------------------------------#
    #   获取classes和数量
    #--------------------------------------------#
    class_names = get_classes(classes_path)
    num_classes = len(class_names)

    #----------------------------------------------------#
    #   获取Retinanet模型
    #----------------------------------------------------#
    model = Retinanet(num_classes, phi, False)

    #----------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    #----------------------------------------------------#
    # model_path = "model_data/retinanet_resnet50.pth"
    # print('Loading weights into state dict...')
    # model_dict = model.state_dict()
    # pretrained_dict = torch.load(model_path)
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    # model_dict.update(pretrained_dict)
    # model.load_state_dict(model_dict)
    # print('Finished!')

    net = model.train()
Ejemplo n.º 4
0
    Cuda = True
    #--------------------------------------------#
    #   输入图像大小
    #--------------------------------------------#
    input_shape = (600, 600)
    annotation_path = '2007_train.txt'

    #--------------------------------------------#
    #   训练自己的模型需要修改txt
    #--------------------------------------------#
    classes_path = 'model_data/voc_classes.txt'   
    class_names = get_classes(classes_path)
    num_classes = len(class_names)
    
    # 创建模型
    model = Retinanet(num_classes, phi, False)
    
    #-------------------------------------------#
    #   权值文件的下载请看README
    #-------------------------------------------#
    model_path = "model_data/retinanet_resnet50.pth"
    # 加快模型训练的效率
    print('Loading weights into state dict...')
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print('Finished!')

    net = model.train()
Ejemplo n.º 5
0
import torch
from nets.retinanet import Retinanet
from nets.retinanet import Resnet

if __name__ == '__main__':

    inputs = torch.randn(5, 3, 512, 512)


    # Test inference
    model = Retinanet(80,2)
    print('# generator parameters:', sum(param.numel() for param in model.parameters()))
    
Ejemplo n.º 6
0
#--------------------------------------------#
#   该部分代码只用于看网络结构,并非测试代码
#   map测试请看get_dr_txt.py、get_gt_txt.py
#   和get_map.py
#--------------------------------------------#
import torch
from nets.retinanet import Retinanet
from nets.retinanet import Resnet

if __name__ == '__main__':
    inputs = torch.randn(5, 3, 512, 512)
    # Test inference
    model = Retinanet(80, 2)
    print('# generator parameters:',
          sum(param.numel() for param in model.parameters()))