Exemplo n.º 1
0
class DarknetModel(object):
    def __init__(self):
        self.scales = "1,2,3"
        self.batch_size = 1
        self.confidence = 0.5
        self.nms_thesh = 0.4
        self.reso = 416
        self.CUDA = False
        self.num_classes = 80
        self.classes = load_classes('data/coco.names') 
        self.colors = load_colors('data/pallete')
        self.model = Darknet('cfg/yolov3.cfg', self.reso)
        self.model.load_state_dict(torch.load('yolov3.pkl'))
        self.inp_dim = self.reso
        assert self.inp_dim % 32 == 0 
        assert self.inp_dim > 32
        if self.CUDA:
            self.model.cuda()
        self.model.eval()
    def predict(self, filename):
        image = cv2.imread(filename)
        img, orig_im, dim = prep_image(image, self.inp_dim)  
        im_dim = torch.FloatTensor(dim).repeat(1,2)        
        if self.CUDA:
            im_dim = im_dim.cuda()
            img = img.cuda()
        output = self.model(img)
        output = sift_results(output, self.confidence, self.num_classes, nms = True, nms_conf = self.nms_thesh)
        output[:,1:5] = torch.clamp(output[:,1:5], 0.0, float(self.inp_dim))/self.inp_dim
            
        output[:,[1,3]] *= image.shape[1]
        output[:,[2,4]] *= image.shape[0]


        list(map(lambda x: write(x, orig_im, self.classes, self.colors), output))
        return orig_im
Exemplo n.º 2
0
if __name__ == '__main__':
    num_classes = 80

    args = arg_parse()
    confidence = float(args.confidence)
    nms_thesh = float(args.nms_thresh)
    start = 0
    CUDA = torch.cuda.is_available()
    classes = load_classes('data/coco.names')
    colors = load_colors('data/pallete')

    num_classes = 80
    bbox_attrs = 5 + num_classes

    model = Darknet(args.cfgfile, height=args.reso)
    model.load_state_dict(torch.load(args.weightsfile))

    model.net_info["height"] = args.reso
    inp_dim = int(model.net_info["height"])

    assert inp_dim % 32 == 0
    assert inp_dim > 32

    if CUDA:
        model.cuda()

    model.eval()

    cap = cv2.VideoCapture(args.video)

    assert cap.isOpened(), 'Cannot capture source'