def forward(self, images, proj_matrices=torch.rand((1,4,3,4)), orig_img_size=[640,480]):
        # images: b x v x 3 x 256(H) x 256(W)
        # proj_matrices: b x v x 3 x 4
        # orig_img_size: [W,H]
        device = images.device
        batch_size, n_views = images.shape[:2]

        # reshape n_views dimension to batch dimension
        images = images.view(-1, *images.shape[2:]) # b*v x 3 x H x W

        # forward backbone and integral
        if self.use_alg_confidences:
            # alg_confidences: (b*N_views) x N_joints
            heatmaps, _, alg_confidences, _ = self.backbone(images)
            alg_confidences = alg_confidences.view(batch_size, n_views, *alg_confidences.shape[1:]) # b x N_views x N_joints
            # norm confidences
            alg_confidences = alg_confidences / alg_confidences.sum(dim=1, keepdim=True) + 1e-5 # for numerical stability
        else:
            heatmaps, _, _, _ = self.backbone(images)
            alg_confidences = None

        keypoints_2d = get_final_preds(heatmaps, use_softmax=self.heatmap_softmax) # b*v x 21 x 2

        # reshape back
        heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])
        keypoints_2d = keypoints_2d.view(batch_size, n_views, *keypoints_2d.shape[1:]) # b x v x 21 x 2
        
        # upscale keypoints_2d so that it is located in the original image
        heatmap_size = heatmaps.shape[-1]
        keypoints_2d[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (orig_img_size[0] / heatmap_size) # u
        keypoints_2d[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (orig_img_size[1] / heatmap_size) # v
        # keypoints_2d_transformed = torch.zeros_like(keypoints_2d)
        # keypoints_2d_transformed[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (images.shape[-1] / heatmaps.shape[-1]) # u
        # keypoints_2d_transformed[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (images.shape[-2] / heatmaps.shape[-2]) # v
        # keypoints_2d = keypoints_2d_transformed

        # triangulate
        try:
            if self.use_alg_confidences:
                keypoints_3d = multiview.triangulate_batch_of_points(
                    proj_matrices, keypoints_2d,
                    confidences_batch=alg_confidences
                )
            else:
                keypoints_3d = torch.cat(
                    [DLT_sii_pytorch(keypoints_2d[:,:,k], proj_matrices).unsqueeze(1) for k in range(keypoints_2d.shape[2])],
                    dim=1
                ) # b x 21 x 3

        except RuntimeError as e:
            print("Error: ", e)

            print("confidences =", confidences_batch_pred)
            print("proj_matrices = ", proj_matrices)
            print("keypoints_2d_batch_pred =", keypoints_2d_batch_pred)
            exit()

        return keypoints_3d, keypoints_2d, heatmaps, alg_confidences
    def forward(self, images, proj_matrices):
        batch_size, n_views = images.shape[:2]
        orig_width, orid_height = self.orig_img_size
        # reshape n_views dimension to batch dimension
        images = images.view(-1, *images.shape[2:])

        # forward backbone and integrate
        heatmaps, _, _, _ = self.backbone(images) # 4 x 21 x 64 x 64
        
        # calcualte shapes
        image_shape = tuple(images.shape[3:])
        n_joints, heatmap_shape = heatmaps.shape[1], tuple(heatmaps.shape[2:])

        keypoints_2d = get_final_preds(heatmaps, self.heatmap_softmax).view(batch_size, n_views, -1, 2)

        # reshape back
        images = images.view(batch_size, n_views, *images.shape[1:])
        heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])

        # upscale keypoints_2d, because image shape != heatmap shape
        keypoints_2d_transformed = torch.zeros_like(keypoints_2d).float()
        keypoints_2d_transformed[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (orig_width / heatmap_shape[1])
        keypoints_2d_transformed[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (orid_height / heatmap_shape[0])
        keypoints_2d = keypoints_2d_transformed

        # triangulate (cpu)
        keypoints_2d_np = keypoints_2d.detach().cpu().numpy()
        proj_matricies_np = proj_matrices.detach().cpu().numpy()

        # contrast 21 x 3
        #keypoint_3d_in_base_camera_test = np.stack([DLT(keypoints_2d_np[b], proj_matricies_np[b]) for b in range(batch_size)])
        #keypoint_3d_in_base_camera_test = np.stack([multiview.triangulate_point_from_multiple_views_linear(proj_matrices[0], keypoints_2d_np[0,:,k]) for k in range(21)])
        #print(keypoint_3d_in_base_camera_test)

        keypoints_3d = np.zeros((batch_size, n_joints, 3))
        confidences = np.zeros((batch_size, n_views, n_joints))  # plug
        for batch_i in range(batch_size):
            for joint_i in range(n_joints):
                current_proj_matricies = proj_matricies_np[batch_i]
                points = keypoints_2d_np[batch_i, :, joint_i]
                keypoint_3d, _ = triangulate_ransac(current_proj_matricies, points, reprojection_error_epsilon=25, direct_optimization=self.direct_optimization)
                keypoints_3d[batch_i, joint_i] = keypoint_3d

        keypoints_3d = torch.from_numpy(keypoints_3d).type(torch.float).to(images.device)
        confidences = torch.from_numpy(confidences).type(torch.float).to(images.device)

        return keypoints_3d, keypoints_2d, heatmaps, confidences
Ejemplo n.º 3
0
def predict_one_img(image, show=False, img_path=None):
    trans = build_transforms(cfg, is_train=False)
    temp_joints = [np.ones((21, 3))]
    orig_img = image.copy()
    resized_image = cv2.cvtColor(
        cv2.resize(image, tuple(cfg.MODEL.IMAGE_SIZE)), cv2.COLOR_RGB2BGR)
    I, _ = trans(resized_image, temp_joints)
    I = I.unsqueeze(0).to(device) if args.gpu != 'cpu' else I.unsqueeze(0)
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        output, _ = model(I)  # output size: 1 x 21 x 64(H) x 64(W)
        # print('Inference time: {:.4f} s'.format(time.time()-start_time))
        kps_pred_np = get_final_preds(
            output,
            use_softmax=cfg.MODEL.HEATMAP_SOFTMAX).cpu().numpy().squeeze()
    return kps_pred_np
    def forward(self, x):
        # x: b x f x 3 x H x W
        n_batches, n_frames = x.shape[0:2]
        heatmaps_pred, trainable_temp = self.backbone(x.view(
            -1, *x.shape[-3:]))  # b*f x 21 x 64 x 64
        n_joints = heatmaps_pred.shape[1]
        pose2d_pred = get_final_preds(heatmaps_pred,
                                      use_softmax=self.use_softmax).view(
                                          n_batches, n_frames, n_joints,
                                          2)  # b x f x 21 x 2

        x = pose2d_pred.permute(0, 3, 1, 2)
        ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data
        x = self.Spatial_forward_features(x)  # b x N_frames x (N_joints * 32)
        x = self.forward_features(x)  # b x 1 x (N_joints * 32)
        pose2d_pred_refined = self.head(x)  # b x 1 x (N_joints * 2)

        return pose2d_pred_refined.view(n_batches, n_joints,
                                        -1), heatmaps_pred, trainable_temp
Ejemplo n.º 5
0
    def forward(self, images, proj_matrices, batch=None, keypoints_3d=None):
        # images: b x N_views x 3 x H x W
        # proj_matricies_batch (K*H): b x N_views x 3 x 4
        device = images.device
        batch_size, n_views = images.shape[:2]

        # reshape for backbone forward
        images = images.view(-1, *images.shape[2:])

        # forward backbone
        # heatmaps: (b*N_views) x 21 x 64 x 64
        # features: (b*N_views) x 480 x 64 x 64
        heatmaps, _ = self.backbone(images)

        # find the middle finger root position
        base_idx = 9
        base_points_2d = get_final_preds(heatmaps, use_softmax=self.heatmap_softmax).view(batch_size, n_views, heatmaps.shape[1], 2) # batch_size x N_views x 21 x 2
        base_points = DLT_sii_pytorch(base_points_2d[:,:,9], proj_matrices[b]) # b x 3
        
        for b in range(batch_size):
            grid_center = base_points[b]
            limb_length = self.hand.compute_limb_length()
def main():
    args = parse_args()

    update_config(cfg, args)
    cfg.defrost()
    cfg.freeze()

    record_prefix = './eval2D_results_'
    if args.is_vis:
        result_dir = record_prefix + cfg.EXP_NAME
        mse2d_lst = np.loadtxt(os.path.join(result_dir,
                                            'mse2d_each_joint.txt'))
        PCK2d_lst = np.loadtxt(os.path.join(result_dir, 'PCK2d.txt'))

        plot_performance(PCK2d_lst[1, :], PCK2d_lst[0, :], mse2d_lst)
        exit()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_path = args.model_path
    is_vis = args.is_vis

    # FP16 SETTING
    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    model = eval(cfg.MODEL.NAME + '.get_pose_net')(cfg, is_train=False)

    # # calculate GFLOPS
    # dump_input = torch.rand(
    #     (5, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0])
    # )

    # print(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    # ops, params = get_model_complexity_info(
    #    model, (3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0]),
    #    as_strings=True, print_per_layer_stat=True, verbose=True)
    # input()

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    if args.gpu != -1:
        device = torch.device('cuda:' + str(args.gpu))
        torch.cuda.set_device(args.gpu)
    else:
        device = torch.device('cpu')
    # load model state
    if model_path:
        print("Loading model:", model_path)
        ckpt = torch.load(model_path)  #, map_location='cpu')
        if 'state_dict' not in ckpt.keys():
            state_dict = ckpt
        else:
            state_dict = ckpt['state_dict']
            print('Model epoch {}'.format(ckpt['epoch']))

        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict, strict=True)

    model.to(device)

    # calculate GFLOPS
    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0])).to(device)

    print(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    model.eval()

    # inference_dataset = eval('dataset.{}'.format(cfg.DATASET.TEST_DATASET[0].replace('_kpt','')))(
    #     cfg.DATA_DIR,
    #     cfg.DATASET.TEST_SET,
    #     transform=transform
    # )
    inference_dataset = eval('dataset.{}'.format(
        cfg.DATASET.TEST_DATASET[0].replace('_kpt', '')))(
            cfg.DATA_DIR,
            cfg.DATASET.TEST_SET,
            transforms=build_transforms(cfg, is_train=False))

    batch_size = args.batch_size
    data_loader = torch.utils.data.DataLoader(
        inference_dataset,
        batch_size=batch_size,  #48
        shuffle=False,
        num_workers=min(8, batch_size),  #8
        pin_memory=False)

    print('\nEvaluation loader information:\n' + str(data_loader.dataset))
    n_joints = cfg.DATASET.NUM_JOINTS
    th2d_lst = np.array([i for i in range(1, 50)])
    PCK2d_lst = np.zeros((len(th2d_lst), ))
    mse2d_lst = np.zeros((n_joints, ))
    visibility_lst = np.zeros((n_joints, ))

    print('Start evaluating... [Batch size: {}]\n'.format(
        data_loader.batch_size))
    with torch.no_grad():
        pose2d_mse_loss = JointsMSELoss().to(device)
        infer_time = [0, 0]
        start_time = time.time()
        for i, ret in enumerate(data_loader):
            # pose2d_gt: b x 21 x 2 is [u,v] 0<=u<64, 0<=v<64 (heatmap size)
            # visibility: b x 21 vis=0/1
            imgs = ret['imgs']
            pose2d_gt = ret['pose2d']  # b [x v] x 21 x 2
            visibility = ret['visibility']  # b [x v] x 21 x 1

            s1 = time.time()
            if 'CPM' == cfg.MODEL.NAME:
                pose2d_gt = pose2d_gt.view(-1, *pose2d_gt.shape[-2:])
                heatmap_lst = model(
                    imgs.to(device), ret['centermaps'].to(device)
                )  # 6 groups of heatmaps, each of which has size (1,22,32,32)
                heatmaps = heatmap_lst[-1][:, 1:]
                pose2d_pred = data_loader.dataset.get_kpts(heatmaps)
                hm_size = heatmap_lst[-1].shape[-1]  # 32
            else:
                if cfg.MODEL.NAME == 'pose_hrnet_transformer':
                    # imgs: b(1) x (4*seq_len) x 3 x 256 x 256
                    n_batches, seq_len = imgs.shape[0], imgs.shape[1] // 4
                    idx_lst = torch.tensor([4 * i for i in range(seq_len)])
                    imgs = torch.stack([
                        imgs[b, idx_lst + cam_idx] for b in range(n_batches)
                        for cam_idx in range(4)
                    ])  # (b*4) x seq_len x 3 x 256 x 256

                    pose2d_pred, heatmaps_pred, _ = model(
                        imgs.cuda(device))  # (b*4) x 21 x 2
                    pose2d_gt = pose2d_gt[:, 4 * (seq_len // 2):4 * (
                        seq_len // 2 + 1)].contiguous().view(
                            -1, *pose2d_pred.shape[-2:])  # (b*4) x 21 x 2
                    visibility = visibility[:, 4 * (seq_len // 2):4 * (
                        seq_len // 2 + 1)].contiguous().view(
                            -1, *visibility.shape[-2:])  # (b*4) x 21

                else:
                    if 'Aggr' in cfg.MODEL.NAME:
                        # imgs: b x (4*5) x 3 x 256 x 256
                        n_batches, seq_len = imgs.shape[0], len(
                            cfg.DATASET.SEQ_IDX)
                        true_batch_size = imgs.shape[1] // seq_len
                        pose2d_gt = torch.cat([
                            pose2d_gt[b, true_batch_size *
                                      (seq_len // 2):true_batch_size *
                                      (seq_len // 2 + 1)]
                            for b in range(n_batches)
                        ],
                                              dim=0)

                        visibility = torch.cat([
                            visibility[b, true_batch_size *
                                       (seq_len // 2):true_batch_size *
                                       (seq_len // 2 + 1)]
                            for b in range(n_batches)
                        ],
                                               dim=0)

                        imgs = torch.cat([
                            imgs[b, true_batch_size * j:true_batch_size *
                                 (j + 1)] for j in range(seq_len)
                            for b in range(n_batches)
                        ],
                                         dim=0)  # (b*4*5) x 3 x 256 x 256

                        heatmaps_pred, _ = model(imgs.to(device))
                    else:
                        pose2d_gt = pose2d_gt.view(-1, *pose2d_gt.shape[-2:])
                        heatmaps_pred, _ = model(
                            imgs.to(device))  # b x 21 x 64 x 64

                    pose2d_pred = get_final_preds(
                        heatmaps_pred, cfg.MODEL.HEATMAP_SOFTMAX)  # b x 21 x 2

                hm_size = heatmaps_pred.shape[-1]  # 64

            if i > 20:
                infer_time[0] += 1
                infer_time[1] += time.time() - s1

            # rescale to the original image before DLT

            if 'RHD' in cfg.DATASET.TEST_DATASET[0]:
                crop_size, corner = ret['crop_size'], ret['corner']
                crop_size, corner = crop_size.view(-1, 1, 1), corner.unsqueeze(
                    1)  # b x 1 x 1; b x 2 x 1
                pose2d_pred = pose2d_pred.cpu() * crop_size / hm_size + corner
                pose2d_gt = pose2d_gt * crop_size / hm_size + corner
            else:
                orig_width, orig_height = data_loader.dataset.orig_img_size
                pose2d_pred[:, :, 0] *= orig_width / hm_size
                pose2d_pred[:, :, 1] *= orig_height / hm_size
                pose2d_gt[:, :, 0] *= orig_width / hm_size
                pose2d_gt[:, :, 1] *= orig_height / hm_size

                # for k in range(21):
                #     print(pose2d_gt[0,k].tolist(), pose2d_pred[0,k].tolist())
                # input()
            # 2D errors
            pose2d_pred, pose2d_gt, visibility = pose2d_pred.cpu().numpy(
            ), pose2d_gt.numpy(), visibility.squeeze(2).numpy()

            # import matplotlib.pyplot as plt
            # imgs = cv2.resize(imgs[0].permute(1,2,0).cpu().numpy(), tuple(data_loader.dataset.orig_img_size))
            # for k in range(21):
            #     print(pose2d_gt[0,k],pose2d_pred[0,k],visibility[0,k])
            # for k in range(0,21,5):
            #     fig = plt.figure()
            #     ax1 = fig.add_subplot(131)
            #     ax2 = fig.add_subplot(132)
            #     ax3 = fig.add_subplot(133)
            #     ax1.imshow(cv2.cvtColor(imgs / imgs.max(), cv2.COLOR_BGR2RGB))
            #     plot_hand(ax1, pose2d_gt[0,:,0:2], order='uv')
            #     ax2.imshow(cv2.cvtColor(imgs / imgs.max(), cv2.COLOR_BGR2RGB))
            #     plot_hand(ax2, pose2d_pred[0,:,0:2], order='uv')
            #     ax3.imshow(heatmaps_pred[0,k].cpu().numpy())
            #     plt.show()
            mse_each_joint = np.linalg.norm(pose2d_pred - pose2d_gt,
                                            axis=2) * visibility  # b x 21

            mse2d_lst += mse_each_joint.sum(axis=0)
            visibility_lst += visibility.sum(axis=0)

            for th_idx in range(len(th2d_lst)):
                PCK2d_lst[th_idx] += np.sum(
                    (mse_each_joint < th2d_lst[th_idx]) * visibility)

            period = 10
            if i % (len(data_loader) // period) == 0:
                print("[Evaluation]{}% finished.".format(
                    period * i // (len(data_loader) // period)))
            #if i == 10:break
        print('Evaluation spent {:.2f} s\tfps: {:.1f} {:.4f}'.format(
            time.time() - start_time, infer_time[0] / infer_time[1],
            infer_time[1] / infer_time[0]))

        mse2d_lst /= visibility_lst
        PCK2d_lst /= visibility_lst.sum()

        result_dir = record_prefix + cfg.EXP_NAME
        if not os.path.exists(result_dir):
            os.mkdir(result_dir)

        mse_file, pck_file = os.path.join(
            result_dir,
            'mse2d_each_joint.txt'), os.path.join(result_dir, 'PCK2d.txt')
        print('Saving results to ' + mse_file)
        print('Saving results to ' + pck_file)
        np.savetxt(mse_file, mse2d_lst, fmt='%.4f')
        np.savetxt(pck_file, np.stack((th2d_lst, PCK2d_lst)))

        plot_performance(PCK2d_lst, th2d_lst, mse2d_lst)
    def predict_one_img(image, show=False, img_path=None):
        trans = build_transforms(cfg, is_train=False)
        temp_joints = [np.ones((21,3))]
        orig_img = image.copy()
        resized_image = cv2.cvtColor(cv2.resize(image, tuple(cfg.MODEL.IMAGE_SIZE)), cv2.COLOR_RGB2BGR)
        I, _ = trans(resized_image, temp_joints)
        I = I.unsqueeze(0).to(device) if args.gpu != 'cpu' else I.unsqueeze(0)

        model.eval()
        with torch.no_grad():
            start_time = time.time()
            output, _ = model(I) # output size: 1 x 21 x 64(H) x 64(W)
            print('Inference time: {:.4f} s'.format(time.time()-start_time))
            kps_pred_np = get_final_preds(output, use_softmax=cfg.MODEL.HEATMAP_SOFTMAX).cpu().numpy().squeeze()
            #kps_pred_np = kornia.spatial_soft_argmax2d(output, normalized_coordinates=False) # 1 x 21 x 2
        #kps_pred_np =  kps_pred_np[0] * np.array([256 / cfg.MODEL.HEATMAP_SIZE[0], 256 / cfg.MODEL.HEATMAP_SIZE[0]])
        

        #kps_pred_np[:,0] += 25
        if True:            
            all_flag = False
            kps_pred_np *=  cfg.MODEL.IMAGE_SIZE[0] / cfg.MODEL.HEATMAP_SIZE[0]
            heatmap_all = np.zeros(tuple(output.shape[2:]))
            heatmap_lst = []

            fig = plt.figure()           
            ax1 = fig.add_subplot(1,2,1)
            ax1.imshow(resized_image)

            for kp in range(0,21): 
                heatmap = output[0][kp].cpu().numpy()
                heatmap_lst.append(heatmap)
                heatmap_all += heatmap

            if not all_flag:         
                ax2 = fig.add_subplot(1,2,2)
                heatmap_cat = np.vstack((np.hstack(heatmap_lst[0:7]), np.hstack(heatmap_lst[7:14]), np.hstack(heatmap_lst[14:21])))
                print(heatmap_cat.shape)
                #hm = 255 * output[0][kp] / hms[0][kp].sum()
                ax1.scatter(kps_pred_np[kp][0], kps_pred_np[kp][1], linewidths=10)
                ax2.imshow(heatmap_cat)

                plt.title(kps_pred_np[kp].tolist())
                plt.show()

            if all_flag:
                
                ax2 = fig.add_subplot(1,2,2)
                ax1.imshow(resized_image)
                ax2.imshow(heatmap_all / heatmap_all.max())
                plt.show()
        else:
            kps_pred_np[:,0] *= orig_img.shape[1] / cfg.MODEL.HEATMAP_SIZE[0]
            kps_pred_np[:,1] *= orig_img.shape[0] / cfg.MODEL.HEATMAP_SIZE[0]
            kps_pred_np[:,0] += 0
            fig = plt.figure()
            fig.set_tight_layout(True)
            plt.imshow(cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB))

            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['thumb palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['thumb palm']][1]], c='r', marker='.')
            plt.plot(kps_pred_np[1:5,0], kps_pred_np[1:5,1], c='r', marker='.', label='Thumb')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['index palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['index palm']][1]], c='g', marker='.')
            plt.plot(kps_pred_np[5:9,0], kps_pred_np[5:9,1], c='g', marker='.', label='Index')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['middle palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['middle palm']][1]], c='b', marker='.')
            plt.plot(kps_pred_np[9:13,0], kps_pred_np[9:13,1], c='b', marker='.', label='Middle')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['ring palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['ring palm']][1]], c='m', marker='.')
            plt.plot(kps_pred_np[13:17,0], kps_pred_np[13:17,1], c='m', marker='.', label='Ring')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['pinky palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['pinky palm']][1]], c='y', marker='.')
            plt.plot(kps_pred_np[17:21,0], kps_pred_np[17:21,1], c='y', marker='.', label='Pinky')
            plt.title('Prediction')
            if img_path:
                plt.title(img_path)
            plt.axis('off')
            plt.legend(bbox_to_anchor=(1.04, 1), loc="upper right", ncol=1, mode="expand", borderaxespad=0.)


        fig.canvas.draw()
        # Get the RGBA buffer from the figure
        buf = fig.canvas.buffer_rgba()

        if show:
            plt.show()
            print(kps_pred_np)


        return np.asarray(buf), kps_pred_np
def main():
    args = parse_args()
    update_config(cfg, args)
    cfg.defrost()
    cfg.freeze()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_path = args.model_path
    is_vis = args.is_vis

    gpus = ','.join([str(i) for i in cfg.GPUS])
    gpu_ids = eval('[' + gpus + ']')

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
    #     cfg, is_train=True
    # )

    if 'pose_hrnet' in cfg.MODEL.NAME:
        model = {
            "pose_hrnet": pose_hrnet.get_pose_net,
            "pose_hrnet_softmax": pose_hrnet_softmax.get_pose_net
        }[cfg.MODEL.NAME](cfg, is_train=True)
    else:
        model = {
            "ransac": RANSACTriangulationNet,
            "alg": AlgebraicTriangulationNet,
            "vol": VolumetricTriangulationNet,
            "vol_CPM": VolumetricTriangulationNet_CPM,
            "FTL": FTLMultiviewNet
        }[cfg.MODEL.NAME](cfg, is_train=False)

    # load model state
    if model_path:
        print("Loading model:", model_path)
        ckpt = torch.load(model_path,
                          map_location='cpu' if args.gpu == -1 else 'cuda:0')
        if 'state_dict' not in ckpt.keys():
            state_dict = ckpt
        else:
            state_dict = ckpt['state_dict']
            print('Model epoch {}'.format(ckpt['epoch']))

        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict, strict=True)

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    device = torch.device('cuda:' + str(args.gpu) if args.gpu != -1 else 'cpu')

    model.to(device)

    model.eval()

    # image transformer
    transform = build_transforms(cfg, is_train=False)

    inference_dataset = eval('dataset.' + cfg.DATASET.TEST_DATASET[0])(
        cfg, cfg.DATASET.TEST_SET, transform=transform)

    data_loader = torch.utils.data.DataLoader(inference_dataset,
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=False)

    print('\nValidation loader information:\n' + str(data_loader.dataset))

    with torch.no_grad():
        pose2d_mse_loss = JointsMSELoss().to(
            device) if args.gpu != -1 else JointsMSELoss()
        pose3d_mse_loss = Joints3DMSELoss().to(
            device) if args.gpu != -1 else Joints3DMSELoss()
        orig_width, orig_height = inference_dataset.orig_img_size
        heatmap_size = cfg.MODEL.HEATMAP_SIZE
        count = 4
        for i, ret in enumerate(data_loader):
            # orig_imgs: 1 x 4 x 480 x 640 x 3
            # imgs: 1 x 4 x 3 x H x W
            # pose2d_gt (bounded in 64 x 64): 1 x 4 x 21 x 2
            # pose3d_gt: 1 x 21 x 3
            # visibility: 1 x 4 x 21
            # extrinsic matrix: 1 x 4 x 3 x 4
            # intrinsic matrix: 1 x 3 x 3
            if not (i % 67 == 0): continue

            imgs = ret['imgs'].to(device)
            orig_imgs = ret['orig_imgs']
            pose2d_gt, pose3d_gt, visibility = ret['pose2d'], ret[
                'pose3d'], ret['visibility']
            extrinsic_matrices, intrinsic_matrices = ret[
                'extrinsic_matrices'], ret['intrinsic_matrix']
            # somtimes intrisic_matrix has a shape of 3x3 or b x 3x3
            intrinsic_matrix = intrinsic_matrices[0] if len(
                intrinsic_matrices.shape) == 3 else intrinsic_matrices

            start_time = time.time()
            if 'pose_hrnet' in cfg.MODEL.NAME:
                pose3d_gt = pose3d_gt.to(device)

                heatmaps, _ = model(imgs[0])  # N_views x 21 x 64 x 64
                pose2d_pred = get_final_preds(heatmaps,
                                              cfg)  # N_views x 21 x 2
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4

                # rescale to the original image before DLT
                pose2d_pred[:, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, 1:2] *= orig_height / heatmap_size[0]
                # 3D world coordinate 1 x 21 x 3
                pose3d_pred = DLT_pytorch(pose2d_pred,
                                          proj_matrices.squeeze()).unsqueeze(0)

            elif 'alg' == cfg.MODEL.NAME or 'ransac' == cfg.MODEL.NAME:
                # the predicted 2D poses have been rescaled inside the triangulation model
                # pose2d_pred: 1 x N_views x 21 x 2
                # pose3d_pred: 1 x 21 x 3
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices
                                 )  # b x v x 3 x 4

                pose3d_pred,\
                pose2d_pred,\
                heatmaps,\
                confidences_pred = model(imgs, proj_matrices.to(device))

            elif "vol" in cfg.MODEL.NAME:
                intrinsic_matrix = update_after_resize(
                    intrinsic_matrix, (orig_height, orig_width),
                    tuple(heatmap_size))
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4

                # pose3d_pred (torch.tensor) b x 21 x 3
                # pose2d_pred (torch.tensor) b x v x 21 x 2 NOTE: the estimated 2D poses are located in the heatmap size 64(W) x 64(H)
                # heatmaps_pred (torch.tensor) b x v x 21 x 64 x 64
                # volumes_pred (torch.tensor)
                # confidences_pred (torch.tensor)
                # cuboids_pred (list)
                # coord_volumes_pred (torch.tensor)
                # base_points_pred (torch.tensor) b x v x 1 x 2
                if cfg.MODEL.BACKBONE_NAME == 'CPM_volumetric':
                    centermaps = ret['centermaps'].to(device)

                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps_pred,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, centermaps, proj_matrices)
                else:
                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, proj_matrices)

                pose2d_pred[:, :, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, :, 1:2] *= orig_height / heatmap_size[0]

            elif 'FTL' == cfg.MODEL.NAME:
                # pose2d_pred: 1 x 4 x 21 x 2
                # pose3d_pred: 1 x 21 x 3
                heatmaps, pose2d_pred, pose3d_pred = model(
                    imgs.to(device), extrinsic_matrices.to(device),
                    intrinsic_matrix.to(device))

                print(pose2d_pred)
                pose2d_pred = torch.cat((pose2d_pred[:, :, :, 0:1] * 640 / 64,
                                         pose2d_pred[:, :, :, 1:2] * 480 / 64),
                                        dim=-1)

            # N_views x 21 x 2
            end_time = time.time()
            print('3D pose inference time {:.1f} ms'.format(
                1000 * (end_time - start_time)))
            pose3d_EPE = pose3d_mse_loss(pose3d_pred[:, 1:],
                                         pose3d_gt[:, 1:].to(device)).item()
            print('Pose3d MSE: {:.4f}\n'.format(pose3d_EPE))

            # if pose3d_EPE > 35:
            #     input()
            #     continue
            # 2D errors
            pose2d_gt[:, :, :, 0] *= orig_width / heatmap_size[0]
            pose2d_gt[:, :, :, 1] *= orig_height / heatmap_size[1]

            # for k in range(21):
            #     print(pose2d_gt[0,k].tolist(), pose2d_pred[0,k].tolist())
            # input()

            visualize(args=args,
                      imgs=np.squeeze(orig_imgs[0].numpy()),
                      pose2d_gt=np.squeeze(pose2d_gt.cpu().numpy()),
                      pose2d_pred=np.squeeze(pose2d_pred.cpu().numpy()),
                      pose3d_gt=np.squeeze(pose3d_gt.cpu().numpy()),
                      pose3d_pred=np.squeeze(pose3d_pred.cpu().numpy()))
def main():
    args = parse_args()

    update_config(cfg, args)
    cfg.defrost()
    cfg.freeze()

    if args.is_vis:
        result_dir = prefix + cfg.EXP_NAME
        mse2d_lst = np.loadtxt(os.path.join(result_dir,
                                            'mse2d_each_joint.txt'))
        mse3d_lst = np.loadtxt(os.path.join(result_dir,
                                            'mse3d_each_joint.txt'))
        PCK2d_lst = np.loadtxt(os.path.join(result_dir, 'PCK2d.txt'))
        PCK3d_lst = np.loadtxt(os.path.join(result_dir, 'PCK3d.txt'))

        plot_performance(PCK2d_lst[1, :], PCK2d_lst[0, :], PCK3d_lst[1, :],
                         PCK3d_lst[0, :], mse2d_lst, mse3d_lst)
        exit()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_path = args.model_path
    is_vis = args.is_vis

    gpus = ','.join([str(i) for i in cfg.GPUS])
    gpu_ids = eval('[' + gpus + ']')

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    if 'pose_hrnet' in cfg.MODEL.NAME:
        model = {
            "pose_hrnet": pose_hrnet.get_pose_net,
            "pose_hrnet_softmax": pose_hrnet_softmax.get_pose_net
        }[cfg.MODEL.NAME](cfg, is_train=True)
    else:
        model = {
            "ransac": RANSACTriangulationNet,
            "alg": AlgebraicTriangulationNet,
            "vol": VolumetricTriangulationNet,
            "vol_CPM": VolumetricTriangulationNet_CPM,
            "FTL": FTLMultiviewNet
        }[cfg.MODEL.NAME](cfg, is_train=False)

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    # load model state
    if model_path:
        print("Loading model:", model_path)
        ckpt = torch.load(model_path,
                          map_location='cpu' if args.gpu == -1 else 'cuda:0')
        if 'state_dict' not in ckpt.keys():
            state_dict = ckpt
        else:
            state_dict = ckpt['state_dict']
            print('Model epoch {}'.format(ckpt['epoch']))

        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict, strict=False)

    device = torch.device('cuda:' + str(args.gpu) if args.gpu != -1 else 'cpu')

    model.to(device)

    model.eval()

    # image transformer
    transform = build_transforms(cfg, is_train=False)

    inference_dataset = eval('dataset.' + cfg.DATASET.DATASET[0])(
        cfg, cfg.DATASET.TEST_SET, transform=transform)
    inference_dataset.n_views = eval(args.views)
    batch_size = args.batch_size
    if platform.system() == 'Linux':  # for linux
        data_loader = torch.utils.data.DataLoader(inference_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=8,
                                                  pin_memory=False)
    else:  # for windows
        batch_size = 1
        data_loader = torch.utils.data.DataLoader(inference_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=0,
                                                  pin_memory=False)

    print('\nEvaluation loader information:\n' + str(data_loader.dataset))
    print('Evaluation batch size: {}\n'.format(batch_size))

    th2d_lst = np.array([i for i in range(1, 50)])
    PCK2d_lst = np.zeros((len(th2d_lst), ))
    mse2d_lst = np.zeros((21, ))
    th3d_lst = np.array([i for i in range(1, 51)])
    PCK3d_lst = np.zeros((len(th3d_lst), ))
    mse3d_lst = np.zeros((21, ))
    visibility_lst = np.zeros((21, ))
    with torch.no_grad():
        start_time = time.time()
        pose2d_mse_loss = JointsMSELoss().cuda(
            args.gpu) if args.gpu != -1 else JointsMSELoss()
        pose3d_mse_loss = Joints3DMSELoss().cuda(
            args.gpu) if args.gpu != -1 else Joints3DMSELoss()

        infer_time = [0, 0]
        start_time = time.time()
        n_valid = 0
        model.orig_img_size = inference_dataset.orig_img_size
        orig_width, orig_height = model.orig_img_size
        heatmap_size = cfg.MODEL.HEATMAP_SIZE

        for i, ret in enumerate(data_loader):
            # ori_imgs: b x 4 x 480 x 640 x 3
            # imgs: b x 4 x 3 x H x W
            # pose2d_gt: b x 4 x 21 x 2 (have not been transformed)
            # pose3d_gt: b x 21 x 3
            # visibility: b x 4 x 21
            # extrinsic matrix: b x 4 x 3 x 4
            # intrinsic matrix: b x 3 x 3
            # if i < count: continue
            imgs = ret['imgs'].to(device)
            orig_imgs = ret['orig_imgs']
            pose2d_gt, pose3d_gt, visibility = ret['pose2d'], ret[
                'pose3d'], ret['visibility']
            extrinsic_matrices, intrinsic_matrices = ret[
                'extrinsic_matrices'], ret['intrinsic_matrix']
            # somtimes intrisic_matrix has a shape of 3x3 or b x 3x3
            intrinsic_matrix = intrinsic_matrices[0] if len(
                intrinsic_matrices.shape) == 3 else intrinsic_matrices

            batch_size = orig_imgs.shape[0]
            n_joints = pose2d_gt.shape[2]
            pose2d_gt = pose2d_gt.view(
                -1, *pose2d_gt.shape[2:]).numpy()  # b*v x 21 x 2
            pose3d_gt = pose3d_gt.numpy()  # b x 21 x 3
            visibility = visibility.view(
                -1, visibility.shape[2]).numpy()  # b*v x 21

            if 'pose_hrnet' in cfg.MODEL.NAME:
                s1 = time.time()
                heatmaps, _ = model(imgs.view(
                    -1, *imgs.shape[2:]))  # b*v x 21 x 64 x 64
                pose2d_pred = get_final_preds(heatmaps, cfg).view(
                    batch_size, -1, n_joints, 2
                )  # b x v x 21 x 2 NOTE: the estimated 2D poses are located in the heatmap size 64(W) x 64(H)
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4
                # rescale to the original image before DLT
                pose2d_pred[:, :, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, :, 1:2] *= orig_height / heatmap_size[0]

                # 3D world coordinate 1 x 21 x 3
                pose3d_pred = torch.cat([
                    DLT_sii_pytorch(pose2d_pred[:, :, k],
                                    proj_matrices).unsqueeze(1)
                    for k in range(n_joints)
                ],
                                        dim=1)  # b x 21 x 3

                if i > 20:
                    infer_time[0] += 1
                    infer_time[1] += time.time() - s1
                    #print('FPS {:.1f}'.format(infer_time[0]/infer_time[1]))

            elif 'alg' == cfg.MODEL.NAME or 'ransac' == cfg.MODEL.NAME:
                s1 = time.time()
                # pose2d_pred: b x N_views x 21 x 2
                # NOTE: the estimated 2D poses are located in the original image of size 640(W) x 480(H)]
                # pose3d_pred: b x 21 x 3 [world coord]
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4
                pose3d_pred,\
                pose2d_pred,\
                heatmaps,\
                confidences_pred = model(imgs.to(device), proj_matrices.to(device))
                if i > 20:
                    infer_time[0] += 1
                    infer_time[1] += time.time() - s1

            elif "vol" in cfg.MODEL.NAME:
                intrinsic_matrix = update_after_resize(
                    intrinsic_matrix, (orig_height, orig_width),
                    tuple(heatmap_size))
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4
                s1 = time.time()

                # pose3d_pred (torch.tensor) b x 21 x 3
                # pose2d_pred (torch.tensor) b x v x 21 x 2 NOTE: the estimated 2D poses are located in the heatmap size 64(W) x 64(H)
                # heatmaps_pred (torch.tensor) b x v x 21 x 64 x 64
                # volumes_pred (torch.tensor)
                # confidences_pred (torch.tensor)
                # cuboids_pred (list)
                # coord_volumes_pred (torch.tensor)
                # base_points_pred (torch.tensor) b x v x 1 x 2
                if cfg.MODEL.BACKBONE_NAME == 'CPM_volumetric':
                    centermaps = ret['centermaps'].to(device)
                    heatmaps_gt = ret['heatmaps']

                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps_pred,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, centermaps, proj_matrices)
                else:
                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, proj_matrices)

                if i > 20:
                    infer_time[0] += 1
                    infer_time[1] += time.time() - s1

                pose2d_pred[:, :, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, :, 1:2] *= orig_height / heatmap_size[1]

            # 2D errors
            pose2d_gt[:, :, 0] *= orig_width / heatmap_size[0]
            pose2d_gt[:, :, 1] *= orig_height / heatmap_size[1]

            pose2d_pred = pose2d_pred.view(-1, n_joints,
                                           2).cpu().numpy()  # b*v x 21 x 2
            for k in range(21):
                print(pose2d_gt[0, k].tolist(), pose2d_pred[0, k].tolist())
            input()
            mse_each_joint = np.linalg.norm(pose2d_pred - pose2d_gt,
                                            axis=2) * visibility  # b*v x 21
            mse2d_lst += mse_each_joint.sum(axis=0)
            visibility_lst += visibility.sum(axis=0)

            for th_idx in range(len(th2d_lst)):
                PCK2d_lst[th_idx] += np.sum(
                    (mse_each_joint < th2d_lst[th_idx]) * visibility)

            # 3D errors
            for k in range(21):
                print(pose3d_gt[0, k].tolist(), pose3d_pred[0, k].tolist())
            input()
            visibility = visibility.reshape(
                (batch_size, -1, n_joints))  # b x v x 21
            for b in range(batch_size):
                # print(np.sum(visibility[b]), visibility[b].size)
                if np.sum(visibility[b]) >= visibility[b].size * 0.65:
                    n_valid += 1
                    mse_each_joint = np.linalg.norm(
                        pose3d_pred[b].cpu().numpy() - pose3d_gt[b],
                        axis=1)  # 21
                    mse3d_lst += mse_each_joint

                    for th_idx in range(len(th3d_lst)):
                        PCK3d_lst[th_idx] += np.sum(
                            mse_each_joint < th3d_lst[th_idx])

            if i % (len(data_loader) // 5) == 0:
                print("[Evaluation]{}% finished.".format(
                    20 * i // (len(data_loader) // 5)))
            #if i == 10:break
        print('Evaluation spent {:.2f} s\tFPS: {:.1f}'.format(
            time.time() - start_time, infer_time[0] / infer_time[1]))

        mse2d_lst /= visibility_lst
        PCK2d_lst /= visibility_lst.sum()
        mse3d_lst /= n_valid
        PCK3d_lst /= (n_valid * 21)
        plot_performance(PCK2d_lst, th2d_lst, PCK3d_lst, th3d_lst, mse2d_lst,
                         mse3d_lst)

        if not os.path.exists(result):
            os.mkdir(result)
        result_dir = prefix + cfg.EXP_NAME
        if not os.path.exists(result_dir):
            os.mkdir(result_dir)

        np.savetxt(os.path.join(result_dir, 'mse2d_each_joint.txt'),
                   mse2d_lst,
                   fmt='%.4f')
        np.savetxt(os.path.join(result_dir, 'mse3d_each_joint.txt'),
                   mse3d_lst,
                   fmt='%.4f')
        np.savetxt(os.path.join(result_dir, 'PCK2d.txt'),
                   np.stack((th2d_lst, PCK2d_lst)))
        np.savetxt(os.path.join(result_dir, 'PCK3d.txt'),
                   np.stack((th3d_lst, PCK3d_lst)))
Ejemplo n.º 10
0
    def forward(self, images, extrinsic_matrices=torch.rand((1,4,3,4)), intrinsic_matrices=torch.rand((1,3,3))):
        # images: b x v x 3 x H x W
        # extrinsic_mat (w2c): b x v x 3 x 4
        # intrinsic_mat (replicas of the same matreix): b x 3 x 3
        device = images.device
        batch_size, n_views = images.shape[:2]
        intrinsic_matrix = intrinsic_matrices[0]

        # reshape n_views dimension to batch dimension
        images = images.view(-1, *images.shape[2:])

        # 编码器:HRNet_w48+final_layer,输出bx(32+64+128+256=480)x18x18
        heatmaps, inter_feat = self.backbone(images) # inter_feat: b*v x 480 x 64 x 64
        feature_maps = self.encoder_head(inter_feat) # inter_feat: b*v x 240 x 18 x 18

        # 2. 重塑: (the last dimension stands for a homogeneous 2D image coord)
        reshaped_features = feature_maps.view(batch_size, n_views, feature_maps.shape[1], -1, 3) # b x v x 240 x 108 x 3

        # 3. FTL
        R_T = extrinsic_matrices[:,:,:,0:-1].transpose(2,3) # b x v x 3 x 3
        t_T = extrinsic_matrices[:,:,:,-1:].transpose(2,3) # b x v x 1 x 3
        intrinsic_matrix_T = intrinsic_matrix.T # 3 x 3

        canonical_feature_lst = []
        for v in range(n_views):
            # pose2d -> 3D cam coord
            canonical_feature = torch.matmul(reshaped_features[:,v], torch.inverse(intrinsic_matrix_T)) # b x 240 x 108 x 3
            # 3D cam coord -> 3D world coord.
            canonical_feature = torch.matmul(canonical_feature - t_T[:,v:v+1], torch.inverse(R_T[:,v:v+1]))  # b x 240 x 108 x 3
            # 4. 重塑
            canonical_feature_lst.append(canonical_feature.view(batch_size, *feature_maps.shape[1:])) # b x 240 x 18 x 18

        # 6. 合并张量:v个视角,合并得到bx240vx16x16
        canonical_feature_all_views = torch.cat(canonical_feature_lst, dim=1) # b x 240*v x 18 x 18

        # 7. 2层1x1卷积
        fused_features = self.fuse_after_FTL(canonical_feature_all_views).view(batch_size, *reshaped_features.shape[2:]) # b x 240 x 108 x 3

        # 8. FTL分发:bx240x16x16
        features_each_view_lst = []
        for v in range(n_views):
            features_each_view = torch.matmul(fused_features, R_T[:,v:v+1]) + t_T[:,v:v+1] # b x 240 x 108 x 3
            features_each_view = torch.matmul(features_each_view, intrinsic_matrix_T) # b x 240 x 108 x 3
            features_each_view_lst.append(features_each_view.view(batch_size, *feature_maps.shape[1:])) # b x 240 x 18 x 18

        features_all_views = torch.cat(features_each_view_lst, dim=0) # b*v x 240 x 18 x 18

        # 9. 1x1卷积
        expanded_features = self.channel_expansion(features_all_views) # b*v x 480 x 18 x 18

        # 10.   Decoder:
        #       转置卷积nn.ConvTranspose2d(16, 16, 3, stride=1,paddding=1):b*v x 480 x 18 x 18
        #       转置卷积nn.ConvTranspose2d(16, 16, 3, stride=2):b*v x 480 x 18 x 18
        #       转置卷积nn.ConvTranspose2d(16, 16, 3, stride=2,padding=1):b*v x 480 x 18 x 18
        #       3x3卷积层: bx21x64x64
        decoded_features = self.decoder(expanded_features) # b*v x 480 x 64 x 64

        # 11.   1x1卷积层输出热力图
        heatmaps = self.final_layer(decoded_features) # b*v x n_joints x 64 x 64

        # Apply 2D softmax to generate heatmaps
        heatmaps_flattened = heatmaps.view(heatmaps.shape[0], heatmaps.shape[1], -1) # b*v x n_joints x 64*64
        heatmaps_softmax = F.softmax(heatmaps_flattened, dim=2)
        heatmaps_pred = heatmaps_softmax.view(heatmaps.shape) # b*v x n_joints x 64 x 64
        
        pose2d_pred = get_final_preds(heatmaps_pred, use_softmax=True).view(batch_size, n_views, -1, 2) # b x v x n_joints x 2

        proj_matrices = torch.matmul(intrinsic_matrix, extrinsic_matrices) # b x v x 3 x 4

        pose3d_pred = torch.cat([DLT_sii_pytorch(proj_matrices, pose2d_pred[:,:,k]).unsqueeze(1) for k in range(pose2d_pred.shape[2])], dim=1)

        return heatmaps_pred, pose2d_pred, pose3d_pred
    def forward(self, images, centermaps=torch.rand(1,4,1,256,256), proj_matrices=torch.rand((1,4,3,4)), batch=None, keypoints_3d=None):
        # images: b x N_views x 3 x H x W
        # proj_matricies_batch (K*H): b x N_views x 3 x 4
        device = images.device
        batch_size, n_views = images.shape[:2]

        # reshape for backbone forward
        images = images.view(-1, *images.shape[2:])
        centermaps = centermaps.view(-1, *centermaps.shape[2:]) # b x 1 x 256 x 256
        # forward backbone
        # heatmaps: (b*N_views) x 22 x 64 x 64
        # features: (b*N_views) x 55 x 64 x 64
        # vol_confidences: b x 32
        _,_,_,_,_, heatmaps, features, vol_confidences = self.backbone(images, centermaps)
        # find the middle finger root position
        base_idx = 9
        # v2: 3D EPE 10.7802
        pose2d_pred = get_final_preds(heatmaps[:,1:], use_softmax=self.heatmap_softmax).view(batch_size, n_views, heatmaps.shape[1] - 1, 2) # batch_size x N_views x 21 x 2
        base_points = torch.cat([DLT_pytorch(pose2d_pred[b,:,base_idx:base_idx+1], proj_matrices[b]) for b in range(batch_size)], dim=0)

        # V3: 
        # pose2d_pred_temp = get_final_preds(heatmaps[:,1:], use_softmax=self.heatmap_softmax).view(batch_size, n_views, heatmaps.shape[1] - 1, 2) # batch_size x N_views x 21 x 2
        # pose2d_pred = torch.zeros(pose2d_pred_temp.shape, dtype=pose2d_pred_temp.dtype, device=pose2d_pred_temp.device)
        # orig_width, orig_height = self.orig_img_size
        # pose2d_pred[:,:,:,0] = pose2d_pred_temp[:,:,:,0] * orig_width / 64
        # pose2d_pred[:,:,:,1] = pose2d_pred_temp[:,:,:,1] * orig_height / 64
        # base_points = DLT_sii_pytorch(pose2d_pred[:,:,base_idx], proj_matrices) # b x 3

        # reshape back
        images = images.view(batch_size, n_views, *images.shape[1:])
        heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])
        features = features.view(batch_size, n_views, *features.shape[1:])

        if vol_confidences is not None:
            vol_confidences = vol_confidences.view(batch_size, n_views, *vol_confidences.shape[1:])

        # calcualte shapes
        image_shape, heatmap_shape = tuple(images.shape[3:]), tuple(heatmaps.shape[3:])

        # norm vol confidences
        if self.volume_aggregation_method == 'conf_norm':
            vol_confidences = vol_confidences / vol_confidences.sum(dim=1, keepdim=True)

        # camera intrinsics already changed in function3D.py

        # new_cameras = deepcopy(batch['cameras'])
        # for view_i in range(n_views):
        #     for batch_i in range(batch_size):
        #         new_cameras[view_i][batch_i].update_after_resize(image_shape, heatmap_shape)

        #proj_matrices = torch.stack([torch.stack([torch.from_numpy(camera.projection) for camera in camera_batch], dim=0) for camera_batch in new_cameras], dim=0).transpose(1, 0)  # shape (batch_size, n_views, 3, 4)
        #proj_matrices = proj_matrices.float().to(device)

        # build coord volumes
        cuboids = []
        coord_volumes = torch.zeros(batch_size, self.volume_size, self.volume_size, self.volume_size, 3, device=device)
        for batch_i in range(batch_size):
            # if self.use_precalculated_pelvis:
            # if self.use_gt_middleroot:
            #     keypoints_3d = batch['keypoints_3d'][batch_i]
            # else:
            base_point = base_points[batch_i]

            # build cuboid
            sides = torch.tensor([self.cuboid_side, self.cuboid_side, self.cuboid_side], device=base_points.device) # 2500 x 2500 x 2500 (mm)
            position = base_point - sides / 2
            cuboid = volumetric.Cuboid3D(position, sides)

            cuboids.append(cuboid)

            # build coord volume, dividing the cubic length (2500mm) into 63 segments. NOTE: meshgrid returns a tuple
            xxx, yyy, zzz = torch.meshgrid(torch.arange(self.volume_size, device=device), torch.arange(self.volume_size, device=device), torch.arange(self.volume_size, device=device))
            grid = torch.stack([xxx, yyy, zzz], dim=-1).type(torch.float) # 64 x 64 x 64 x 3
            grid = grid.reshape((-1, 3))

            grid_coord = torch.zeros_like(grid)
            # self.volume_size - 1 because we fill the cube with the global coords of each voxel's center
            # the elements of the grid are bound in [0,63)
            grid_coord[:, 0] = position[0] + (sides[0] / (self.volume_size - 1)) * grid[:, 0]
            grid_coord[:, 1] = position[1] + (sides[1] / (self.volume_size - 1)) * grid[:, 1]
            grid_coord[:, 2] = position[2] + (sides[2] / (self.volume_size - 1)) * grid[:, 2]

            coord_volume = grid_coord.reshape(self.volume_size, self.volume_size, self.volume_size, 3)

            # random rotation
            if self.training:
                theta = np.random.uniform(0.0, 2 * np.pi)
            else:
                theta = 0.0

            axis = [0, 1, 0]  # y axis

            # rotate
            coord_volume = coord_volume - base_point
            coord_volume = volumetric.rotate_coord_volume(coord_volume, theta, axis)
            coord_volume = coord_volume + base_point

            # transfer
            # if self.transfer_cmu_to_human36m:  # different world coordinates
            #     coord_volume = coord_volume.permute(0, 2, 1, 3)
            #     inv_idx = torch.arange(coord_volume.shape[1] - 1, -1, -1).long().to(device)
            #     coord_volume = coord_volume.index_select(1, inv_idx)

            coord_volumes[batch_i] = coord_volume # batch_size x 64 x 64 x 64 x 3

        # process features before unprojecting
        features = features.view(-1, *features.shape[2:])
        features = self.process_features(features) # 32 output channels
        features = features.view(batch_size, n_views, *features.shape[1:])

        # lift to volume: b x 32 x 64 x 64 x 64
        volumes = op.unproject_heatmaps(features, proj_matrices, coord_volumes, volume_aggregation_method=self.volume_aggregation_method, vol_confidences=vol_confidences)

        # integral 3d
        volumes = self.volume_net(volumes)
        vol_keypoints_3d, volumes = op.integrate_tensor_3d_with_coordinates(volumes * self.volume_multiplier, coord_volumes, softmax=self.volume_softmax)

        return vol_keypoints_3d, pose2d_pred, heatmaps, volumes, vol_confidences, coord_volumes, base_points