def show_image(path, det=SsdDet()):
    dir = os.listdir(path)
    for dir_name in dir:
        for _, __, files in os.walk(path + dir_name):
            for filename in files:
                image = cv2.imread(path + dir_name + '/' + filename)
                #cv2.imshow("raw", image          )
                det.det_mode(image)
Exemple #2
0
def gen_det_txt(net_dic, img_roots, det=SsdDet()):
    for img_root in img_roots:
        img_lists = os.listdir(img_root)
        img_lists.sort()

        for num, img_name in enumerate(img_lists):

            if img_name.split(".")[-1] == 'xml':
                img = cv2.imread('/home/remo/from_wdh/data/val2017/' +
                                 img_name[:-4] + '.jpg')

                res = det.det_txt(img)
                txt_dirs = res.keys()
                txt_dirs = [
                    '/home/remo/from_wdh/data/predict/' + dirr + '/' +
                    img_name[:-4] + '.txt' for dirr in txt_dirs
                ]
                txts = res.values()
                save_txt(txt_dirs, txts)
Exemple #3
0
def gen_det_txt(net_dic, img_roots, det=SsdDet()):
    for img_root in img_roots:  # cat or dog
        print(img_root)
        img_lists0 = os.listdir(img_root)
        img_lists0 = [img_root + '/' + pa for pa in img_lists0]
        for img_listss in img_lists0:  # zishi
            img_lists = glob.glob(img_listss + '/*')
            print(img_listss.split('/')[-1])
            total = 0
            detout = 0
            for num, img_name in enumerate(img_lists):
                total += 1
                if img_name.split(".")[-1] == 'jpg':
                    img = cv2.imread(img_name)
                    # 2. -------------------检测---------------------------------
                    resss = det.det_txt(img)
                    # print(resss)

                    if resss != 0:
                        detout += 1
            print('nodet, %d  total %d' % (total - detout, total))
Exemple #4
0

def gen_det_txt(net_dic, img_roots, det=SsdDet()):
    for img_root in img_roots:  # cat or dog
        print(img_root)
        img_lists0 = os.listdir(img_root)
        img_lists0 = [img_root + '/' + pa for pa in img_lists0]
        for img_listss in img_lists0:  # zishi
            img_lists = glob.glob(img_listss + '/*')
            print(img_listss.split('/')[-1])
            total = 0
            detout = 0
            for num, img_name in enumerate(img_lists):
                total += 1
                if img_name.split(".")[-1] == 'jpg':
                    img = cv2.imread(img_name)
                    # 2. -------------------检测---------------------------------
                    resss = det.det_txt(img)
                    # print(resss)

                    if resss != 0:
                        detout += 1
            print('nodet, %d  total %d' % (total - detout, total))


if __name__ == '__main__':
    net_dict_info, img_roots = det_models()
    ssd_det = SsdDet()
    ssd_det.det_init(net_dict_info)
    gen_det_txt(net_dict_info, img_roots, ssd_det)
def main(videos,
         img_roots,
         flag_video=False,
         flag_img=False,
         flag_cap=False,
         det=SsdDet()):
    assert flag_video + flag_img + flag_cap <= 1
    if flag_video:
        for video_name in videos:
            print(video_name)
            try:
                show_videos_with_trackbar(video_name, det)
            except:
                print("%s finised!" % video_name[0])
                cv2.destroyAllWindows()
            # if ~ex:
            #     exit()

    if flag_img == True:
        for img_root in img_roots:
            # img_lists = os.listdir(img_root)
            img_lists = glob.glob(img_root + '/*')
            #img_lists.sort()
            for num, img_name in enumerate(img_lists):
                if img_name.split(".")[-1] == 'xml':
                    img = cv2.imread('/home/remo/from_wdh/data/val2017/' +
                                     img_name[:-4] + '.jpg')
                    # 1. ------------------增广操作------------------------------
                    img = aug_img(img)
                    # 2. -------------------检测---------------------------------
                    det.det_mode(img)
                    flag_stop = show_img(det.img_one,
                                         img_name[:-4] + '.jpg',
                                         wait_time=0)
                    if flag_stop:
                        break

                if img_name.split(".")[-1] == 'jpg':
                    print img_name
                    try:
                        img = cv2.imread(img_name)
                        # 1. ------------------增广操作------------------------------
                        img = aug_img(img)
                        # 2. -------------------检测---------------------------------
                        det.det_mode_and_save(img, img_name)
                        #flag_stop = show_img(det.img_one, img_name[:-4]+'.jpg', wait_time=0)
                        #if flag_stop:
                        #break
                    except:
                        continue
    if flag_cap == True:
        cap = cv2.VideoCapture(1)
        cap_frame = 0
        while 1:
            ret, frame_org = cap.read()
            if ret == False: continue
            # 1. ------------------增广操作------------------------------
            frame_org = aug_img(frame_org)
            # 2. -------------------检测---------------------------------
            det.det_mode(frame_org)
            flag_stop = show_img(det.img_one, wait_time=5)
            cap_frame += 1
            if flag_stop:
                break
def show_videos_with_trackbar(video_name, det=SsdDet()):
    video = video_name[0]
    cv2.namedWindow("video", cv2.NORM_HAMMING2)
    cv2.resizeWindow("video", 1820, 980)
    cap = cv2.VideoCapture(video)
    frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  # 获得总帧数
    # print cap.get(cv2.CAP_PROP_FPS) # 获得FPS
    loop_flag = 0
    pos = 0
    cv2.createTrackbar('time', 'video', 0, frames, nothing)  #设置滑动条
    cur_video = "start"
    while 1:
        if loop_flag == pos:  # 视频起始位置
            loop_flag = loop_flag + 1
            cv2.setTrackbarPos('time', 'video', loop_flag)
        else:
            # 设置视频播放位置
            pos = cv2.getTrackbarPos('time', 'video')
            loop_flag = pos
            cap.set(cv2.CAP_PROP_POS_FRAMES, pos)  # 设置当前帧所在位置
        ret, img = cap.read()

        #img = cv2.imread("/home/remo/Desktop/d381eb013ace326d408bda46c6fc36f.jpg")
        if ret == False:
            print("read error ")
            break
        cv2.imshow("raw", img)
        #------------------makeboreder操作------------------------------
        #img = copy_img(img,1)
        # 1. ------------------旋转操作------------------------------
        #img = cv2.transpose(img)
        #img = cv2.flip(img,1)
        # 1. ------------------增广操作------------------------------
        img = aug_img(img)
        #img = img[:,326:639,:]
        # 2. -------------------检测---------------------------------
        det.det_mode(img)
        # 3. --------------- 检测输出 拼接图片-------------------------
        # print(det.imgs_show_all[0].shape)
        # print(det.img_one.shape)
        cv2.imshow("video", det.img_one)
        key = cv2.waitKey(0)
        if key == ord('q') or loop_flag == frames:
            break
        if key == ord('o'):
            wrong_txt = open(
                '/home/remo/Desktop/remo_cat_dog/dog_cat_compare/wrong_video.txt',
                'a+')
            if (video == cur_video):
                wrong_txt.write(' ' + str(loop_flag))
            else:
                wrong_txt.write('\n')
                wrong_txt.write(video + ' ' + str(loop_flag))
            cur_video = video
            #wrong_txt.write(video+' '+str(loop_flag)+'\n')
            wrong_txt.close()
        if key == ord('s'):
            save_v_name = video.split("/")[-1]
            save_path = "%s/%s_%s.jpg" % (
                "/home/remo/Desktop/dog_cat/dog_cat_compare", save_v_name,
                str(loop_flag))
            cv2.imwrite(save_path, det.img_one)
    img_dirs = os.listdir(
        "/home/remo/Desktop/remo_source/Data_CatDog/OtherBackGround_Images/OtherBackGround_Images"
    )
    for dir in img_dirs:
        img_roots2.append(
            "/home/remo/Desktop/remo_source/Data_CatDog/OtherBackGround_Images/OtherBackGround_Images/"
            + dir)
    flag = True
    flag_aug_test = not flag

    flag_video = True
    flag_cap = False
    flag_img = False

    if flag:
        ssd_det = SsdDet()  #构造一个SsDet对象
        ssd_det.det_init(net_dict_info)  #调用初始化函数

        if flag_169:
            ssd_det.flag_169 = True
            ssd_det.flag_916 = False
        elif flag_916:
            ssd_det.flag_169 = False
            ssd_det.flag_916 = True
        main(videos, img_roots2, flag_video, flag_img, flag_cap, ssd_det)

    if flag_aug_test:
        root = "/home/remo/from_wdh/hand/hand_img/"
        list_img = os.listdir(root)
        for img in list_img: