Exemple #1
0
def get_pose_estimation_prediction(cfg, model, image, vis_thre, transforms):
    # size at scale 1.0
    base_size, center, scale = get_multi_scale_size(
        image, cfg.DATASET.INPUT_SIZE, 1.0, 1.0
    )

    parser = HeatmapRegParser(cfg)

    with torch.no_grad():
        heatmap_fuse = 0
        final_heatmaps = None
        final_kpts = None
        input_size = cfg.DATASET.INPUT_SIZE

        for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
            #joints, mask do not use in demo mode
            joints = np.zeros((0, cfg.DATASET.NUM_JOINTS, 3))
            mask = np.zeros((image.shape[0], image.shape[1]))
            image_resized, _, _, center, scale = resize_align_multi_scale(
                image, joints, mask, input_size, s, 1.0
            )

            image_resized = transforms(image_resized)
            image_resized = image_resized.unsqueeze(0).cuda()

            outputs, heatmaps, kpts = get_multi_stage_outputs(
                cfg, model, image_resized, cfg.TEST.FLIP_TEST
            )
            final_heatmaps, final_kpts = aggregate_results(
                cfg, final_heatmaps, final_kpts, heatmaps, kpts
            )

        for heatmap in final_heatmaps:
            heatmap_fuse += up_interpolate(
                heatmap,
                size=(base_size[1], base_size[0]),
                mode='bilinear'
            )
        heatmap_fuse = heatmap_fuse/float(len(final_heatmaps))

        # for only pred kpts
        grouped, scores = parser.parse(
            final_heatmaps, final_kpts, heatmap_fuse[0], use_heatmap=False
        )
        if len(scores) == 0:
            return []

        results = get_final_preds(
            grouped, center, scale,
            [heatmap_fuse.size(-1), heatmap_fuse.size(-2)]
        )

        final_results = []
        for i in range(len(scores)):
            if scores[i] > vis_thre:
                final_results.append(results[i])

        if len(final_results) == 0:
            return []
    return final_results
Exemple #2
0
def aggregate_results(cfg, heatmap_sum, poses, heatmap, posemap, scale):
    """
    Get initial pose proposals and aggregate the results of all scale.

    Args: 
        heatmap (Tensor): Heatmap at this scale (1, 1+num_joints, w, h)
        posemap (Tensor): Posemap at this scale (1, 2*num_joints, w, h)
        heatmap_sum (Tensor): Sum of the heatmaps (1, 1+num_joints, w, h)
        poses (List): Gather of the pose proposals [(num_people, num_joints, 3)]
    """

    ratio = cfg.DATASET.INPUT_SIZE * 1.0 / cfg.DATASET.OUTPUT_SIZE
    reverse_scale = ratio / scale
    h, w = heatmap[0].size(-1), heatmap[0].size(-2)

    heatmap_sum += up_interpolate(heatmap,
                                  size=(int(reverse_scale * w),
                                        int(reverse_scale * h)),
                                  mode='bilinear')

    center_heatmap = heatmap[0, -1:]
    pose_ind, ctr_score = get_maximum_from_heatmap(cfg, center_heatmap)
    posemap = posemap[0].permute(1, 2, 0).view(h * w, -1, 2)
    pose = reverse_scale * posemap[pose_ind]
    ctr_score = ctr_score[:, None].expand(-1, pose.shape[-2])[:, :, None]
    poses.append(torch.cat([pose, ctr_score], dim=2))

    return heatmap_sum, poses
Exemple #3
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid'
    )

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(
            final_output_dir, 'model_best.pth.tar'
        )
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    #dump_input = torch.rand(
    #    (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE)
    #)
    #logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    data_loader, test_dataset = make_test_dataloader(cfg)

    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )
    else:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )

    parser = HeatmapRegParser(cfg)

    # for only kpts
    all_reg_preds = []
    all_reg_scores = []

    # for pred kpts and pred heat
    all_preds = []
    all_scores = []

    pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None
    for i, (images, joints, masks, areas) in enumerate(data_loader):
        assert 1 == images.size(0), 'Test batch size should be 1'

        image = images[0].cpu().numpy()
        joints = joints[0].cpu().numpy()
        mask = masks[0].cpu().numpy()
        area = areas[0].cpu().numpy()
        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, 1.0
        )

        with torch.no_grad():
            heatmap_fuse = 0
            final_heatmaps = None
            final_kpts = None
            input_size = cfg.DATASET.INPUT_SIZE

            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):

                image_resized, joints_resized, _, center, scale = resize_align_multi_scale(
                    image, joints, mask, input_size, s, 1.0
                )

                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, kpts = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST
                )
                final_heatmaps, final_kpts = aggregate_results(
                    cfg, final_heatmaps, final_kpts, heatmaps, kpts
                )

            for heatmap in final_heatmaps:
                heatmap_fuse += up_interpolate(
                    heatmap,
                    size=(base_size[1], base_size[0]),
                    mode='bilinear'
                )
            heatmap_fuse = heatmap_fuse/float(len(final_heatmaps))

            # for only pred kpts
            grouped, scores = parser.parse(
                final_heatmaps, final_kpts, heatmap_fuse[0], use_heatmap=False
            )

            if len(scores) == 0:
                all_reg_preds.append([])
                all_reg_scores.append([])
            else:
                final_results = get_final_preds(
                    grouped, center, scale,
                    [heatmap_fuse.size(-1),heatmap_fuse.size(-2)]
                )
                if cfg.RESCORE.USE:
                    scores = rescore_valid(cfg, final_results, scores)
                all_reg_preds.append(final_results)
                all_reg_scores.append(scores)

            # for pred kpts and pred heatmaps
            grouped, scores = parser.parse(
                final_heatmaps, final_kpts, heatmap_fuse[0], use_heatmap=True
            )
            if len(scores) == 0:
                all_preds.append([])
                all_scores.append([])
                if cfg.TEST.LOG_PROGRESS:
                    pbar.update()
                continue

            final_results = get_final_preds(
                grouped, center, scale,
                [heatmap_fuse.size(-1),heatmap_fuse.size(-2)]
            )

            if cfg.RESCORE.USE:
                scores = rescore_valid(cfg, final_results, scores)

            all_preds.append(final_results)
            all_scores.append(scores)

        if cfg.TEST.LOG_PROGRESS:
            pbar.update()
    
    sv_all_preds = [all_reg_preds, all_preds]
    sv_all_scores = [all_reg_scores, all_scores]
    sv_all_name = ['regression', 'final']

    if cfg.TEST.LOG_PROGRESS:
        pbar.close()

    for i in range(len(sv_all_preds)):
        print('Testing '+sv_all_name[i])
        preds = sv_all_preds[i]
        scores = sv_all_scores[i]
        name_values, _ = test_dataset.evaluate(
            cfg, preds, scores, final_output_dir, sv_all_name[i]
        )

        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(logger, name_value, cfg.MODEL.NAME)
        else:
            _print_name_value(logger, name_values, cfg.MODEL.NAME)
def get_multi_stage_outputs(cfg, model, image, with_flip=False):
    num_joints = cfg.DATASET.NUM_JOINTS - 1
    dataset = cfg.DATASET.DATASET
    heatmaps_avg = 0
    num_heatmaps = 0
    heatmaps = []
    reg_kpts_list = []

    # forward
    ##########################################################################
    if cfg.LOSS.HEATMAP_MIDDLE_LOSS:
        all_outputs, all_offsets, _ = model(image)
    else:
        all_outputs, all_offsets = model(image)
    ##########################################################################
    outputs = [get_one_stage_outputs(out) for out in all_outputs]
    offset = all_offsets[0][-1]
    h, w = offset.shape[2:]
    reg_kpts = get_reg_kpts(offset[0], num_joints)
    reg_kpts = reg_kpts.contiguous().view(h * w, 2 * num_joints).permute(
        1, 0).contiguous().view(1, -1, h, w)
    reg_kpts_list.append(reg_kpts)

    if with_flip:
        if 'coco' in dataset:
            flip_index_heat = FLIP_CONFIG['COCO_WITH_CENTER'] \
                if cfg.DATASET.WITH_CENTER else FLIP_CONFIG['COCO']
            flip_index_offset = FLIP_CONFIG['COCO']
        elif 'crowd_pose' in dataset:
            flip_index_heat = FLIP_CONFIG['CROWDPOSE_WITH_CENTER'] \
                if cfg.DATASET.WITH_CENTER else FLIP_CONFIG['CROWDPOSE']
            flip_index_offset = FLIP_CONFIG['CROWDPOSE']
        else:
            raise ValueError(
                'Please implement flip_index for new dataset: %s.' % dataset)

        new_image = torch.zeros_like(image)
        new_image_2x = torch.zeros_like(image)

        image = torch.flip(image, [3])
        new_image[:, :, :, :-3] = image[:, :, :, 3:]
        new_image_2x[:, :, :, :-1] = image[:, :, :, 1:]

        ##########################################################################
        if cfg.LOSS.HEATMAP_MIDDLE_LOSS:
            all_outputs_flip, all_offsets_flip, _ = model(new_image)
        else:
            all_outputs_flip, all_offsets_flip = model(new_image)
        ##########################################################################
        outputs_flip = [get_one_stage_outputs(all_outputs_flip[0])]
        if len(cfg.DATASET.OUTPUT_SIZE) > 1:
            ##########################################################################
            if cfg.LOSS.HEATMAP_MIDDLE_LOSS:
                all_outputs_flip, _, _ = model(new_image_2x)
            else:
                all_outputs_flip, _ = model(new_image_2x)
            ##########################################################################
            outputs_flip.append(get_one_stage_outputs(all_outputs_flip[1]))

        offset_flip = all_offsets_flip[0][-1]
        reg_kpts_flip = get_reg_kpts(offset_flip[0], num_joints)
        reg_kpts_flip = reg_kpts_flip[:, flip_index_offset, :]
        reg_kpts_flip[:, :, 0] = w - reg_kpts_flip[:, :, 0] - 1
        reg_kpts_flip = reg_kpts_flip.contiguous().view(
            h * w, 2 * num_joints).permute(1,
                                           0).contiguous().view(1, -1, h, w)
        reg_kpts_list.append(torch.flip(reg_kpts_flip, [3]))
    else:
        outputs_flip = None

    for i, output in enumerate(outputs):
        if len(outputs) > 1 and i != len(outputs) - 1:
            output = up_interpolate(output,
                                    size=(outputs[-1].size(2),
                                          outputs[-1].size(3)))

        c = output.shape[1]
        if cfg.LOSS.WITH_HEATMAPS_LOSS[i] and cfg.TEST.WITH_HEATMAPS[i]:
            num_heatmaps += 1
            if num_heatmaps > 1:
                heatmaps_avg[:, :c] += output
            else:
                heatmaps_avg += output

    if num_heatmaps > 0:
        heatmaps_avg[:, :c] /= num_heatmaps
        heatmaps.append(heatmaps_avg)

    if with_flip:
        heatmaps_avg = 0
        num_heatmaps = 0
        for i in range(len(outputs_flip)):
            output = outputs_flip[i]
            if len(outputs_flip) > 1 and i != len(outputs_flip) - 1:
                output = up_interpolate(output,
                                        size=(outputs_flip[-1].size(2),
                                              outputs_flip[-1].size(3)))

            output = torch.flip(output, [3])
            outputs.append(output)
            c = output.shape[1]

            if cfg.LOSS.WITH_HEATMAPS_LOSS[i] and cfg.TEST.WITH_HEATMAPS[i]:
                num_heatmaps += 1
                if 'coco' in dataset:
                    flip_index_heat = FLIP_CONFIG['COCO_WITH_CENTER'] \
                        if c == num_joints+1 else FLIP_CONFIG['COCO']
                elif 'crowd_pose' in dataset:
                    flip_index_heat = FLIP_CONFIG['CROWDPOSE_WITH_CENTER'] \
                        if c == num_joints+1 else FLIP_CONFIG['CROWDPOSE']
                else:
                    raise ValueError(
                        'Please implement flip_index for new dataset: %s.' %
                        dataset)

                if num_heatmaps > 1:
                    heatmaps_avg[:, :c] += output[:, flip_index_heat, :, :]
                else:
                    heatmaps_avg += \
                        output[:, flip_index_heat, :, :]
        if num_heatmaps > 0:
            heatmaps_avg[:, :c] /= num_heatmaps
            heatmaps.append(heatmaps_avg)

    return outputs, heatmaps, reg_kpts_list