コード例 #1
0
 def __init__(self, show_img=True):
     self.pose_estimator = PoseEstimator()
     self.object_detector = ObjectDetectionYolo()
     self.object_tracker = ObjectTracker()
     self.BBV = BBoxVisualizer()
     self.KPV = KeyPointVisualizer()
     self.IDV = IDVisualizer(with_bbox=False)
     self.img = []
     self.img_black = []
     self.show_img = show_img
コード例 #2
0
ファイル: detect.py プロジェクト: CheungBH/MLForAction
 def __init__(self, path=config.video_path, show_img=True):
     self.pose_estimator = PoseEstimator()
     self.object_detector = ObjectDetectionYolo()
     self.BBV = BBoxVisualizer()
     self.IDV = IDVisualizer()
     self.object_tracker = ObjectTracker()
     self.video_path = path
     self.cap = cv2.VideoCapture(self.video_path)
     self.img = []
     self.img_black = []
     self.show_img = show_img
     self.locator = Locator([1, 2])
コード例 #3
0
 def __init__(self, show_img=True):
     self.object_detector = ObjectDetectionYolo(cfg=opt.yolo_cfg, weight=opt.yolo_weight)
     self.object_tracker = ObjectTracker()
     self.BBV = BBoxVisualizer()
     self.IDV = IDVisualizer(with_bbox=False)
     self.boxes = tensor([])
     self.boxes_scores = tensor([])
     self.img_black = np.array([])
     self.frame = np.array([])
     self.id2bbox = {}
     self.show_img = show_img
     self.CNN_model = CNNInference()
コード例 #4
0
 def __init__(self, video_path):
     # self.BBV = BBoxVisualizer()
     self.cap = cv2.VideoCapture(video_path)
     self.height, self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
     with open("/".join(video_path.split("/")[:-1]) + "/" + video_path.split("/")[-1][:-4] + "_box.txt", "r") as bf:
         self.box_txt = [line[:-1] for line in bf.readlines()]
     with open("/".join(video_path.split("/")[:-1]) + "/" + video_path.split("/")[-1][:-4] + "_kps.txt", "r") as kf:
         self.kps_txt = [line[:-1] for line in kf.readlines()]
     with open("/".join(video_path.split("/")[:-1]) + "/" + video_path.split("/")[-1][:-4] +
               "_kps_score.txt", "r") as ksf:
         self.kps_score_txt = [line[:-1] for line in ksf.readlines()]
     self.IDV = IDVisualizer(with_bbox=True)
     self.KPV = KeyPointVisualizer()
コード例 #5
0
 def __init__(self, show_img=True):
     self.object_detector = ObjectDetectionYolo(cfg=yolo_cfg, weight=yolo_weight)
     self.object_tracker = ObjectTracker()
     self.pose_estimator = PoseEstimator(pose_cfg=pose_cfg, pose_weight=pose_weight)
     self.BBV = BBoxVisualizer()
     self.KPV = KeyPointVisualizer()
     self.IDV = IDVisualizer(with_bbox=False)
     self.boxes = tensor([])
     self.boxes_scores = tensor([])
     self.img_black = np.array([])
     self.frame = np.array([])
     self.id2bbox = {}
     self.kps = {}
     self.kps_score = {}
     self.show_img = show_img
コード例 #6
0
 def __init__(self, resize_size, show_img=True):
     self.object_detector = ObjectDetectionYolo(cfg=config.yolo_cfg, weight=config.yolo_weight)
     self.object_tracker = ObjectTracker()
     self.pose_estimator = PoseEstimator(pose_cfg=config.pose_cfg, pose_weight=config.pose_weight)
     self.BBV = BBoxVisualizer()
     self.KPV = KeyPointVisualizer()
     self.IDV = IDVisualizer()
     self.boxes = tensor([])
     self.boxes_scores = tensor([])
     self.frame = np.array([])
     self.id2bbox = {}
     self.CNN_model = CNNInference()
     self.kps = {}
     self.kps_score = {}
     self.show_img = show_img
     self.resize_size = resize_size
コード例 #7
0
 def __init__(self, model_name, video_path, label_path):
     model_name, video_path, label_path = model_name.replace(
         "\\", "/"), video_path.replace("\\",
                                        "/"), label_path.replace("\\", "/")
     self.tester = self.__get_tester(model_name)
     self.video_name = video_path.split("/")[-1]
     self.cap = cv2.VideoCapture(video_path)
     self.IDV = IDVisualizer(with_bbox=True)
     self.KPV = KeyPointVisualizer()
     self.height, self.width = int(self.cap.get(
         cv2.CAP_PROP_FRAME_HEIGHT)), int(
             self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
     self.kps_dict = defaultdict(list)
     self.KPSP = KPSProcessor(self.height, self.width)
     self.label, self.test_id = self.__get_label(label_path)
     self.coord = []
     self.pred = defaultdict(str)
     self.pred_dict = defaultdict(list)
     self.res = defaultdict(bool)
     self.label_dict = defaultdict(bool)
     with open(
             "/".join(video_path.split("/")[:-1]) + "_txt/" +
             video_path.split("/")[-1][:-4] + "_box.txt", "r") as bf:
         self.box_txt = [line[:-1] for line in bf.readlines()]
     with open(
             "/".join(video_path.split("/")[:-1]) + "_txt/" +
             video_path.split("/")[-1][:-4] + "_kps.txt", "r") as kf:
         self.kps_txt = [line[:-1] for line in kf.readlines()]
     with open(
             "/".join(video_path.split("/")[:-1]) + "_txt/" +
             video_path.split("/")[-1][:-4] + "_kps_score.txt", "r") as ksf:
         self.kps_score_txt = [line[:-1] for line in ksf.readlines()]
     if write_video:
         res_video = "/".join(
             video_path.split("/")[:-1]) + "_" + model_name.split(
                 "/")[-1][:-4] + "/" + self.video_name
         self.out = cv2.VideoWriter(res_video,
                                    cv2.VideoWriter_fourcc(*'XVID'), 10,
                                    store_size)
コード例 #8
0
class VideoProcessor:
    def __init__(self, video_path):
        # self.BBV = BBoxVisualizer()
        self.cap = cv2.VideoCapture(video_path)
        self.height, self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        with open("/".join(video_path.split("/")[:-1]) + "/" + video_path.split("/")[-1][:-4] + "_box.txt", "r") as bf:
            self.box_txt = [line[:-1] for line in bf.readlines()]
        with open("/".join(video_path.split("/")[:-1]) + "/" + video_path.split("/")[-1][:-4] + "_kps.txt", "r") as kf:
            self.kps_txt = [line[:-1] for line in kf.readlines()]
        with open("/".join(video_path.split("/")[:-1]) + "/" + video_path.split("/")[-1][:-4] +
                  "_kps_score.txt", "r") as ksf:
            self.kps_score_txt = [line[:-1] for line in ksf.readlines()]
        self.IDV = IDVisualizer(with_bbox=True)
        self.KPV = KeyPointVisualizer()

    def process_video(self):
        cnt = 0
        while True:
            ret, frame = self.cap.read()
            cnt += 1
            if ret:
                frame = cv2.resize(frame, frame_size)
                id2bbox = str2boxdict(self.box_txt.pop(0))
                id2kps = str2kpsdict(self.kps_txt.pop(0))
                id2kpsScore = str2kpsScoredict(self.kps_score_txt.pop(0))
                if id2bbox is not None:
                    frame = self.IDV.plot_bbox_id(id2bbox, frame)
                if id2kps is not None:
                    kps_tensor, score_tensor = self.KPV.kpsdic2tensor(id2kps, id2kpsScore)
                    frame = self.KPV.vis_ske(frame, kps_tensor, score_tensor)
                    black_frm = self.KPV.vis_ske_black(frame, kps_tensor, score_tensor)

                cv2.imshow("res", frame)
                cv2.waitKey(100)
            else:
                self.cap.release()
                cv2.destroyAllWindows()
                break
コード例 #9
0
ファイル: detect.py プロジェクト: CheungBH/MLForAction
class ImageProcessor(object):
    def __init__(self, path=config.video_path, show_img=True):
        self.pose_estimator = PoseEstimator()
        self.object_detector = ObjectDetectionYolo()
        self.BBV = BBoxVisualizer()
        self.IDV = IDVisualizer()
        self.object_tracker = ObjectTracker()
        self.video_path = path
        self.cap = cv2.VideoCapture(self.video_path)
        self.img = []
        self.img_black = []
        self.show_img = show_img
        self.locator = Locator([1, 2])

    def process(self):
        cnt = 0
        while True:
            ret, frame = self.cap.read()
            if ret:
                frame = cv2.resize(frame, config.frame_size)
                with torch.no_grad():
                    inps, orig_img, boxes, scores, pt1, pt2 = self.object_detector.process(
                        frame)
                    if boxes is not None:
                        # cv2.imshow("bbox", self.BBV.visualize(boxes, copy.deepcopy(frame)))
                        key_points, self.img, self.img_black = self.pose_estimator.process_img(
                            inps, orig_img, boxes, scores, pt1, pt2)
                        if len(key_points) > 0:
                            id2ske_all, id2bbox_all = self.object_tracker.track(
                                boxes, key_points)
                            id2ske, id2bbox = self.locator.locate_user(
                                id2ske_all, id2bbox_all)

                            # process skeleton

                            if self.show_img:
                                cv2.imshow(
                                    "id_bbox_all",
                                    self.IDV.plot_bbox_id(
                                        id2bbox_all, copy.deepcopy(frame)))
                                cv2.imshow(
                                    "id_ske_all",
                                    self.IDV.plot_skeleton_id(
                                        id2ske_all, copy.deepcopy(frame)))
                                cv2.imshow(
                                    "id_bbox_located",
                                    self.IDV.plot_bbox_id(
                                        id2bbox, copy.deepcopy(frame)))
                                cv2.imshow(
                                    "id_ske_located",
                                    self.IDV.plot_skeleton_id(
                                        id2ske, copy.deepcopy(frame)))
                                self.__show_img()

                        else:
                            if self.show_img:
                                self.__show_img()
                    else:
                        # cv2.imshow("bbox", frame)
                        # cv2.imshow("id", frame)
                        self.img, self.img_black = frame, frame
                        if self.show_img:
                            self.__show_img()
                cnt += 1
                print(cnt)
            else:
                self.cap.release()
                cv2.destroyAllWindows()
                break

    def __show_img(self):
        cv2.imshow("result", self.img)
        cv2.moveWindow("result", 1200, 90)
        cv2.imshow("result_black", self.img_black)
        cv2.moveWindow("result_black", 1200, 540)
        cv2.waitKey(1)

    def process_single_person(self, frame):
        frame = cv2.resize(frame, config.frame_size)
        with torch.no_grad():
            inps, orig_img, boxes, scores, pt1, pt2 = self.object_detector.process(
                frame)
            key_points, img, img_black = self.pose_estimator.process_img(
                inps, orig_img, boxes, scores, pt1, pt2)
            return key_points[0], img, img_black

    def process_multiple_person(self, frame):
        frame = cv2.resize(frame, config.frame_size)
        with torch.no_grad():
            inps, orig_img, boxes, scores, pt1, pt2 = self.object_detector.process(
                frame)
            key_points, img, img_black = self.pose_estimator.process_img(
                inps, orig_img, boxes, scores, pt1, pt2)
            if len(key_points) > 0:
                id2ske_all, id2bbox_all = self.object_tracker.track(
                    boxes, key_points)
                id2ske, id2bbox = self.locator.locate_user(
                    id2ske_all, id2bbox_all)
        return id2ske, img, img_black
コード例 #10
0
class ImgProcessor:
    def __init__(self, show_img=True):
        self.pose_estimator = PoseEstimator()
        self.object_detector = ObjectDetectionYolo()
        self.object_tracker = ObjectTracker()
        self.BBV = BBoxVisualizer()
        self.KPV = KeyPointVisualizer()
        self.IDV = IDVisualizer(with_bbox=False)
        self.img = []
        self.img_black = []
        self.show_img = show_img

    def init_sort(self):
        self.object_tracker.init_tracker()

    def __process_kp(self, kps, idx):
        new_kp = []
        for bdp in range(len(kps)):
            for coord in range(2):
                new_kp.append(kps[bdp][coord])
        return {idx: new_kp}

    def process_img(self, frame):
        # frame = cv2.resize(frame, config.frame_size)
        img = copy.deepcopy(frame)
        img_black = cv2.imread('video/black.jpg')
        with torch.no_grad():
            inps, orig_img, boxes, scores, pt1, pt2 = self.object_detector.process(
                frame)

            if boxes is not None:
                key_points, img, black_img = self.pose_estimator.process_img(
                    inps, orig_img, boxes, scores, pt1, pt2)
                if config.plot_bbox:
                    img = self.BBV.visualize(boxes, frame)

                if key_points is not []:
                    id2ske, id2bbox = self.object_tracker.track(
                        boxes, key_points)
                    if config.plot_id:
                        img = self.IDV.plot_bbox_id(id2bbox,
                                                    copy.deepcopy(img))
                        # img = self.IDV.plot_skeleton_id(id2ske, copy.deepcopy(img))

                    if config.track_idx != "all":
                        try:
                            kps = self.__process_kp(id2ske[config.track_idx],
                                                    config.track_idx)
                        except KeyError:
                            kps = {}
                    else:
                        kps = id2ske
                    #
                    # if config.plot_kps:
                    #     vis_kps = self.KPV.dict2ls(kps)
                    #     img = self.KPV.vis_ske(orig_img, vis_kps, kp_score)
                    #     img_black = self.KPV.vis_ske_black(orig_img, vis_kps, kp_score)

                    return kps, img, img_black
                else:
                    return {}, img, img_black
            else:
                return {}, frame, frame
コード例 #11
0
class HumanDetection:
    def __init__(self, show_img=True):
        self.object_detector = ObjectDetectionYolo(cfg=opt.yolo_cfg, weight=opt.yolo_weight)
        self.object_tracker = ObjectTracker()
        self.BBV = BBoxVisualizer()
        self.IDV = IDVisualizer(with_bbox=False)
        self.boxes = tensor([])
        self.boxes_scores = tensor([])
        self.img_black = np.array([])
        self.frame = np.array([])
        self.id2bbox = {}
        self.show_img = show_img
        self.CNN_model = CNNInference()

    def init_sort(self):
        self.object_tracker.init_tracker()

    def clear_res(self):
        self.boxes = tensor([])
        self.boxes_scores = tensor([])
        self.img_black = np.array([])
        self.frame = np.array([])
        self.id2bbox = {}

    def visualize(self):
        self.img_black = cv2.imread('video/black.jpg')
        if config.plot_bbox and self.boxes is not None:
            self.frame = self.BBV.visualize(self.boxes, self.frame, self.boxes_scores)
            # cv2.imshow("cropped", (torch_to_im(inps[0]) * 255))
        if config.plot_id and self.id2bbox is not None:
            self.frame = self.IDV.plot_bbox_id(self.id2bbox, self.frame)
            # frame = self.IDV.plot_skeleton_id(id2ske, copy.deepcopy(img))
        return self.frame, self.img_black

    def process_img(self, frame, gray=False):
        self.clear_res()
        self.frame = frame

        with torch.no_grad():
            if gray:
                gray_img = gray3D(copy.deepcopy(frame))
                box_res = self.object_detector.process(gray_img)
            else:
                box_res = self.object_detector.process(frame)
            self.boxes, self.boxes_scores = self.object_detector.cut_box_score(box_res)

            if box_res is not None:
                self.id2bbox = self.object_tracker.track(box_res)

        return self.id2bbox

    def classify_whole(self):
        out = self.CNN_model.predict(self.img_black)
        idx = out[0].tolist().index(max(out[0].tolist()))
        pred = opt.CNN_class[idx]
        print("The prediction is {}".format(pred))

    def classify(self, frame):
        for box in self.id2bbox.values():
            x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
            x1 = 0 if x1 < 0 else x1
            y1 = 0 if y1 < 0 else y1
            x2 = frame.shape[1] if x2 > frame.shape[1] else x2
            y2 = frame.shape[0] if y2 > frame.shape[0] else y2
            img = np.asarray(frame[y1:y2, x1:x2])
            # cv2.imshow("cut", img)
            # cv2.imwrite("img/tmp/0.jpg", img)
            out = self.CNN_model.predict(img)
            idx = out[0].tolist().index(max(out[0].tolist()))
            pred = opt.CNN_class[idx]
            print(pred)
            text_location = (int((box[0]+box[2])/2)), int((box[1])+50)
            cv2.putText(frame, pred, text_location, cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 255), 2)
コード例 #12
0
class HumanDetection:
    def __init__(self, resize_size, show_img=True):
        self.object_detector = ObjectDetectionYolo(cfg=config.yolo_cfg, weight=config.yolo_weight)
        self.object_tracker = ObjectTracker()
        self.pose_estimator = PoseEstimator(pose_cfg=config.pose_cfg, pose_weight=config.pose_weight)
        self.BBV = BBoxVisualizer()
        self.KPV = KeyPointVisualizer()
        self.IDV = IDVisualizer()
        self.boxes = tensor([])
        self.boxes_scores = tensor([])
        self.frame = np.array([])
        self.id2bbox = {}
        self.CNN_model = CNNInference()
        self.kps = {}
        self.kps_score = {}
        self.show_img = show_img
        self.resize_size = resize_size

    def init(self):
        self.object_tracker.init_tracker()

    def clear_res(self):
        self.boxes = tensor([])
        self.boxes_scores = tensor([])
        self.frame = np.array([])
        self.id2bbox = {}
        self.kps = {}
        self.kps_score = {}

    def visualize(self):
        img_black = np.full((self.resize_size[1], self.resize_size[0], 3), 0).astype(np.uint8)
        if config.plot_bbox and self.boxes is not None:
            self.BBV.visualize(self.boxes, self.frame)
        if config.plot_kps and self.kps is not []:
            self.KPV.vis_ske(self.frame, self.kps, self.kps_score)
            self.KPV.vis_ske_black(img_black, self.kps, self.kps_score)
        if config.plot_id and self.id2bbox is not None:
            self.IDV.plot_bbox_id(self.id2bbox, self.frame)
            self.IDV.plot_skeleton_id(self.kps, self.frame)
        return self.frame, img_black

    def process_img(self, frame, gray=False):
        self.clear_res()
        self.frame = frame

        with torch.no_grad():
            if gray:
                gray_img = gray3D(copy.deepcopy(frame))
                box_res = self.object_detector.process(gray_img)
            else:
                box_res = self.object_detector.process(frame)
            self.boxes, self.boxes_scores = self.object_detector.cut_box_score(box_res)

            if box_res is not None:
                self.id2bbox = self.object_tracker.track(box_res)
                self.id2bbox = eliminate_nan(self.id2bbox)
                boxes = self.object_tracker.id_and_box(self.id2bbox)

                inps, pt1, pt2 = crop_bbox(frame, boxes)
                if inps is not None:
                    kps, kps_score, kps_id = self.pose_estimator.process_img(inps, boxes, pt1, pt2)
                    self.kps, self.kps_score = self.object_tracker.match_kps(kps_id, kps, kps_score)

        return self.kps, self.id2bbox, self.kps_score

    def classify_whole(self, pred_img, show_img):
        pred = self.CNN_model.classify_whole(pred_img, show_img)
        return pred

    def classify(self, pred_img, show_img, id2bbox):
        preds = self.CNN_model.classify(pred_img, id2bbox)
        self.CNN_model.visualize(show_img, preds)
        return preds
コード例 #13
0
class HumanDetection:
    def __init__(self, show_img=True):
        self.object_detector = ObjectDetectionYolo(cfg=yolo_cfg, weight=yolo_weight)
        self.object_tracker = ObjectTracker()
        self.pose_estimator = PoseEstimator(pose_cfg=pose_cfg, pose_weight=pose_weight)
        self.BBV = BBoxVisualizer()
        self.KPV = KeyPointVisualizer()
        self.IDV = IDVisualizer(with_bbox=False)
        self.boxes = tensor([])
        self.boxes_scores = tensor([])
        self.img_black = np.array([])
        self.frame = np.array([])
        self.id2bbox = {}
        self.kps = {}
        self.kps_score = {}
        self.show_img = show_img

    def init_sort(self):
        self.object_tracker.init_tracker()

    def clear_res(self):
        self.boxes = tensor([])
        self.boxes_scores = tensor([])
        self.frame = np.array([])
        self.id2bbox = {}
        self.kps = {}
        self.kps_score = {}

    def visualize(self):
        img_black = cv2.imread('video/black.jpg')
        if config.plot_bbox and self.boxes is not None:
            self.frame = self.BBV.visualize(self.boxes, self.frame, self.boxes_scores)
            # cv2.imshow("cropped", (torch_to_im(inps[0]) * 255))
        if config.plot_kps and self.kps is not []:
            self.frame = self.KPV.vis_ske(self.frame, self.kps, self.kps_score)
            img_black = self.KPV.vis_ske_black(self.frame, self.kps, self.kps_score)
        if config.plot_id and self.id2bbox is not None:
            self.frame = self.IDV.plot_bbox_id(self.id2bbox, self.frame)
            # frame = self.IDV.plot_skeleton_id(id2ske, copy.deepcopy(img))
        return self.frame, img_black

    def process_img(self, frame, gray=False):
        self.clear_res()
        self.frame = frame

        with torch.no_grad():
            if gray:
                gray_img = gray3D(copy.deepcopy(frame))
                box_res = self.object_detector.process(gray_img)
            else:
                box_res = self.object_detector.process(frame)
            self.boxes, self.boxes_scores = self.object_detector.cut_box_score(box_res)

            if box_res is not None:
                self.id2bbox = self.object_tracker.track(box_res)
                boxes = self.object_tracker.id_and_box(self.id2bbox)

                inps, pt1, pt2 = crop_bbox(frame, boxes)
                kps, kps_score, kps_id = self.pose_estimator.process_img(inps, boxes, pt1, pt2)
                self.kps, self.kps_score = self.object_tracker.match_kps(kps_id, kps, kps_score)

        return self.kps, self.id2bbox, self.kps_score
コード例 #14
0
class Tester:
    def __init__(self, model_name, video_path, label_path):
        model_name, video_path, label_path = model_name.replace(
            "\\", "/"), video_path.replace("\\",
                                           "/"), label_path.replace("\\", "/")
        self.tester = self.__get_tester(model_name)
        self.video_name = video_path.split("/")[-1]
        self.cap = cv2.VideoCapture(video_path)
        self.IDV = IDVisualizer(with_bbox=True)
        self.KPV = KeyPointVisualizer()
        self.height, self.width = int(self.cap.get(
            cv2.CAP_PROP_FRAME_HEIGHT)), int(
                self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.kps_dict = defaultdict(list)
        self.KPSP = KPSProcessor(self.height, self.width)
        self.label, self.test_id = self.__get_label(label_path)
        self.coord = []
        self.pred = defaultdict(str)
        self.pred_dict = defaultdict(list)
        self.res = defaultdict(bool)
        self.label_dict = defaultdict(bool)
        with open(
                "/".join(video_path.split("/")[:-1]) + "_txt/" +
                video_path.split("/")[-1][:-4] + "_box.txt", "r") as bf:
            self.box_txt = [line[:-1] for line in bf.readlines()]
        with open(
                "/".join(video_path.split("/")[:-1]) + "_txt/" +
                video_path.split("/")[-1][:-4] + "_kps.txt", "r") as kf:
            self.kps_txt = [line[:-1] for line in kf.readlines()]
        with open(
                "/".join(video_path.split("/")[:-1]) + "_txt/" +
                video_path.split("/")[-1][:-4] + "_kps_score.txt", "r") as ksf:
            self.kps_score_txt = [line[:-1] for line in ksf.readlines()]
        if write_video:
            res_video = "/".join(
                video_path.split("/")[:-1]) + "_" + model_name.split(
                    "/")[-1][:-4] + "/" + self.video_name
            self.out = cv2.VideoWriter(res_video,
                                       cv2.VideoWriter_fourcc(*'XVID'), 10,
                                       store_size)

    def __get_label(self, path):
        with open(path, "r") as lf:
            labels, ids = defaultdict(list), []
            for line in lf.readlines():
                [idx, label] = line[:-1].split(":")
                labels[idx] = [l for l in label.split(" ")]
                ids.append(idx)
        return labels, ids

    def __get_tester(self, model):
        if "ConvLSTM" in model:
            return ConvLSTMPredictor(model, len(cls))
        if "BiLSTM" in model:
            return BiLSTMPredictor(model, len(cls))
        if "ConvGRU" in model:
            return ConvGRUPredictor(model, len(cls))
        if 'LSTM' in model:
            if lstm:
                return LSTMPredictor(model)
            else:
                print("lstm is not usable")
        if "TCN" in model:
            return TCNPredictor(model, len(cls))

    def __detect_kps(self):
        refresh_idx = []
        for k, v in self.kps_dict.items():
            if len(v) == seq_length:
                pred = self.tester.predict(np.array(v).astype(np.float32))
                self.pred[k] = cls[pred]
                self.pred_dict[str(k)].append(cls[pred])
                # print("Predicting id {}".format(k))
                refresh_idx.append(k)
        for idx in refresh_idx:
            self.kps_dict[idx] = []

    def __put_pred(self, img):
        for idx, (k, v) in enumerate(self.pred.items()):
            cv2.putText(img, "id{}: {}".format(k,
                                               v), (30, int(40 * (idx + 1))),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        return img

    def __put_cnt(self, img):
        for idx, (k, v) in enumerate(self.kps_dict.items()):
            cv2.putText(img, "id{} cnt: {}".format(k, len(v)),
                        (300, int(30 * (idx + 2))), cv2.FONT_HERSHEY_SIMPLEX,
                        1, (0, 255, 0), 2)
        return img

    def __get_target_pred(self):
        return {
            key: value
            for key, value in self.pred_dict.items() if key in self.test_id
        }

    def __compare(self):
        self.pred_dict = self.__get_target_pred()
        assert self.pred_dict.keys() == self.label.keys()
        for k in self.pred_dict.keys():
            label, pred = self.label[k], self.pred_dict[k]
            assert len(label) == len(pred)
            for idx, (l, p) in enumerate(zip(label, pred)):
                if l != "pass":
                    sample_str = self.video_name[:
                                                 -4] + "_id{}_frame{}-{}".format(
                                                     k, 30 * idx, 30 *
                                                     (idx + 1) - 1)
                    self.res[sample_str] = l == p
                    self.label_dict[sample_str] = l

    def test(self):
        cnt = 0
        while True:
            cnt += 1
            ret, frame = self.cap.read()
            if ret:
                frame = cv2.resize(frame, config.size)
                id2bbox = str2boxdict(self.box_txt.pop(0))
                id2kps = str2kpsdict(self.kps_txt.pop(0))
                id2kpsScore = str2kpsScoredict(self.kps_score_txt.pop(0))
                if id2bbox is not None:
                    frame = self.IDV.plot_bbox_id(id2bbox, frame)
                if id2kps is not None:
                    kps_tensor, score_tensor = self.KPV.kpsdic2tensor(
                        id2kps, id2kpsScore)
                    frame = self.KPV.vis_ske(frame, kps_tensor, score_tensor)
                if id2kps:
                    for key, v in id2kps.items():
                        # coord = self.__normalize_coordinates(kps[key])
                        coord = self.KPSP.process_kp(v)
                        self.kps_dict[key].append(coord)
                    self.__detect_kps()

                img = frame
                img = cv2.resize(img, store_size)
                img = self.__put_pred(img)
                cv2.putText(img, "Frame cnt: {}".format(cnt), (300, 30),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
                img = self.__put_cnt(img)
                cv2.imshow("res", img)
                cv2.waitKey(2)
                if write_video:
                    self.out.write(img)
            else:
                self.cap.release()
                # IP.init_sort()
                self.out.release()
                cv2.destroyAllWindows()
                break
        self.__compare()
        return self.res