Exemplo n.º 1
0
def vis_filter(ref_depth, reproj_xyd, in_range, img_dist_thresh, depth_thresh, vthresh):
    n, v, _, h, w = reproj_xyd.size()
    xy = get_pixel_grids(h, w).permute(3,2,0,1).unsqueeze(1)[:,:,:2]  # 112hw
    dist_masks = (reproj_xyd[:,:,:2,:,:] - xy).norm(dim=2, keepdim=True) < img_dist_thresh  # nv1hw
    depth_masks = (ref_depth.unsqueeze(1) - reproj_xyd[:,:,2:,:,:]).abs() < (torch.max(ref_depth.unsqueeze(1), reproj_xyd[:,:,2:,:,:])*depth_thresh)  # nv1hw
    masks = in_range * dist_masks.to(ref_depth.dtype) * depth_masks.to(ref_depth.dtype)  # nv1hw
    mask = masks.sum(dim=1) >= (vthresh-1.1)  # n1hw
    return masks, mask
Exemplo n.º 2
0
 def forward(self, init_d, ref_feat, ref_cam, srcs_feat, srcs_cam, s_scale):
     """
          d Fi     d p2     d p3
     J = ------ * ------ * ------
          d p2     d p3     d d
          c*2      2*3      3*1
     """
     n, c, h, w = ref_feat.size()
     ref_cam_scaled, *srcs_cam_scaled = [
         scale_camera(cam, 1 / s_scale) for cam in [ref_cam] + srcs_cam
     ]
     J_list = []
     r_list = []
     for src_feat, src_cam_scaled in zip(srcs_feat, srcs_cam_scaled):
         H = get_homographies(
             ref_cam_scaled, src_cam_scaled, 1, init_d.detach(),
             torch.ones(n, 1, 1, 1,
                        dtype=torch.float32).cuda()).squeeze(1)  # nhw33
         # copied from homography.py
         with torch.no_grad():
             pixel_grids = get_pixel_grids(*src_feat.size()[-2:]).unsqueeze(
                 0)  # 1hw31
             warped_homo_coord = (H @ pixel_grids).squeeze(-1)  # nhw3
             warped_coord = warped_homo_coord[..., :2] / (
                 warped_homo_coord[..., 2:3] + 1e-9)  # nhw2
         warped = interpolate(src_feat, warped_coord)
         residual = (warped - ref_feat).permute(0, 2, 3,
                                                1).unsqueeze(-1)  # nhwc1
         src_grad = self.sobel_conv(src_feat)  # n c*2 hw
         src_grad_warped = interpolate(src_grad, warped_coord).permute(
             0, 2, 3, 1).reshape(n, h, w, c, 2)  # nhwc2
         d3to2_1 = torch.eye(2, dtype=torch.float32).cuda().view(
             1, 1, 1, 2, 2) / (warped_homo_coord[..., -1].view(
                 n, h, w, 1, 1) + 1e-9)  # nhw22
         d3to2_2 = warped_coord.unsqueeze(-1) / (
             warped_homo_coord[..., -1].view(n, h, w, 1, 1) + 1e-9)  # nhw21
         d3to2 = torch.cat([d3to2_1, d3to2_2], dim=-1)  # nhw23
         Ki = src_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3)  # n1133
         K0 = ref_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3)
         Ri = src_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3)
         R0 = ref_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3)
         # dptod = Ki @ Ri @ R0.inverse() @ K0.inverse() @ pixel_grids  # nhw31
         dptod = (Ki @ Ri @ R0.inverse() @ K0.inverse() -
                  H) @ pixel_grids / init_d.detach().view(n, h, w, 1, 1)
         Ji = src_grad_warped @ d3to2 @ dptod  # nhwc1
         r_list.append(residual)
         J_list.append(Ji)
     J, r = [torch.cat(l, dim=-2) for l in [J_list, r_list]]
     delta = (-(J.transpose(-1, -2) @ r) /
              (J.transpose(-1, -2) @ J + 1e-9)).reshape(n, 1, h, w)
     if (delta != delta).any():
         raise NanError
     # delta = delta.clamp(-1, 1)
     # plt.imshow(delta[0,0,...].clone().cpu().data.numpy())
     # plt.show()
     refined_d = init_d + delta
     return refined_d
Exemplo n.º 3
0
def project_img(src_img, dst_depth, src_cam, dst_cam, height=None, width=None):  # nchw, n1hw -> nchw, n1hw
    if height is None: height = src_img.size()[-2]
    if width is None: width = src_img.size()[-1]
    dst_idx_img_homo = get_pixel_grids(height, width).unsqueeze(0)  # nhw31
    dst_idx_cam_homo = idx_img2cam(dst_idx_img_homo, dst_depth, dst_cam)  # nhw41
    dst_idx_world_homo = idx_cam2world(dst_idx_cam_homo, dst_cam)  # nhw41
    dst2src_idx_cam_homo = idx_world2cam(dst_idx_world_homo, src_cam)  # nhw41
    dst2src_idx_img_homo = idx_cam2img(dst2src_idx_cam_homo, src_cam)  # nhw31
    warp_coord = dst2src_idx_img_homo[...,:2,0]  # nhw2
    warp_coord[..., 0] /= width
    warp_coord[..., 1] /= height
    warp_coord = (warp_coord*2-1).clamp(-1.1, 1.1)  # nhw2
    in_range = bin_op_reduce([-1<=warp_coord[...,0], warp_coord[...,0]<=1, -1<=warp_coord[...,1], warp_coord[...,1]<=1], torch.min).to(src_img.dtype).unsqueeze(1)  # n1hw
    warped_img = F.grid_sample(src_img, warp_coord, mode='bilinear', padding_mode='zeros', align_corners=False)
    return warped_img, in_range
Exemplo n.º 4
0
def get_reproj(ref_depth, srcs_depth, ref_cam, srcs_cam):  # n1hw, nv1hw -> n1hw
    n, v, _, h, w = srcs_depth.size()
    srcs_depth_f = srcs_depth.view(n*v, 1, h, w)
    srcs_cam_f = srcs_cam.view(n*v, 2, 4, 4)
    ref_depth_r = ref_depth.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 1, h, w)
    ref_cam_r = ref_cam.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 2, 4, 4)
    idx_img = get_pixel_grids(h, w).unsqueeze(0)  # 1hw31

    srcs_idx_cam = idx_img2cam(idx_img, srcs_depth_f, srcs_cam_f)  # Nhw41
    srcs_idx_world = idx_cam2world(srcs_idx_cam, srcs_cam_f)  # Nhw41
    srcs2ref_idx_cam = idx_world2cam(srcs_idx_world, ref_cam_r)  # Nhw41
    srcs2ref_idx_img = idx_cam2img(srcs2ref_idx_cam, ref_cam_r)  # Nhw31
    srcs2ref_xyd = torch.cat([srcs2ref_idx_img[...,:2,0], srcs2ref_idx_cam[...,2:3,0]], dim=-1).permute(0,3,1,2)  # N3hw

    reproj_xyd_f, in_range_f= project_img(srcs2ref_xyd, ref_depth_r, srcs_cam_f, ref_cam_r)  # N3hw, N1hw
    reproj_xyd = reproj_xyd_f.view(n,v,3,h,w)
    in_range = in_range_f.view(n,v,1,h,w)
    return reproj_xyd, in_range
Exemplo n.º 5
0
def get_reproj(ref_depth, srcs_depth, ref_cam, srcs_cam):  # n1hw, nv1hw -> nv3hw, nv1hw
    n, v, _, h, w = srcs_depth.size()
    srcs_depth_f = srcs_depth.view(n*v, 1, h, w)
    srcs_valid_f = (srcs_depth_f > 1e-9).to(srcs_depth_f.dtype)
    srcs_cam_f = srcs_cam.view(n*v, 2, 4, 4)
    ref_depth_r = ref_depth.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 1, h, w)
    ref_cam_r = ref_cam.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 2, 4, 4)
    idx_img = get_pixel_grids(h, w).unsqueeze(0)  # 1hw31

    srcs_idx_cam = idx_img2cam(idx_img, srcs_depth_f, srcs_cam_f)  # Nhw41
    srcs_idx_world = idx_cam2world(srcs_idx_cam, srcs_cam_f)  # Nhw41
    srcs2ref_idx_cam = idx_world2cam(srcs_idx_world, ref_cam_r)  # Nhw41
    srcs2ref_idx_img = idx_cam2img(srcs2ref_idx_cam, ref_cam_r)  # Nhw31
    srcs2ref_xydv = torch.cat([srcs2ref_idx_img[...,:2,0], srcs2ref_idx_cam[...,2:3,0], srcs_valid_f.permute(0,2,3,1)], dim=-1).permute(0,3,1,2)  # N4hw

    reproj_xydv_f, in_range_f= project_img(srcs2ref_xydv, ref_depth_r, srcs_cam_f, ref_cam_r)  # N4hw, N1hw
    reproj_xyd = reproj_xydv_f.view(n,v,4,h,w)[:,:,:3]
    in_range = (in_range_f * reproj_xydv_f[:,3:]).view(n,v,1,h,w)
    return reproj_xyd, in_range
Exemplo n.º 6
0
        reproj_xyd_g, in_range_g = get_reproj(ref_depth_g, srcs_depth_g, ref_cam_g, srcs_cam_g)
        vis_masks_g, vis_mask_g = vis_filter(ref_depth_g, reproj_xyd_g, in_range_g, 1, 0.01, args.vthresh)

        update[id] = {
            'mask': vis_mask_g.cpu()
        }
        del ref_depth_g, ref_cam_g, srcs_depth_g, srcs_cam_g, reproj_xyd_g, in_range_g, vis_masks_g, vis_mask_g
    for i, id in enumerate(pair['id_list']):
        views[id]['mask'] = views[id]['mask'] & update[id]['mask']
        views[id]['depth'] *= views[id]['mask']

    pcds = {}
    for i, id in tqdm(enumerate(pair['id_list']), 'back proj', n_views):
        ref_depth_g, ref_cam_g = views[id]['depth'].cuda(), views[id]['cam'].cuda()

        idx_img_g = get_pixel_grids(*ref_depth_g.size()[-2:]).unsqueeze(0)
        idx_cam_g = idx_img2cam(idx_img_g, ref_depth_g, ref_cam_g)
        points_g = idx_cam2world(idx_cam_g, ref_cam_g)[...,:3,0]  # nhw3
        cam_center_g = (- ref_cam_g[:,0,:3,:3].transpose(-2,-1) @ ref_cam_g[:,0,:3,3:])[...,0]  # n3
        dir_vec_g = cam_center_g.reshape(-1,1,1,3) - points_g  # nhw3

        p_f = points_g.cpu()[ views[id]['mask'].squeeze(1) ]  # m3
        c_f = views[id]['image'].permute(0,2,3,1)[ views[id]['mask'].squeeze(1) ] / 255  # m3
        d_f = dir_vec_g.cpu()[ views[id]['mask'].squeeze(1) ]  # m3
        
        pcds[id] = {
            'points': p_f,
            'colors': c_f,
            'dirs': d_f,
        }
        del views[id]