Example #1
0
def worker(inputs, results, gpu, detection_cfg, estimation_cfg):
    worker_id = current_process()._identity[0] - 1
    global pose_estimators
    if worker_id not in pose_estimators:
        pose_estimators[worker_id] = init_pose_estimator(detection_cfg,
                                                         estimation_cfg,
                                                         device=gpu)
    while True:
        idx, image = inputs.get()

        # end signal
        if image is None:
            return

        res = inference_pose_estimator(pose_estimators[worker_id], image)
        res['frame_index'] = idx
        results.put(res)
Example #2
0
def worker(inputs, results, gpu, detection_cfg, estimation_cfg, render_image):
    worker_id = current_process()._identity[0] - 1
    global pose_estimators
    if worker_id not in pose_estimators:
        pose_estimators[worker_id] = init_pose_estimator(detection_cfg,
                                                         estimation_cfg,
                                                         device=gpu)
    while not inputs.empty():
        try:
            idx, image = inputs.get_nowait()
        except:
            return

        res = inference_pose_estimator(pose_estimators[worker_id], image)
        res['frame_index'] = idx

        if render_image:
            res['render_image'] = render(image, res['joint_preds'],
                                         res['person_bbox'],
                                         detection_cfg.bbox_thre)
        results.put(res)
Example #3
0
def inference(detection_cfg,
              estimation_cfg,
              video_file,
              gpus=1,
              worker_per_gpu=1,
              save_dir=None):

    video_frames = mmcv.VideoReader(video_file)
    all_result = []
    print('\nPose estimation:')

    # case for single process
    if gpus == 1 and worker_per_gpu == 1:
        model = init_pose_estimator(detection_cfg, estimation_cfg, device=0)
        prog_bar = ProgressBar(len(video_frames))
        for i, image in enumerate(video_frames):
            res = inference_pose_estimator(model, image)
            res['frame_index'] = i
            if save_dir is not None:
                res['render_image'] = render(image, res['joint_preds'],
                                             res['person_bbox'],
                                             detection_cfg.bbox_thre)
            all_result.append(res)
            prog_bar.update()

    # case for multi-process
    else:
        cache_checkpoint(detection_cfg.checkpoint_file)
        cache_checkpoint(estimation_cfg.checkpoint_file)
        num_worker = gpus * worker_per_gpu
        procs = []
        inputs = Manager().Queue(len(video_frames))
        results = Manager().Queue(len(video_frames))

        for i, image in enumerate(video_frames):
            inputs.put((i, image))

        for i in range(num_worker):
            p = Process(target=worker,
                        args=(inputs, results, i % gpus, detection_cfg,
                              estimation_cfg, save_dir is not None))
            procs.append(p)
            p.start()
        for i in range(len(video_frames)):
            t = results.get()
            all_result.append(t)
            if 'prog_bar' not in locals():
                prog_bar = ProgressBar(len(video_frames))
            prog_bar.update()
        for p in procs:
            p.join()

    # sort results
    all_result = sorted(all_result, key=lambda x: x['frame_index'])

    # generate video
    if (len(all_result) == len(video_frames)) and (save_dir is not None):
        print('\n\nGenerate video:')
        video_name = video_file.strip('/n').split('/')[-1]
        video_path = os.path.join(save_dir, video_name)
        vwriter = cv2.VideoWriter(video_path,
                                  mmcv.video.io.VideoWriter_fourcc(*('mp4v')),
                                  video_frames.fps, video_frames.resolution)
        prog_bar = ProgressBar(len(video_frames))
        for r in all_result:
            vwriter.write(r['render_image'])
            prog_bar.update()
        vwriter.release()
        print('\nVideo was saved to {}'.format(video_path))

    return all_result
Example #4
0
def realtime_detect(detection_cfg,
                    estimation_cfg,
                    model_cfg,
                    dataset_cfg,
                    tracker_cfg,
                    video_dir,
                    category_annotation,
                    checkpoint,
                    batch_size=64,
                    gpus=1,
                    workers=4):
    """
        初始化
    """
    # 初始化模型
    pose_estimators = init_pose_estimator(detection_cfg,
                                          estimation_cfg,
                                          device=0)
    if isinstance(model_cfg, list):
        model = [call_obj(**c) for c in model_cfg]
        model = torch.nn.Sequential(*model)
    else:
        model = call_obj(**model_cfg)
    load_checkpoint(model, checkpoint, map_location='cpu')
    model = MMDataParallel(model, device_ids=range(gpus)).cuda()
    model.eval()

    # 获取图像
    video_file = 'train/clean/clean10.avi'
    reader = mmcv.VideoReader(os.path.join(video_dir, video_file))
    video_frames = reader[:10000]

    if category_annotation is None:
        video_categories = dict()
    else:
        with open(category_annotation) as f:
            json_file = json.load(f)
            video_categories = json_file['annotations']
            action_class = json_file['categories']
    annotations = []
    num_keypoints = -1
    for i, image in enumerate(video_frames):
        res = inference_pose_estimator(pose_estimators, image)
        res['frame_index'] = i
        if not res['has_return']:
            continue
        num_person = len(res['joint_preds'])
        assert len(res['person_bbox']) == num_person

        for j in range(num_person):
            keypoints = [[p[0], p[1], round(s[0], 2)] for p, s in zip(
                res['joint_preds'][j].round().astype(int).tolist(),
                res['joint_scores'][j].tolist())]
            num_keypoints = len(keypoints)
            person_info = dict(
                person_bbox=res['person_bbox'][j].round().astype(int).tolist(),
                frame_index=res['frame_index'],
                id=j,
                person_id=None,
                keypoints=keypoints)
            annotations.append(person_info)
        category_id = video_categories[video_file][
            'category_id'] if video_file in video_categories else -1
        info = dict(video_name=video_file,
                    resolution=reader.resolution,
                    num_frame=len(video_frames),
                    num_keypoints=num_keypoints,
                    keypoint_channels=['x', 'y', 'score'],
                    version='1.0')
        video_info = dict(info=info,
                          category_id=category_id,
                          annotations=annotations)

        data_loader = data_parse(video_info, dataset_cfg.pipeline,
                                 dataset_cfg.data_source.num_track)
        data, label = data_loader
        with torch.no_grad():
            data = torch.from_numpy(data)
            # 增加一维,表示batch_size
            data = data.unsqueeze(0)
            data = data.float().to("cuda:0").detach()
            output = model(data).data.cpu().numpy()
        top1 = output.argmax()
        if output[:, top1] > 3:
            label = action_class[top1]
        else:
            label = 'unknow'
        print("reslt:", output)

        res['render_image'] = render(image, res['joint_preds'], label,
                                     res['person_bbox'],
                                     detection_cfg.bbox_thre)
        cv2.imshow('image', image)
        cv2.waitKey(10)
Example #5
0
def build(inputs,
          detection_cfg,
          estimation_cfg,
          tracker_cfg,
          video_dir,
          gpus=1,
          video_max_length=10000,
          category_annotation=None):
    print('data build start')
    cache_checkpoint(detection_cfg.checkpoint_file)
    cache_checkpoint(estimation_cfg.checkpoint_file)

    if category_annotation is None:
        video_categories = dict()
    else:
        with open(category_annotation) as f:
            video_categories = json.load(f)['annotations']

    if tracker_cfg is not None:
        raise NotImplementedError

    pose_estimators = init_pose_estimator(detection_cfg,
                                          estimation_cfg,
                                          device=0)

    video_file_list = []
    get_all_file(video_dir, video_file_list)

    prog_bar = ProgressBar(len(video_file_list))
    for video_path in video_file_list:
        video_file = os.path.basename(video_path)
        reader = mmcv.VideoReader(video_path)
        video_frames = reader[:video_max_length]

        annotations = []
        num_keypoints = -1
        for i, image in enumerate(video_frames):
            res = inference_pose_estimator(pose_estimators, image)
            res['frame_index'] = i
            if not res['has_return']:
                continue
            num_person = len(res['joint_preds'])
            assert len(res['person_bbox']) == num_person

            for j in range(num_person):
                keypoints = [[p[0], p[1], round(s[0], 2)] for p, s in zip(
                    res['joint_preds'][j].round().astype(int).tolist(),
                    res['joint_scores'][j].tolist())]
                num_keypoints = len(keypoints)
                person_info = dict(person_bbox=res['person_bbox']
                                   [j].round().astype(int).tolist(),
                                   frame_index=res['frame_index'],
                                   id=j,
                                   person_id=None,
                                   keypoints=keypoints)
                annotations.append(person_info)
        annotations = sorted(annotations, key=lambda x: x['frame_index'])
        category_id = video_categories[video_file][
            'category_id'] if video_file in video_categories else -1
        info = dict(video_name=video_file,
                    resolution=reader.resolution,
                    num_frame=len(video_frames),
                    num_keypoints=num_keypoints,
                    keypoint_channels=['x', 'y', 'score'],
                    version='1.0')
        video_info = dict(info=info,
                          category_id=category_id,
                          annotations=annotations)
        inputs.put(video_info)
        prog_bar.update()