示例#1
0
def validate(m, heatmap_to_coord, batch_size=20):
    det_dataset = builder.build_dataset(cfg.DATASET.TEST,
                                        preset_cfg=cfg.DATA_PRESET,
                                        train=False,
                                        opt=opt)
    eval_joints = det_dataset.EVAL_JOINTS

    det_loader = torch.utils.data.DataLoader(det_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=20,
                                             drop_last=False)
    kpt_json = []
    m.eval()

    for inps, crop_bboxes, bboxes, img_ids, scores, imghts, imgwds in tqdm(
            det_loader, dynamic_ncols=True):
        if isinstance(inps, list):
            inps = [inp.cuda() for inp in inps]
        else:
            inps = inps.cuda()
        output = m(inps)
        if opt.flip_test:
            if isinstance(inps, list):
                inps_flip = [flip(inp) for inp in inps]
            else:
                inps_flip = flip(inps)
            output_flip = flip_heatmap(m(inps_flip),
                                       det_dataset.joint_pairs,
                                       shift=True)
            output = (output + output_flip) / 2

        pred = output.cpu().data.numpy()
        assert pred.ndim == 4
        pred = pred[:, eval_joints, :, :]

        for i in range(output.shape[0]):
            bbox = crop_bboxes[i].tolist()
            pose_coords, pose_scores = heatmap_to_coord(
                pred[i][det_dataset.EVAL_JOINTS], bbox)

            keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
            keypoints = keypoints.reshape(-1).tolist()

            data = dict()
            data['bbox'] = bboxes[i, 0].tolist()
            data['image_id'] = int(img_ids[i])
            data['score'] = float(scores[i] + np.mean(pose_scores) +
                                  np.max(pose_scores))
            data['category_id'] = 1
            data['keypoints'] = keypoints

            kpt_json.append(data)

    with open('./exp/json/validate_rcnn_kpt.json', 'w') as fid:
        json.dump(kpt_json, fid)
    res = evaluate_mAP('./exp/json/validate_rcnn_kpt.json',
                       ann_type='keypoints')
    return res['AP']
示例#2
0
def validate_gt(pose_model, cfg, heatmap_to_coord, batch_size=20):
    gt_val_dataset = builder.build_dataset(cfg.DATASET.VAL, preset_cfg=cfg.DATA_PRESET, train=False)
    eval_joints = gt_val_dataset.EVAL_JOINTS

    gt_val_loader = torch.utils.data.DataLoader(
        gt_val_dataset, batch_size=batch_size, shuffle=False, num_workers=20, drop_last=False)
    kpt_json = []
    # m.eval()

    norm_type = cfg.LOSS.get('NORM_TYPE', None)
    hm_size = cfg.DATA_PRESET.HEATMAP_SIZE

    for inps, labels, label_masks, img_ids, bboxes in tqdm(gt_val_loader, dynamic_ncols=True):
        if isinstance(inps, list):
            inps = [inp.cuda() for inp in inps]
        else:
            inps = inps.cuda()
        # 使用static shape的fast pose的engine进行推理
        output = pose_model.detect_context(inps.cpu().numpy())
        # 一定要转到cuda上,结果才不会出错
        output = torch.from_numpy(output).to(opt.device)
        if opt.flip_test:
            if isinstance(inps, list):
                inps_flip = [flip(inp).cuda() for inp in inps]
            else:
                inps_flip = flip(inps).cuda()
            inps_flip = pose_model.detect_context(inps_flip.cpu().numpy())
            inps_flip = torch.from_numpy(inps_flip).to(opt.device)
            # output_flip = flip_heatmap(m(inps_flip), gt_val_dataset.joint_pairs, shift=True)
            output_flip = flip_heatmap(inps_flip, gt_val_dataset.joint_pairs, shift=True)
            pred_flip = output_flip[:, eval_joints, :, :]
        else:
            output_flip = None

        pred = output
        assert pred.dim() == 4
        pred = pred[:, eval_joints, :, :]
        for i in range(bboxes.shape[0]):
            bbox = bboxes[i].tolist()
            pose_coords, pose_scores = heatmap_to_coord(
                pred[i], bbox, hms_flip=pred_flip[i], hm_shape=hm_size, norm_type=norm_type)

            keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
            keypoints = keypoints.reshape(-1).tolist()

            data = dict()
            data['bbox'] = bboxes[i].tolist()
            data['image_id'] = int(img_ids[i])
            data['score'] = float(np.mean(pose_scores) + np.max(pose_scores))
            data['category_id'] = 1
            data['keypoints'] = keypoints

            kpt_json.append(data)

    with open('./exp/json/trt/validate_gt_kpt.json', 'w') as fid:
        json.dump(kpt_json, fid)
    res = evaluate_mAP('./exp/json/trt/validate_gt_kpt.json', ann_type='keypoints',
                       ann_file=os.path.join(cfg.DATASET.VAL.ROOT, cfg.DATASET.VAL.ANN))
    return res
示例#3
0
def validate(m, heatmap_to_coord, batch_size=20):
    det_dataset = builder.build_dataset(cfg.DATASET.TEST, preset_cfg=cfg.DATA_PRESET, detector_cfg=cfg.DETECTOR, train=False, opt=opt)
    eval_joints = det_dataset.EVAL_JOINTS

    det_loader = torch.utils.data.DataLoader(
        det_dataset, batch_size=batch_size, shuffle=False, num_workers=20, drop_last=False)
    kpt_json = []
    m.eval()

    norm_type = cfg.LOSS.get('NORM_TYPE', None)
    hm_size = cfg.DATA_PRESET.HEATMAP_SIZE

    for inps, crop_bboxes, bboxes, img_ids, scores, imghts, imgwds in tqdm(det_loader, dynamic_ncols=True):
        if isinstance(inps, list):
            inps = [inp.cuda() for inp in inps]
        else:
            inps = inps.cuda()
        output = m(inps)
        if opt.flip_test:
            if isinstance(inps, list):
                inps_flip = [flip(inp).cuda() for inp in inps]
            else:
                inps_flip = flip(inps).cuda()
            output_flip = flip_heatmap(m(inps_flip), det_dataset.joint_pairs, shift=True)
            pred_flip = output_flip[:, eval_joints, :, :]
        else:
            output_flip = None

        pred = output
        assert pred.dim() == 4
        pred = pred[:, eval_joints, :, :]

        for i in range(output.shape[0]):
            bbox = crop_bboxes[i].tolist()
            pose_coords, pose_scores = heatmap_to_coord(
                pred[i], bbox, hms_flip=pred_flip[i], hm_shape=hm_size, norm_type=norm_type)

            keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
            keypoints = keypoints.reshape(-1).tolist()

            data = dict()
            data['bbox'] = bboxes[i, 0].tolist()
            data['image_id'] = int(img_ids[i])
            data['area'] = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
            # data['score'] = float(scores[i] + np.mean(pose_scores) + np.max(pose_scores))
            data['score'] = float(scores[i])
            data['category_id'] = 1
            data['keypoints'] = keypoints

            kpt_json.append(data)

    kpt_json = oks_pose_nms(kpt_json)

    with open('./exp/json/validate_rcnn_kpt.json', 'w') as fid:
        json.dump(kpt_json, fid)
    res = evaluate_mAP('./exp/json/validate_rcnn_kpt.json', ann_type='keypoints')
    return res['AP']
示例#4
0
文件: demo.py 项目: IMBINGO95/FairMOT
                inps = inps.to(args.device)
                datalen = inps.size(0)
                leftover = 0
                if (datalen) % batchSize:
                    leftover = 1
                num_batches = datalen // batchSize + leftover
                hm = []
                for j in range(num_batches):
                    inps_j = inps[j * batchSize:min((j + 1) *
                                                    batchSize, datalen)]
                    if args.flip:
                        inps_j = torch.cat((inps_j, flip(inps_j)))
                    hm_j = pose_model(inps_j)
                    if args.flip:
                        hm_j_flip = flip_heatmap(hm_j[int(len(hm_j) / 2):],
                                                 det_loader.joint_pairs,
                                                 shift=True)
                        hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 21
                    hm.append(hm_j)
                hm = torch.cat(hm)

                vis_img, keypoints = heat_to_map(orig_img, hm, cropped_boxes)

                cv2.imwrite(os.path.join(outputpath, img_name), vis_img)
                data.append({'img_name': img_name, 'keypoints': keypoints})

                if args.profile:
                    ckpt_time, pose_time = getTime(ckpt_time)
                    runtime_profile['pt'].append(pose_time)
                # hm_ = hm_.cpu()
                # print('cropped_boxes',cropped_boxes)
示例#5
0
 inps = inps.to(args.device)
 datalen = inps.size(0)
 leftover = 0
 if (datalen) % batchSize:
     leftover = 1
 num_batches = datalen // batchSize + leftover
 hm = []
 for j in range(num_batches):
     inps_j = inps[j * batchSize:min((j + 1) *
                                     batchSize, datalen)]
     if args.flip:
         inps_j = torch.cat((inps_j, flip(inps_j)))
     hm_j = pose_model(inps_j)
     if args.flip:
         hm_j_flip = flip_heatmap(hm_j[int(len(hm_j) / 2):],
                                  pose_dataset.joint_pairs,
                                  shift=True)
         hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 2
     hm.append(hm_j)
 hm = torch.cat(hm)
 if args.profile:
     ckpt_time, pose_time = getTime(ckpt_time)
     runtime_profile['pt'].append(pose_time)
 if args.pose_track:
     boxes, scores, ids, hm, cropped_boxes = track(
         tracker, args, orig_img, inps, boxes, hm,
         cropped_boxes, im_name, scores)
 hm = hm.cpu()
 writer.save(boxes, scores, ids, hm, cropped_boxes, orig_img,
             im_name)
 if args.profile:
示例#6
0
    def process(self, im_name, image):
        # Init data writer
        self.writer = DataWriter(self.cfg, self.args)

        runtime_profile = {'dt': [], 'pt': [], 'pn': []}
        pose = None
        try:
            start_time = getTime()
            with torch.no_grad():
                (inps, orig_img, im_name, boxes, scores, ids,
                 cropped_boxes) = self.det_loader.process(im_name,
                                                          image).read()
                if orig_img is None:
                    raise Exception("no image is given")
                if boxes is None or boxes.nelement() == 0:
                    if self.args.profile:
                        ckpt_time, det_time = getTime(start_time)
                        runtime_profile['dt'].append(det_time)
                    self.writer.save(None, None, None, None, None, orig_img,
                                     im_name)
                    if self.args.profile:
                        ckpt_time, pose_time = getTime(ckpt_time)
                        runtime_profile['pt'].append(pose_time)
                    pose = self.writer.start()
                    if self.args.profile:
                        ckpt_time, post_time = getTime(ckpt_time)
                        runtime_profile['pn'].append(post_time)
                else:
                    if self.args.profile:
                        ckpt_time, det_time = getTime(start_time)
                        runtime_profile['dt'].append(det_time)
                    # Pose Estimation
                    inps = inps.to(self.args.device)
                    if self.args.flip:
                        inps = torch.cat((inps, flip(inps)))
                    hm = self.pose_model(inps)
                    if self.args.flip:
                        hm_flip = flip_heatmap(hm[int(len(hm) / 2):],
                                               self.pose_dataset.joint_pairs,
                                               shift=True)
                        hm = (hm[0:int(len(hm) / 2)] + hm_flip) / 2
                    if self.args.profile:
                        ckpt_time, pose_time = getTime(ckpt_time)
                        runtime_profile['pt'].append(pose_time)
                    hm = hm.cpu()
                    self.writer.save(boxes, scores, ids, hm, cropped_boxes,
                                     orig_img, im_name)
                    pose = self.writer.start()
                    if self.args.profile:
                        ckpt_time, post_time = getTime(ckpt_time)
                        runtime_profile['pn'].append(post_time)

            if self.args.profile:
                print(
                    'det time: {dt:.4f} | pose time: {pt:.4f} | post processing: {pn:.4f}'
                    .format(dt=np.mean(runtime_profile['dt']),
                            pt=np.mean(runtime_profile['pt']),
                            pn=np.mean(runtime_profile['pn'])))
            print('===========================> Finish Model Running.')
        except Exception as e:
            print(repr(e))
            print(
                'An error as above occurs when processing the images, please check it'
            )
            pass
        except KeyboardInterrupt:
            print('===========================> Finish Model Running.')

        return pose
示例#7
0
                datalen = inps.size(0)
                leftover = 0
                if (datalen) % batchSize:
                    leftover = 1
                num_batches = datalen // batchSize + leftover
                hm = []
                for j in range(num_batches):
                    inps_j = inps[j * batchSize:min((j + 1) *
                                                    batchSize, datalen)]
                    if args.flip:
                        inps_j = torch.cat((inps_j, flip(inps_j)))
                    hm_j = pose_model(inps_j)
                    if args.flip:
                        hm_j_flip = flip_heatmap(
                            hm_j[int(len(hm_j) / 2):],
                            [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
                             [13, 14], [15, 16]],
                            shift=True)
                        hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 2
                    hm.append(hm_j)
                hm = torch.cat(hm)
                if args.profile:
                    ckpt_time, pose_time = getTime(ckpt_time)
                    runtime_profile['pt'].append(pose_time)
                hm = hm.cpu()
                writer.save(boxes, scores, ids, hm, cropped_boxes, orig_img,
                            os.path.basename(im_name))

                if args.profile:
                    ckpt_time, post_time = getTime(ckpt_time)
                    runtime_profile['pn'].append(post_time)
示例#8
0
def validate_gt(cfg,
                batchsize=1,
                engine_file_path=None,
                m=None,
                heatmap_to_coord=None):

    gt_val_dataset = builder.build_dataset(cfg.DATASET.VAL,
                                           preset_cfg=cfg.DATA_PRESET,
                                           train=False)
    eval_joints = gt_val_dataset.EVAL_JOINTS

    gt_val_loader = torch.utils.data.DataLoader(gt_val_dataset,
                                                batch_size=batchsize,
                                                shuffle=False,
                                                num_workers=20,
                                                drop_last=True)
    kpt_json = []
    if m:
        m.eval()

    norm_type = cfg.LOSS.get('NORM_TYPE', None)
    hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
    average_time = 0
    pytorch_all_infer_time = 0
    trt_all_infer_time = 0
    data_num = 0

    for inps, labels, label_masks, img_ids, bboxes in tqdm(gt_val_loader,
                                                           dynamic_ncols=True):
        data_num += 1
        if engine_file_path:

            # hm_data = []
            inps = inps.numpy()
            np.copyto(inputs[0].host, inps.ravel())

            trt_infer_before_time = count_time()
            trt_outputs = trt_common.do_inference(context,
                                                  bindings=bindings,
                                                  inputs=inputs,
                                                  outputs=outputs,
                                                  stream=stream)
            trt_infer_after_time = count_time()

            trt_all_infer_time += trt_infer_after_time - trt_infer_before_time

            pred = trt_outputs[0].reshape(-1, 17, 64, 48)
        else:
            if isinstance(inps, list):
                inps = [inp.cuda() for inp in inps]
            else:
                inps = inps.cuda()

            pytorch_infer_before_time = count_time()
            output = m(inps)
            pytorch_infer_after_time = count_time()
            pytorch_all_infer_time += pytorch_infer_after_time - pytorch_infer_before_time

            if opt.flip_test:
                if isinstance(inps, list):
                    inps_flip = [flip(inp).cuda() for inp in inps]
                else:
                    inps_flip = flip(inps).cuda()
                output_flip = flip_heatmap(m(inps_flip),
                                           gt_val_dataset.joint_pairs,
                                           shift=True)
                pred_flip = output_flip[:, eval_joints, :, :]
            else:
                output_flip = None

            pred = output
            assert pred.dim() == 4
            pred = pred[:, eval_joints, :, :]

        for i in range(pred.shape[0]):  #后处理过程
            bbox = bboxes[i].tolist()
            if engine_file_path:
                pose_coords, pose_scores = heatmap_to_coord_simple(
                    pred[i], bbox, hm_shape=hm_size, norm_type=norm_type)
            else:
                pose_coords, pose_scores = heatmap_to_coord(
                    pred[i], bbox, hm_shape=hm_size, norm_type=norm_type)

            keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
            keypoints = keypoints.reshape(-1).tolist()

            data = dict()
            data['bbox'] = bboxes[i].tolist()
            data['image_id'] = int(img_ids[i])
            data['score'] = float(np.mean(pose_scores) + np.max(pose_scores))
            data['category_id'] = 1
            data['keypoints'] = keypoints

            kpt_json.append(data)
    if engine_file_path:
        average_time = float(trt_all_infer_time / data_num)
    else:
        average_time = (pytorch_all_infer_time / data_num)

    print("average_time:", average_time)
    res_file = r"data/coco/res.json"
    with open(res_file, 'w') as F:
        json.dump(kpt_json, F)
    res = evaluate_mAP(res_file,
                       ann_type='keypoints',
                       ann_file=os.path.join(cfg.DATASET.VAL.ROOT,
                                             cfg.DATASET.VAL.ANN))
    return res, average_time
示例#9
0
    def start(self):
        parser = argparse.ArgumentParser(description='AlphaPose Demo')
        parser.add_argument(
            '--cfg',
            type=str,
            required=False,
            help='experiment configure file name',
            default=
            "./AlphaPose/configs/coco/resnet/256x192_res50_lr1e-3_1x.yaml")
        parser.add_argument(
            '--checkpoint',
            type=str,
            required=False,
            help='checkpoint file name',
            default="./AlphaPose/pretrained_models/fast_res50_256x192.pth")
        parser.add_argument('--sp',
                            default=False,
                            action='store_true',
                            help='Use single process for pytorch')
        parser.add_argument('--detector',
                            dest='detector',
                            help='detector name',
                            default="yolo")
        parser.add_argument('--detfile',
                            dest='detfile',
                            help='detection result file',
                            default="")
        parser.add_argument('--indir',
                            dest='inputpath',
                            help='image-directory',
                            default="./media/img")
        parser.add_argument('--list',
                            dest='inputlist',
                            help='image-list',
                            default="")
        parser.add_argument('--image',
                            dest='inputimg',
                            help='image-name',
                            default="")
        parser.add_argument('--outdir',
                            dest='outputpath',
                            help='output-directory',
                            default="./output")
        parser.add_argument('--save_img',
                            default=True,
                            action='store_true',
                            help='save result as image')
        parser.add_argument('--vis',
                            default=False,
                            action='store_true',
                            help='visualize image')
        parser.add_argument('--showbox',
                            default=False,
                            action='store_true',
                            help='visualize human bbox')
        parser.add_argument('--profile',
                            default=False,
                            action='store_true',
                            help='add speed profiling at screen output')
        parser.add_argument(
            '--format',
            type=str,
            help=
            'save in the format of cmu or coco or openpose, option: coco/cmu/open',
            default="open")
        parser.add_argument('--min_box_area',
                            type=int,
                            default=0,
                            help='min box area to filter out')
        parser.add_argument('--detbatch',
                            type=int,
                            default=1,
                            help='detection batch size PER GPU')
        parser.add_argument('--posebatch',
                            type=int,
                            default=30,
                            help='pose estimation maximum batch size PER GPU')
        parser.add_argument(
            '--eval',
            dest='eval',
            default=False,
            action='store_true',
            help=
            'save the result json as coco format, using image index(int) instead of image name(str)'
        )
        parser.add_argument(
            '--gpus',
            type=str,
            dest='gpus',
            default="0",
            help=
            'choose which cuda device to use by index and input comma to use multi gpus, e.g. 0,1,2,3. (input -1 for cpu only)'
        )
        parser.add_argument(
            '--qsize',
            type=int,
            dest='qsize',
            default=1024,
            help=
            'the length of result buffer, where reducing it will lower requirement of cpu memory'
        )
        parser.add_argument('--flip',
                            default=False,
                            action='store_true',
                            help='enable flip testing')
        parser.add_argument('--debug',
                            default=False,
                            action='store_true',
                            help='print detail information')
        """----------------------------- Video options -----------------------------"""
        parser.add_argument('--video',
                            dest='video',
                            help='video-name',
                            default="")
        parser.add_argument('--webcam',
                            dest='webcam',
                            type=int,
                            help='webcam number',
                            default=-1)
        parser.add_argument('--save_video',
                            dest='save_video',
                            help='whether to save rendered video',
                            default=False,
                            action='store_true')
        parser.add_argument('--vis_fast',
                            dest='vis_fast',
                            help='use fast rendering',
                            action='store_true',
                            default=False)
        """----------------------------- Tracking options -----------------------------"""
        parser.add_argument('--pose_flow',
                            dest='pose_flow',
                            help='track humans in video with PoseFlow',
                            action='store_true',
                            default=False)
        parser.add_argument('--pose_track',
                            dest='pose_track',
                            help='track humans in video with reid',
                            action='store_true',
                            default=True)

        args = parser.parse_args()
        cfg = update_config(args.cfg)

        if platform.system() == 'Windows':
            args.sp = True

        args.gpus = [int(i) for i in args.gpus.split(',')
                     ] if torch.cuda.device_count() >= 1 else [-1]
        args.device = torch.device(
            "cuda:" + str(args.gpus[0]) if args.gpus[0] >= 0 else "cpu")
        args.detbatch = args.detbatch * len(args.gpus)
        args.posebatch = args.posebatch * len(args.gpus)
        args.tracking = args.pose_track or args.pose_flow or args.detector == 'tracker'

        if not args.sp:
            torch.multiprocessing.set_start_method('forkserver', force=True)
            torch.multiprocessing.set_sharing_strategy('file_system')

        def check_input():
            # for wecam
            if args.webcam != -1:
                args.detbatch = 1
                return 'webcam', int(args.webcam)

            # for video
            if len(args.video):
                if os.path.isfile(args.video):
                    videofile = args.video
                    return 'video', videofile
                else:
                    raise IOError(
                        'Error: --video must refer to a video file, not directory.'
                    )

            # for detection results
            if len(args.detfile):
                if os.path.isfile(args.detfile):
                    detfile = args.detfile
                    return 'detfile', detfile
                else:
                    raise IOError(
                        'Error: --detfile must refer to a detection json file, not directory.'
                    )

            # for images
            if len(args.inputpath) or len(args.inputlist) or len(
                    args.inputimg):
                inputpath = args.inputpath
                inputlist = args.inputlist
                inputimg = args.inputimg

                if len(inputlist):
                    im_names = open(inputlist, 'r').readlines()
                elif len(inputpath) and inputpath != '/':
                    for root, dirs, files in os.walk(inputpath):
                        im_names = files
                    im_names = natsort.natsorted(im_names)
                elif len(inputimg):
                    args.inputpath = os.path.split(inputimg)[0]
                    im_names = [os.path.split(inputimg)[1]]

                return 'image', im_names

            else:
                raise NotImplementedError

        def print_finish_info():
            print('===========================> Finish Model Running.')
            if (args.save_img or args.save_video) and not args.vis_fast:
                print(
                    '===========================> Rendering remaining images in the queue...'
                )
                print(
                    '===========================> If this step takes too long, you can enable the --vis_fast flag to use fast rendering (real-time).'
                )

        def loop():
            n = 0
            while True:
                yield n
                n += 1

        # dirList = os.listdir(args.inputpath)
        # inDir = args.inputpath
        # outDir = args.outputpath
        # for i in dirList :
        mode, input_source = check_input()
        if not os.path.exists(args.outputpath):
            os.makedirs(args.outputpath)

        # Load detection loader
        if mode == 'webcam':
            det_loader = WebCamDetectionLoader(input_source,
                                               get_detector(args), cfg, args)
            det_worker = det_loader.start()
        elif mode == 'detfile':
            det_loader = FileDetectionLoader(input_source, cfg, args)
            det_worker = det_loader.start()
        else:
            det_loader = DetectionLoader(input_source,
                                         get_detector(args),
                                         cfg,
                                         args,
                                         batchSize=args.detbatch,
                                         mode=mode,
                                         queueSize=args.qsize)
            det_worker = det_loader.start()

        # Load pose model
        pose_model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)

        print(f'Loading pose model from {args.checkpoint}...')
        pose_model.load_state_dict(
            torch.load(args.checkpoint, map_location=args.device))
        pose_dataset = builder.retrieve_dataset(cfg.DATASET.TRAIN)
        if args.pose_track:
            tracker = Tracker(tcfg, args)
        if len(args.gpus) > 1:
            pose_model = torch.nn.DataParallel(pose_model,
                                               device_ids=args.gpus).to(
                                                   args.device)
        else:
            pose_model.to(args.device)
        pose_model.eval()

        runtime_profile = {'dt': [], 'pt': [], 'pn': []}

        # Init data writer
        queueSize = 2 if mode == 'webcam' else args.qsize
        if args.save_video and mode != 'image':
            from alphapose.utils.writer import DEFAULT_VIDEO_SAVE_OPT as video_save_opt
            if mode == 'video':
                video_save_opt['savepath'] = os.path.join(
                    args.outputpath,
                    'AlphaPose_' + os.path.basename(input_source))
            else:
                video_save_opt['savepath'] = os.path.join(
                    args.outputpath,
                    'AlphaPose_webcam' + str(input_source) + '.mp4')
            video_save_opt.update(det_loader.videoinfo)
            writer = DataWriter(cfg,
                                args,
                                save_video=True,
                                video_save_opt=video_save_opt,
                                queueSize=queueSize).start()
        else:
            writer = DataWriter(cfg,
                                args,
                                save_video=False,
                                queueSize=queueSize).start()

        if mode == 'webcam':
            print('Starting webcam demo, press Ctrl + C to terminate...')
            sys.stdout.flush()
            im_names_desc = tqdm(loop())
        else:
            data_len = det_loader.length
            im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

        batchSize = args.posebatch
        if args.flip:
            batchSize = int(batchSize / 2)
        try:
            self.percentage[2] = '관절정보 분석중'
            for i in range(len(im_names_desc)):
                start_time = getTime()
                # print(start_time)
                self.percentage[0] += 1
                #
                with torch.no_grad():
                    (inps, orig_img, im_name, boxes, scores, ids,
                     cropped_boxes) = det_loader.read()
                    if orig_img is None:
                        break
                    if boxes is None or boxes.nelement() == 0:
                        writer.save(None, None, None, None, None, orig_img,
                                    im_name)
                        continue
                    if args.profile:
                        ckpt_time, det_time = getTime(start_time)
                        runtime_profile['dt'].append(det_time)
                    # Pose Estimation
                    inps = inps.to(args.device)
                    datalen = inps.size(0)
                    leftover = 0
                    if (datalen) % batchSize:
                        leftover = 1
                    num_batches = datalen // batchSize + leftover
                    hm = []
                    for j in range(num_batches):
                        inps_j = inps[j * batchSize:min((j + 1) *
                                                        batchSize, datalen)]
                        if args.flip:
                            inps_j = torch.cat((inps_j, flip(inps_j)))
                        hm_j = pose_model(inps_j)
                        if args.flip:
                            hm_j_flip = flip_heatmap(hm_j[int(len(hm_j) / 2):],
                                                     pose_dataset.joint_pairs,
                                                     shift=True)
                            hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 2
                        hm.append(hm_j)
                    hm = torch.cat(hm)
                    if args.profile:
                        ckpt_time, pose_time = getTime(ckpt_time)
                        runtime_profile['pt'].append(pose_time)
                    if args.pose_track:
                        boxes, scores, ids, hm, cropped_boxes = track(
                            tracker, args, orig_img, inps, boxes, hm,
                            cropped_boxes, im_name, scores)
                    hm = hm.cpu()
                    writer.save(boxes, scores, ids, hm, cropped_boxes,
                                orig_img, im_name)
                    if args.profile:
                        ckpt_time, post_time = getTime(ckpt_time)
                        runtime_profile['pn'].append(post_time)

                if args.profile:
                    # TQDM
                    im_names_desc.set_description(
                        'det time: {dt:.4f}  | pose time: {pt:.4f} | post processing: {pn:.4f}'
                        .format(dt=np.mean(runtime_profile['dt']),
                                pt=np.mean(runtime_profile['pt']),
                                pn=np.mean(runtime_profile['pn'])))
            print_finish_info()
            print("마무리 작업중...")
            while (writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(writer.count()) + ' images in the queue...')
            writer.stop()
            det_loader.stop()
            print("작업종료")
        except Exception as e:
            print(repr(e))
            print(
                'An error as above occurs when processing the images, please check it'
            )
            pass
        except KeyboardInterrupt:
            print_finish_info()
            # Thread won't be killed when press Ctrl+C
            if args.sp:
                det_loader.terminate()
                while (writer.running()):
                    time.sleep(1)
                    print('===========================> Rendering remaining ' +
                          str(writer.count()) + ' images in the queue...')
                writer.stop()
            else:
                # subprocesses are killed, manually clear queues

                det_loader.terminate()
                writer.terminate()
                writer.clear_queues()
                det_loader.clear_queues()
示例#10
0
def validate(m, heatmap_to_coord, batch_size=20):
    det_dataset = builder.build_dataset(cfg.DATASET.TEST,
                                        preset_cfg=cfg.DATA_PRESET,
                                        train=False,
                                        opt=opt)
    eval_joints = det_dataset.EVAL_JOINTS

    det_loader = torch.utils.data.DataLoader(det_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=20,
                                             drop_last=False)
    kpt_json = []
    m.eval()

    norm_type = cfg.LOSS.get('NORM_TYPE', None)
    hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
    combined_loss = (cfg.LOSS.get('TYPE') == 'Combined')

    halpe = (cfg.DATA_PRESET.NUM_JOINTS == 133) or (cfg.DATA_PRESET.NUM_JOINTS
                                                    == 136)

    for inps, crop_bboxes, bboxes, img_ids, scores, imghts, imgwds in tqdm(
            det_loader, dynamic_ncols=True):
        if isinstance(inps, list):
            inps = [inp.cuda() for inp in inps]
        else:
            inps = inps.cuda()
        output = m(inps)
        if opt.flip_test:
            if isinstance(inps, list):
                inps_flip = [flip(inp).cuda() for inp in inps]
            else:
                inps_flip = flip(inps).cuda()
            output_flip = flip_heatmap(m(inps_flip),
                                       det_dataset.joint_pairs,
                                       shift=True)
            pred_flip = output_flip[:, eval_joints, :, :]
        else:
            output_flip = None
            pred_flip = None

        pred = output
        assert pred.dim() == 4
        pred = pred[:, eval_joints, :, :]

        if output.size()[1] == 68:
            face_hand_num = 42
        else:
            face_hand_num = 110

        for i in range(output.shape[0]):
            bbox = crop_bboxes[i].tolist()
            if combined_loss:
                pose_coords_body_foot, pose_scores_body_foot = heatmap_to_coord[
                    0](pred[i][det_dataset.EVAL_JOINTS[:-face_hand_num]],
                       bbox,
                       hm_shape=hm_size,
                       norm_type=norm_type,
                       hms_flip=pred_flip[i][
                           det_dataset.EVAL_JOINTS[:-face_hand_num]]
                       if pred_flip is not None else None)
                pose_coords_face_hand, pose_scores_face_hand = heatmap_to_coord[
                    1](pred[i][det_dataset.EVAL_JOINTS[-face_hand_num:]],
                       bbox,
                       hm_shape=hm_size,
                       norm_type=norm_type,
                       hms_flip=pred_flip[i][
                           det_dataset.EVAL_JOINTS[-face_hand_num:]]
                       if pred_flip is not None else None)
                pose_coords = np.concatenate(
                    (pose_coords_body_foot, pose_coords_face_hand), axis=0)
                pose_scores = np.concatenate(
                    (pose_scores_body_foot, pose_scores_face_hand), axis=0)
            else:
                pose_coords, pose_scores = heatmap_to_coord(
                    pred[i][det_dataset.EVAL_JOINTS],
                    bbox,
                    hm_shape=hm_size,
                    norm_type=norm_type,
                    hms_flip=pred_flip[i][det_dataset.EVAL_JOINTS]
                    if pred_flip is not None else None)

            keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
            keypoints = keypoints.reshape(-1).tolist()

            data = dict()
            data['bbox'] = bboxes[i, 0].tolist()
            data['image_id'] = int(img_ids[i])
            data['area'] = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
            data['score'] = float(scores[i] + np.mean(pose_scores) +
                                  1.25 * np.max(pose_scores))
            # data['score'] = float(scores[i])
            data['category_id'] = 1
            data['keypoints'] = keypoints

            kpt_json.append(data)

    if opt.ppose_nms:
        from alphapose.utils.pPose_nms import ppose_nms_validate_preprocess, pose_nms, write_json
        final_result = []
        tmp_data = ppose_nms_validate_preprocess(kpt_json)
        for key in tmp_data:
            boxes, scores, ids, preds_img, preds_scores = tmp_data[key]
            boxes, scores, ids, preds_img, preds_scores, pick_ids = \
                        pose_nms(boxes, scores, ids, preds_img, preds_scores, 0, cfg.LOSS.get('TYPE') == 'MSELoss')

            _result = []
            for k in range(len(scores)):
                _result.append({
                    'keypoints':
                    preds_img[k],
                    'kp_score':
                    preds_scores[k],
                    'proposal_score':
                    torch.mean(preds_scores[k]) + scores[k] +
                    1.25 * max(preds_scores[k]),
                    'idx':
                    ids[k],
                    'box': [
                        boxes[k][0], boxes[k][1], boxes[k][2] - boxes[k][0],
                        boxes[k][3] - boxes[k][1]
                    ]
                })
            im_name = str(key).zfill(12) + '.jpg'
            result = {'imgname': im_name, 'result': _result}
            final_result.append(result)

        write_json(final_result,
                   './exp/json/',
                   form='coco',
                   for_eval=True,
                   outputfile='validate_rcnn_kpt.json')
    else:
        if opt.oks_nms:
            from alphapose.utils.pPose_nms import oks_pose_nms
            kpt_json = oks_pose_nms(kpt_json)

        with open('./exp/json/validate_rcnn_kpt.json', 'w') as fid:
            json.dump(kpt_json, fid)

    sysout = sys.stdout
    res = evaluate_mAP('./exp/json/validate_rcnn_kpt.json',
                       ann_type='keypoints',
                       ann_file=os.path.join(cfg.DATASET.TEST.ROOT,
                                             cfg.DATASET.TEST.ANN),
                       halpe=halpe)
    sys.stdout = sysout
    return res
示例#11
0
def validate_gt(m, cfg, heatmap_to_coord, batch_size=20):
    gt_val_dataset = builder.build_dataset(cfg.DATASET.VAL,
                                           preset_cfg=cfg.DATA_PRESET,
                                           train=False)
    eval_joints = gt_val_dataset.EVAL_JOINTS

    gt_val_loader = torch.utils.data.DataLoader(gt_val_dataset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=20,
                                                drop_last=False)
    kpt_json = []
    m.eval()

    norm_type = cfg.LOSS.get('NORM_TYPE', None)
    hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
    combined_loss = (cfg.LOSS.get('TYPE') == 'Combined')

    halpe = (cfg.DATA_PRESET.NUM_JOINTS == 133) or (cfg.DATA_PRESET.NUM_JOINTS
                                                    == 136)

    for inps, labels, label_masks, img_ids, bboxes in tqdm(gt_val_loader,
                                                           dynamic_ncols=True):
        if isinstance(inps, list):
            inps = [inp.cuda() for inp in inps]
        else:
            inps = inps.cuda()
        output = m(inps)
        if opt.flip_test:
            if isinstance(inps, list):
                inps_flip = [flip(inp).cuda() for inp in inps]
            else:
                inps_flip = flip(inps).cuda()
            output_flip = flip_heatmap(m(inps_flip),
                                       gt_val_dataset.joint_pairs,
                                       shift=True)
            pred_flip = output_flip[:, eval_joints, :, :]
        else:
            output_flip = None
            pred_flip = None

        pred = output
        assert pred.dim() == 4
        pred = pred[:, eval_joints, :, :]

        if output.size()[1] == 68:
            face_hand_num = 42
        else:
            face_hand_num = 110

        for i in range(output.shape[0]):
            bbox = bboxes[i].tolist()
            if combined_loss:
                pose_coords_body_foot, pose_scores_body_foot = heatmap_to_coord[
                    0](pred[i][gt_val_dataset.EVAL_JOINTS[:-face_hand_num]],
                       bbox,
                       hm_shape=hm_size,
                       norm_type=norm_type,
                       hms_flip=pred_flip[i][
                           gt_val_dataset.EVAL_JOINTS[:-face_hand_num]]
                       if pred_flip is not None else None)
                pose_coords_face_hand, pose_scores_face_hand = heatmap_to_coord[
                    1](pred[i][gt_val_dataset.EVAL_JOINTS[-face_hand_num:]],
                       bbox,
                       hm_shape=hm_size,
                       norm_type=norm_type,
                       hms_flip=pred_flip[i][
                           gt_val_dataset.EVAL_JOINTS[-face_hand_num:]]
                       if pred_flip is not None else None)
                pose_coords = np.concatenate(
                    (pose_coords_body_foot, pose_coords_face_hand), axis=0)
                pose_scores = np.concatenate(
                    (pose_scores_body_foot, pose_scores_face_hand), axis=0)
            else:
                pose_coords, pose_scores = heatmap_to_coord(
                    pred[i][gt_val_dataset.EVAL_JOINTS],
                    bbox,
                    hm_shape=hm_size,
                    norm_type=norm_type,
                    hms_flip=pred_flip[i][gt_val_dataset.EVAL_JOINTS]
                    if pred_flip is not None else None)

            keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
            keypoints = keypoints.reshape(-1).tolist()

            data = dict()
            data['bbox'] = bboxes[i].tolist()
            data['image_id'] = int(img_ids[i])
            data['score'] = float(
                np.mean(pose_scores) + 1.25 * np.max(pose_scores))
            data['category_id'] = 1
            data['keypoints'] = keypoints

            kpt_json.append(data)

    sysout = sys.stdout
    with open('./exp/json/validate_gt_kpt.json', 'w') as fid:
        json.dump(kpt_json, fid)
    res = evaluate_mAP('./exp/json/validate_gt_kpt.json',
                       ann_type='keypoints',
                       ann_file=os.path.join(cfg.DATASET.VAL.ROOT,
                                             cfg.DATASET.VAL.ANN),
                       halpe=halpe)
    sys.stdout = sysout
    return res
示例#12
0
    def predict(self, image, img_name):
        args = self.args
        # Load detection loader
        det_loader = DetectionLoader(self.input_source, [img_name], [image],
                                     get_detector(args),
                                     self.cfg,
                                     args,
                                     batchSize=args.detbatch,
                                     mode=self.mode).start()

        # Init data writer
        queueSize = args.qsize
        self.writer = DataWriter(self.cfg,
                                 args,
                                 save_video=False,
                                 queueSize=queueSize).start()

        runtime_profile = {'dt': [], 'pt': [], 'pn': []}

        data_len = det_loader.length
        im_names_desc = tqdm(range(data_len), dynamic_ncols=True)

        batchSize = args.posebatch
        if args.flip:
            batchSize = int(batchSize / 2)

        try:
            for i in im_names_desc:
                start_time = getTime()
                with torch.no_grad():
                    (inps, orig_img, im_name, boxes, scores, ids,
                     cropped_boxes) = det_loader.read()
                    if orig_img is None:
                        break
                    if boxes is None or boxes.nelement() == 0:
                        self.writer.save(None, None, None, None, None,
                                         orig_img, os.path.basename(im_name))
                        continue
                    if args.profile:
                        ckpt_time, det_time = getTime(start_time)
                        runtime_profile['dt'].append(det_time)
                    # Pose Estimation
                    inps = inps.to(args.device)
                    datalen = inps.size(0)
                    leftover = 0
                    if (datalen) % batchSize:
                        leftover = 1
                    num_batches = datalen // batchSize + leftover
                    hm = []
                    for j in range(num_batches):
                        inps_j = inps[j * batchSize:min((j + 1) *
                                                        batchSize, datalen)]
                        if args.flip:
                            inps_j = torch.cat((inps_j, flip(inps_j)))
                        hm_j = self.pose_model(inps_j)
                        if args.flip:
                            hm_j_flip = flip_heatmap(hm_j[int(len(hm_j) / 2):],
                                                     det_loader.joint_pairs,
                                                     shift=True)
                            hm_j = (hm_j[0:int(len(hm_j) / 2)] + hm_j_flip) / 2
                        hm.append(hm_j)
                    hm = torch.cat(hm)
                    if args.profile:
                        ckpt_time, pose_time = getTime(ckpt_time)
                        runtime_profile['pt'].append(pose_time)
                    hm = hm.cpu()
                    self.writer.save(boxes, scores, ids, hm, cropped_boxes,
                                     orig_img, os.path.basename(im_name))

                    if args.profile:
                        ckpt_time, post_time = getTime(ckpt_time)
                        runtime_profile['pn'].append(post_time)

                if args.profile:
                    # TQDM
                    im_names_desc.set_description(
                        'det time: {dt:.4f} | pose time: {pt:.4f} | post processing: {pn:.4f}'
                        .format(dt=np.mean(runtime_profile['dt']),
                                pt=np.mean(runtime_profile['pt']),
                                pn=np.mean(runtime_profile['pn'])))
            while (self.writer.running()):
                time.sleep(1)
                print('===========================> Rendering remaining ' +
                      str(self.writer.count()) + ' images in the queue...')
            self.writer.stop()
            det_loader.stop()
        except KeyboardInterrupt:
            self.print_finish_info(args)
            # Thread won't be killed when press Ctrl+C
            if args.sp:
                det_loader.terminate()
                while (self.writer.running()):
                    time.sleep(1)
                    print('===========================> Rendering remaining ' +
                          str(self.writer.count()) + ' images in the queue...')
                self.writer.stop()
            else:
                # subprocesses are killed, manually clear queues
                self.writer.commit()
                self.writer.clear_queues()
                # det_loader.clear_queues()
        final_result = self.writer.results()
        return write_json(final_result,
                          args.outputpath,
                          form=args.format,
                          for_eval=args.eval)