def test_voc():
    # load net
    num_classes = len(VOC_CLASSES) + 1 # +1 background
    net = build_ssd('test', 300, num_classes) # initialize SSD
    net.load_state_dict(torch.load(args.trained_model,map_location='cpu'))
    net.eval()
    print('Finished loading model!')
    # load data
    testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform())
    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    # evaluation
    test_net(args.save_folder, net, args.cuda, testset,
             BaseTransform(net.size, (104, 117, 123)),
             thresh=args.visual_threshold)
Example #2
0
def detection_video(pipeline, weight):  #识别的video
    global image_size, depth_image, init_num, class_name
    flag = 0
    net = build_ssd('test', 300, num_classes)
    net.eval()
    net.load_weights(weight)  #导入模型参数
    init_num = 0
    size = (640, 480)
    #t4=time.time()#test

    while True:
        frames = pipeline.wait_for_frames()  #获取一帧
        depth_frame = frames.get_depth_frame()  #深度图
        color_frame = frames.get_color_frame()  #颜色图
        if not depth_frame or not color_frame:
            continue
        depth_image = np.asanyarray(depth_frame.get_data())
        color_image = np.asanyarray(color_frame.get_data())
        # 在深度图上用颜色渲染
        #depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)

        if init_num == 0:  #初始化程序

            flag += 1
            if flag % 3 != 0:  #每三帧处理一次,为了防止jetson nano速率不够
                continue

            t0 = time.time()
            rgb_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)

            resize_image = cv2.resize(color_image,
                                      (300, 300)).astype(np.float32)
            resize_image -= (104, 117, 123)  #对SSD实现均值化
            resize_image = resize_image.astype(np.float32)  #转为float32
            resize_image = resize_image[:, :, ::-1].copy()

            torch_image = torch.from_numpy(resize_image).permute(
                2, 0, 1)  #重新排列传入torch
            input_image = Variable(torch_image.unsqueeze(0))  #扩展第一列
            if torch.cuda.is_available():
                input_image = input_image.cuda()  #设置为CUDA形式

            out = net(input_image)  #传入到模型当中

            colors = cfg.COLORS

            detections = out.data

            scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(
                2)  #[ 起始下标 : 终止下标 : 间隔距离 ]
            rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)  #转化为BGR参数

            idx_obj = -1  #初始为-1
            center_point = [0, 0]
            gallery_best_draw = [0, 0, 0, 0]

            for i in range(detections.size(1)):  #获取所有的参数
                j = 0  #都要循环类的次数
                #center_point=[0,0]
                #print(detections.size())
                if detections[0, i, j, 0] >= 0.45:  #设定阈值

                    idx_obj += 1  #物体数+1

                    score = detections[0, i, j, 0]  #计算得分
                    label_name = labels[i - 1]  #得到名称

                    display_txt = '%s %.2f' % (label_name, score)  #显示目标物体位置
                    pt = (detections[0, i, j, 1:] *
                          scale).cpu().numpy()  #获取四个点位置

                    #j += 1

                    # 求得四个边角,并防止溢出
                    pt[0] = max(pt[0], 0)
                    pt[1] = max(pt[1], 0)
                    pt[2] = min(pt[2], size[1])
                    pt[3] = min(pt[3], size[0])

                    if label_name == "cup" or label_name == "battery" or label_name == "bottle" or label_name == "orange" or label_name == "paper":  #保证检测到的是垃圾信息
                        if (pt[0] + pt[2]
                            ) / 2 > 100 and (pt[1] + pt[3]) / 2 > 140 and (
                                pt[0] + pt[1] + pt[2] + pt[3]) / 2 > (
                                    center_point[0] +
                                    center_point[1]) and (pt[2] - pt[0]) * (
                                        pt[3] - pt[1]) > 5000:  #处理一帧中的最优点
                            center_point = [(pt[0] + pt[2]) / 2,
                                            (pt[1] + pt[3]) / 2]  #更新最优点
                            gallery_best_draw = [pt[0], pt[1], pt[2], pt[3]]
                            init_num = 1
                            class_name = label_name  #为了判断障碍是属于哪种

                    color = colors[idx_obj % len(colors)]  #选择颜色

                    textsize = cv2.getTextSize(display_txt,
                                               cv2.FONT_HERSHEY_COMPLEX, 1,
                                               2)[0]  #显示文本文字

                    text_x = int(pt[0])  #文本位置
                    text_y = int(pt[1])
                    cv2.rectangle(rgb_image, (int(pt[0]), int(pt[1])),
                                  (int(pt[2]), int(pt[3])), color, 4)  #框选位置
                    cv2.putText(
                        rgb_image, display_txt, (text_x + 4, text_y),
                        cv2.FONT_HERSHEY_COMPLEX, 1,
                        (255 - color[0], 255 - color[1], 255 - color[2]),
                        2)  #输出结果

            if gallery_best_draw[0] != 0:
                #https://blog.csdn.net/weixin_44576543/article/details/96179330
                #https://blog.csdn.net/weixin_44576543/article/details/96175286
                distace = D415_Depth(gallery_best_draw)
                Angle(center_point, distace)
                track_roi = (gallery_best_draw[0], gallery_best_draw[1],
                             abs(gallery_best_draw[2] - gallery_best_draw[0]),
                             abs(gallery_best_draw[3] - gallery_best_draw[1]))
                print("track_roi:", track_roi)
                tracker_rgb = cv2.TrackerMOSSE_create()  #重置
                tracker_rgb.init(rgb_image, track_roi)  #初始化对应的参数
            '''t1 = time.time()

            cv2.putText(rgb_image, "FPS: %.2f" % (1 / (t1 - t0)), (5, 30), cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 255, 255), 2)

            cv2.imshow("result",rgb_image)'''

        elif init_num == 1:
            color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
            color_image = cv2.cvtColor(color_image,
                                       cv2.COLOR_RGB2BGR)  #转化为BGR参数
            t0 = time.time()
            rgb_image = color_image.copy()
            (success, box) = tracker_rgb.update(rgb_image)
            # if time.time()-t4>10:#test
            #      init_num=2
            #      t4=time.time()
            #print(time.time()-t4)
            if success:
                (x, y, w, h) = [int(v) for v in box]
                csrt_best_draw = [x, y, x + w, y + h]
                center_point = [(x + w / 2), (y + h / 2)]  #更新最优点
                distace = D415_Depth(csrt_best_draw)
                Angle(center_point, distace)
                cv2.rectangle(rgb_image, tuple(csrt_best_draw), color,
                              4)  #框选位置
            cv2.imshow("result", rgb_image)

        elif init_num == 2:  # 初始化程序
            flag += 1
            if flag % 3 != 0:  # 每三帧处理一次,为了防止jetson nano速率不够
                continue

            t0 = time.time()
            rgb_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)

            resize_image = cv2.resize(color_image,
                                      (300, 300)).astype(np.float32)
            resize_image -= (104, 117, 123)  # 对SSD实现均值化
            resize_image = resize_image.astype(np.float32)  # 转为float32
            resize_image = resize_image[:, :, ::-1].copy()

            torch_image = torch.from_numpy(resize_image).permute(
                2, 0, 1)  # 重新排列传入torch
            input_image = Variable(torch_image.unsqueeze(0))  # 扩展第一列
            if torch.cuda.is_available():
                input_image = input_image.cuda()  # 设置为CUDA形式

            # if time.time()-t4>10:#test
            #      init_num=0
            #      t4=time.time()

            out = net(input_image)  # 传入到模型当中

            colors = cfg.COLORS

            detections = out.data

            scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(
                2)  # [ 起始下标 : 终止下标 : 间隔距离 ]

            rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)  # 转化为BGR参数

            idx_obj = -1  # 初始为-1
            center_point = [0, 0]
            gallery_best_draw = [0, 0, 0, 0]
            for i in range(detections.size(1)):  # 获取所有的参数
                j = 0  # 都要循环类的次数
                #center_point = [0, 0]
                # print(detections.size())
                if detections[0, i, j, 0] >= 0.45:  # 设定阈值

                    #idx_obj += 1  # 物体数+1

                    score = detections[0, i, j, 0]  # 计算得分
                    label_name = labels[i - 1]  # 得到名称

                    display_txt = '%s %.2f' % (label_name, score)  # 显示目标物体位置
                    pt = (detections[0, i, j, 1:] *
                          scale).cpu().numpy()  # 获取四个点位置

                    j += 1

                    # 求得四个边角,并防止溢出
                    pt[0] = max(pt[0], 0)
                    pt[1] = max(pt[1], 0)
                    pt[2] = min(pt[2], size[1])
                    pt[3] = min(pt[3], size[0])

                    if (pt[2] - pt[0]) * (pt[3] - pt[1]) > 5000:
                        if (class_name == "cup" and label_name == "brown") or (
                                class_name == "battery" and label_name == "red"
                        ) or (class_name == "bottle" and label_name == "black"
                              ) or (class_name == "orange" and label_name
                                    == "green") or (class_name == "paper"
                                                    and label_name == "black"):
                            gallery_best_draw = [pt[0], pt[1], pt[2], pt[3]]
                            center_point = [(pt[0] + pt[2]) / 2,
                                            (pt[1] + pt[3]) / 2]  # 更新最优点
                            print(center_point)

                    color = colors[idx_obj % len(colors)]  # 选择颜色

                    text_x = int(pt[0])  # 文本位置
                    text_y = int(pt[1])
                    cv2.rectangle(rgb_image, (int(pt[0]), int(pt[1])),
                                  (int(pt[2]), int(pt[3])), color, 4)  # 框选位置
                    cv2.putText(
                        rgb_image, display_txt, (text_x + 4, text_y),
                        cv2.FONT_HERSHEY_COMPLEX, 1,
                        (255 - color[0], 255 - color[1], 255 - color[2]),
                        2)  # 输出结果

            if gallery_best_draw[0] != 0:
                # https://blog.csdn.net/weixin_44576543/article/details/96179330
                # https://blog.csdn.net/weixin_44576543/article/details/96175286
                distace = D415_Depth(gallery_best_draw)
                Angle(center_point, distace)

        t1 = time.time()

        cv2.putText(rgb_image, "FPS: %.2f" % (1 / (t1 - t0)), (5, 30),
                    cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 255, 255), 2)

        cv2.imshow("result", rgb_image)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            cap.release()
            cv2.destroyAllWindows()
def train():
    if args.dataset == 'COCO':
        if args.dataset_root == VOC_ROOT:
            if not os.path.exists(COCO_ROOT):
                parser.error('Must specify dataset_root if specifying dataset')
            print("WARNING: Using default COCO dataset_root because " +
                  "--dataset_root was not specified.")
            args.dataset_root = COCO_ROOT
        cfg = coco
        dataset = COCODetection(root=args.dataset_root,
                                transform=SSDAugmentation(
                                    cfg['min_dim'], MEANS))
    elif args.dataset == 'VOC':
        if args.dataset_root == COCO_ROOT:
            parser.error('Must specify dataset if specifying dataset_root')
        cfg = voc
        dataset = VOCDetection(root=args.dataset_root,
                               transform=SSDAugmentation(
                                   cfg['min_dim'], MEANS))

    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net

    if args.cuda:
        net = torch.nn.DataParallel(ssd_net)
        cudnn.benchmark = True

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        ssd_net.load_weights(args.resume)
    else:
        ssd_net.mobilenet = nn.DataParallel(ssd_net.mobilenet)

        mobile_weights = torch.load(args.save_folder + args.basenet,
                                    map_location='cuda:0')
        print('Loading base network...')
        ssd_net.mobilenet.load_state_dict(mobile_weights['state_dict'])

        if isinstance(ssd_net.mobilenet, torch.nn.DataParallel):
            ssd_net.mobilenet = ssd_net.mobilenet.module

        ssd_net.mobilenet.apply(weights_init)

    if args.cuda:
        net = net.cuda()

    if not args.resume:
        print('Initializing weights...')
        # initialize newly added layers' weights with xavier method
        #ssd_net.extra1.apply(weights_init)
        ssd_net.extras.apply(weights_init)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, args.cuda)

    net.train()
    # loss counters
    loc_loss = 0
    conf_loss = 0
    epoch = 0
    print('Loading the dataset...')

    epoch_size = len(dataset) // args.batch_size
    print('Training SSD on:', dataset.name)
    print('Using the specified args:')
    print(args)

    step_index = 0

    if args.visdom:
        vis_title = 'SSD.PyTorch on ' + dataset.name
        vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
        iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend)
        epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend)

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True,
                                  worker_init_fn=worker_init_fn)
    # create batch iterator
    batch_iterator = iter(data_loader)

    t0 = time.time()
    for iteration in range(args.start_iter, cfg['max_iter']):
        if args.visdom and iteration != 0 and (iteration % epoch_size == 0):

            epoch += 1

            update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None,
                            'append', epoch_size)
            # reset epoch loss counters
            loc_loss = 0
            conf_loss = 0
            #epoch += 1

        if iteration in cfg['lr_steps']:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

        # load train data
        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

        with torch.no_grad():
            if args.cuda:
                images = Variable(images.cuda())
                targets = [Variable(ann.cuda()) for ann in targets]
            else:
                images = Variable(images)
                targets = [Variable(ann) for ann in targets]

        # backprop
        optimizer.zero_grad()

        # forward
        out = net(images)

        loss_l, loss_c = criterion(out, targets)

        loss = loss_l + loss_c

        loss.backward()
        optimizer.step()
        # t1 = time.time()

        # loc_loss += loss_l.data[0]
        # conf_loss += loss_c.data[0]
        loc_loss += loss_l.item()
        conf_loss += loss_c.item()

        if iteration % 10 == 0:

            #print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')
            print('[%3d/%d] iter ' % (epoch + args.start_iter / epoch_size,
                                      int(cfg['max_iter']) / epoch_size) +
                  repr(iteration) + ' || Loss: %.4f ||' % (loss.item()),
                  end=' ')
            print('timer: %.4f sec.' % (time.time() - t0))
            t0 = time.time()
        if args.visdom:
            # update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
            #                 iter_plot, epoch_plot, 'append')
            update_vis_plot(iteration, loss_l.item(), loss_c.item(), iter_plot,
                            epoch_plot, 'append')

        #if iteration != 0 and iteration % 5000 == 0:

        if iteration != 0 and iteration % 5000 == 0:
            print('Saving state, iter:', iteration)
            torch.save(
                ssd_net.state_dict(), 'weights/ssd_mobilenetv2/mobilenetv2_' +
                repr(iteration) + '.pth')
    torch.save(ssd_net.state_dict(),
               'weights/ssd_mobilenetv2/mobilenetv2_final' + '.pth')
Example #4
0
def detection_video(path,weight):#识别的video
    global image_size,tracker_rgb,init_num
    flag = 0
    net = build_ssd('test', 300, num_classes)
    net.eval()
    net.load_weights(weight)#导入模型参数

    cap = cv2.VideoCapture(path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    init_num=0
    t4=time.time()

    while cap.isOpened():
        ret,image = cap.read()
        if init_num==0:#初始化程序
            flag += 1

            if ret == False:
                print("video is over!")
                break
            if flag % 3 != 0:#每三帧处理一次,为了防止jetson nano速率不够
                continue

            t0 = time.time()
            rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            resize_image = cv2.resize(image, (300, 300)).astype(np.float32)
            resize_image -= (104, 117, 123)#对SSD实现均值化
            resize_image = resize_image.astype(np.float32)#转为float32
            resize_image = resize_image[:, :, ::-1].copy()

            torch_image = torch.from_numpy(resize_image).permute(2, 0, 1)#重新排列传入torch
            input_image = Variable(torch_image.unsqueeze(0))#扩展第一列
            if torch.cuda.is_available():
                input_image = input_image.cuda()#设置为CUDA形式

            out = net(input_image)#传入到模型当中

            colors = cfg.COLORS

            detections = out.data

            scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)#[ 起始下标 : 终止下标 : 间隔距离 ]


            rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)#转化为BGR参数

            idx_obj = -1#初始为-1

            center_point=[0,0]
            gallery_best_draw=[0,0,0,0]

            for i in range(detections.size(1)):#获取所有的参数
                j = 0#都要循环类的次数
                #print(detections.size())
                #print(i)
                if detections[0,i,j,0] >= 0.95:#设定阈值

                    idx_obj += 1#物体数+1

                    score = detections[0,i,j,0]#计算得分
                    label_name = labels[i-1]#得到名称

                    display_txt = '%s %.2f'%(label_name, score)#显示目标物体位置
                    pt = (detections[0,i,j,1:]*scale).cpu().numpy()#获取四个点位置

                    #j += 1

                    # 求得四个边角,并防止溢出
                    pt[0] = max(pt[0],0)
                    pt[1] = max(pt[1],0)
                    pt[2] = min(pt[2],size[1])
                    pt[3] = min(pt[3],size[0])
                    #print(pt[0],pt[3])
                    if  abs(pt[2]-pt[0])*abs(pt[3]-pt[1])>500:
                        print((pt[2]-pt[0])*(pt[3]-pt[1]))
                        if (pt[0]+pt[2])/2>100 and (pt[1]+pt[3])/2>140 and (pt[0]+pt[1]+pt[2]+pt[3])/2>(center_point[0]+center_point[1]):#处理一帧中的最优点
                            center_point=[(pt[0]+pt[2])/2,(pt[1]+pt[3])/2]#更新最优点
                            gallery_best_draw=[pt[0],pt[1],pt[2],pt[3]]
                            #init_num=1
                            #print(pt[0],pt[1],pt[2],pt[3])
                            #print(center_point)
                    else:
                        print("error",(pt[2]-pt[0])*(pt[3]-pt[1]))
                        continue




                    color = colors[idx_obj%len(colors)]#选择颜色

                    textsize = cv2.getTextSize(display_txt, cv2.FONT_HERSHEY_COMPLEX, 1, 2)[0]#显示文本文字


                    text_x = int(pt[0])#文本位置
                    text_y = int(pt[1])
                    cv2.rectangle(rgb_image,(int(pt[0]), int(pt[1])),(int(pt[2]), int(pt[3])),color,4)#框选位置
                    cv2.putText(rgb_image, display_txt, (text_x + 4, text_y), cv2.FONT_HERSHEY_COMPLEX, 1,(255 - color[0], 255 - color[1], 255 - color[2]), 2)#输出结果

            if gallery_best_draw[0]!=0:
                track_roi=(gallery_best_draw[0],gallery_best_draw[1],abs(gallery_best_draw[2]-gallery_best_draw[0]),abs(gallery_best_draw[3]-gallery_best_draw[1]))
                print("track_roi:",track_roi)
                try:
                    tracker_rgb=cv2.TrackerMOSSE_create()#重置
                    tracker_rgb.init(rgb_image, track_roi)#初始化对应的参数
                except:
                	pass


            #t1 = time.time()

            #cv2.putText(rgb_image, "FPS: %.2f" % (1 / (t1 - t0)), (5, 30), cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 255, 255), 2)

            #cv2.imshow("result",rgb_image)

        elif init_num==1:
            t0 = time.time()
            images = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images = cv2.cvtColor(images, cv2.COLOR_RGB2BGR)#转化为BGR参数
            rgb_image=images.copy()
            (success, box) = tracker_rgb.update(rgb_image)
            if time.time()-t4>10:
                init_num=0
                t4=time.time()
            #print(time.time()-t4)
            if success:
                (x, y, w, h) = [int(v) for v in box]
                csrt_best_draw=[int(x),int(y),int(x+w),int(y+h)]
                cv2.rectangle(rgb_image,tuple(csrt_best_draw),color,4)#框选位置
        
        t1 = time.time()

        cv2.putText(rgb_image, "FPS: %.2f" % (1 / (t1 - t0)), (5, 30), cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 255, 255), 2)

        cv2.imshow("result",rgb_image)



        if cv2.waitKey(1) & 0xFF == ord('q'):
            cap.release()
            cv2.destroyAllWindows()
def detection_image(path, weight):
    # cv2.namedWindow("result", 0)
    global image_size

    net = build_ssd('test', 300, num_classes)
    net.eval()
    net.load_weights(weight)

    image = cv2.imread(path, cv2.IMREAD_COLOR)

    rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # rgb_image = cv2.resize(rgb_image, (512, 512))

    resize_image = cv2.resize(image, (300, 300)).astype(np.float32)
    resize_image -= (104, 117, 123)
    resize_image = resize_image.astype(np.float32)
    resize_image = resize_image[:, :, ::-1].copy()

    torch_image = torch.from_numpy(resize_image).permute(2, 0, 1)

    input_image = Variable(torch_image.unsqueeze(0))
    if torch.cuda.is_available():
        input_image = input_image.cuda()

    out = net(input_image)

    colors = cfg.COLORS

    detections = out.data

    scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)

    rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)

    idx_obj = -1

    for i in range(detections.size(1)):
        j = 0
        while detections[0, i, j, 0] >= 0.45:

            idx_obj += 1

            score = detections[0, i, j, 0]
            label_name = labels[i - 1]
            display_txt = '%s %.2f' % (label_name, score)
            pt = (detections[0, i, j, 1:] * scale).cpu().numpy()

            j += 1

            pt[0] = max(pt[0], 0)
            pt[1] = max(pt[1], 0)
            pt[2] = min(pt[2], rgb_image.shape[1])
            pt[3] = min(pt[3], rgb_image.shape[0])

            color = colors[idx_obj % (len(colors))]

            textsize = cv2.getTextSize(display_txt, cv2.FONT_HERSHEY_COMPLEX,
                                       1, 2)[0]

            text_x = int(pt[0])
            text_y = int(pt[1])
            if (int(pt[1]) - textsize[1] < 0):
                text_y = int(pt[1]) + textsize[1] + 2
                cv2.rectangle(rgb_image, (int(pt[0]), int(pt[1])),
                              (int(pt[0]) + textsize[0] + 8,
                               int(pt[1]) + textsize[1] + 10),
                              (color[0], color[1], color[2], 125), -1)
            else:
                text_y -= 6
                cv2.rectangle(rgb_image,
                              (int(pt[0]) - 2, int(pt[1]) - textsize[1] - 10),
                              (int(pt[0]) + textsize[0] + 8, int(pt[1])),
                              (color[0], color[1], color[2], 125), -1)

            cv2.rectangle(rgb_image, (int(pt[0]), int(pt[1])),
                          (int(pt[2]), int(pt[3])), color, 4)
            cv2.putText(rgb_image, display_txt, (text_x + 4, text_y),
                        cv2.FONT_HERSHEY_COMPLEX, 1,
                        (255 - color[0], 255 - color[1], 255 - color[2]), 2)
            cv2.putText(
                rgb_image, 'x',
                (int((pt[2] + pt[0]) // 2 - 5), int(
                    (pt[3] + pt[1]) // 2)), cv2.FONT_HERSHEY_COMPLEX, 1, color)

        # cv2.imshow("result", rgb_image)
        # cv2.waitKey(0)
    print(path.replace("test_images", "out_images"))
    cv2.imwrite(path.replace("test_images", "out_images"), rgb_image)
def detection_video(path, weight):
    global image_size

    flag = 0
    net = build_ssd('test', 300, num_classes)
    net.eval()
    net.load_weights(weight)

    cap = cv2.VideoCapture(path)
    frameNumber = cap.get(7)

    fps = cap.get(cv2.CAP_PROP_FPS)
    size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', '2')
    outVideo = cv2.VideoWriter(
        'output_videos/out_%s.avi' % (path.split("/")[-1].split(".")[0]),
        fourcc, fps, size)
    cv2.namedWindow("result", 0)

    # image_size = size[0]

    while cap.isOpened():
        ret, image = cap.read()
        flag += 1

        if ret == False:
            print("video is over!")
            break
        if flag % 3 != 0:
            continue

        t0 = time.time()
        rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        #rgb_image = cv2.resize(rgb_image, (512, 512))

        resize_image = cv2.resize(image, (300, 300)).astype(np.float32)
        resize_image -= (104, 117, 123)
        resize_image = resize_image.astype(np.float32)
        resize_image = resize_image[:, :, ::-1].copy()

        torch_image = torch.from_numpy(resize_image).permute(2, 0, 1)

        input_image = Variable(torch_image.unsqueeze(0))
        if torch.cuda.is_available():
            input_image = input_image.cuda()

        out = net(input_image)

        colors = cfg.COLORS

        detections = out.data

        scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)

        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)

        idx_obj = -1

        for i in range(detections.size(1)):
            j = 0
            while detections[0, i, j, 0] >= 0.45:

                idx_obj += 1

                score = detections[0, i, j, 0]
                label_name = labels[i - 1]

                display_txt = '%s %.2f' % (label_name, score)
                pt = (detections[0, i, j, 1:] * scale).cpu().numpy()

                j += 1

                # TODO revise solutions
                pt[0] = max(pt[0], 0)
                pt[1] = max(pt[1], 0)
                pt[2] = min(pt[2], size[1])
                pt[3] = min(pt[3], size[0])

                color = colors[idx_obj % len(colors)]

                textsize = cv2.getTextSize(display_txt,
                                           cv2.FONT_HERSHEY_COMPLEX, 1, 2)[0]

                text_x = int(pt[0])
                text_y = int(pt[1])
                if (int(pt[1]) - textsize[1] < 0):
                    text_y = int(pt[1]) + textsize[1] + 2
                    cv2.rectangle(rgb_image, (int(pt[0]), int(pt[1])),
                                  (int(pt[0]) + textsize[0] + 8,
                                   int(pt[1]) + textsize[1] + 10),
                                  (color[0], color[1], color[2], 125), -1)
                else:
                    text_y -= 6
                    cv2.rectangle(
                        rgb_image,
                        (int(pt[0]) - 2, int(pt[1]) - textsize[1] - 10),
                        (int(pt[0]) + textsize[0] + 8, int(pt[1])),
                        (color[0], color[1], color[2], 125), -1)

                cv2.rectangle(rgb_image, (int(pt[0]), int(pt[1])),
                              (int(pt[2]), int(pt[3])), color, 4)
                cv2.putText(rgb_image, display_txt, (text_x + 4, text_y),
                            cv2.FONT_HERSHEY_COMPLEX, 1,
                            (255 - color[0], 255 - color[1], 255 - color[2]),
                            2)

        t1 = time.time()

        cv2.putText(rgb_image, "FPS: %.2f" % (1 / (t1 - t0)), (5, 30),
                    cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 255, 255), 2)

        # cv2.imshow("result",rgb_image)
        outVideo.write(rgb_image)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            outVideo.release()
            cap.release()
            cv2.destroyAllWindows()
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    evaluate_detections(all_boxes, output_dir, dataset)


def evaluate_detections(box_list, output_dir, dataset):
    write_voc_results_file(box_list, dataset)
    do_python_eval(output_dir)


if __name__ == '__main__':
    # torch.backends.cudnn.enabled = True
    # load net
    num_classes = len(labelmap) + 1                      # +1 for background
    net = build_ssd('test', 300, num_classes)            # initialize SSD
    net.load_state_dict(torch.load(args.trained_model))
    net.eval()
    print('Finished loading model!')
    # load data
    dataset = VOCDetection(args.voc_root, [('2007', set_type)],
                           BaseTransform(300, dataset_mean),
                           VOCAnnotationTransform())
    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    # evaluation
    test_net(args.save_folder, net, args.cuda, dataset,
             BaseTransform(net.size, dataset_mean), args.top_k, 300,
             thresh=args.confidence_threshold)