def vis_filter(ref_depth, reproj_xyd, in_range, img_dist_thresh, depth_thresh, vthresh): n, v, _, h, w = reproj_xyd.size() xy = get_pixel_grids(h, w).permute(3,2,0,1).unsqueeze(1)[:,:,:2] # 112hw dist_masks = (reproj_xyd[:,:,:2,:,:] - xy).norm(dim=2, keepdim=True) < img_dist_thresh # nv1hw depth_masks = (ref_depth.unsqueeze(1) - reproj_xyd[:,:,2:,:,:]).abs() < (torch.max(ref_depth.unsqueeze(1), reproj_xyd[:,:,2:,:,:])*depth_thresh) # nv1hw masks = in_range * dist_masks.to(ref_depth.dtype) * depth_masks.to(ref_depth.dtype) # nv1hw mask = masks.sum(dim=1) >= (vthresh-1.1) # n1hw return masks, mask
def forward(self, init_d, ref_feat, ref_cam, srcs_feat, srcs_cam, s_scale): """ d Fi d p2 d p3 J = ------ * ------ * ------ d p2 d p3 d d c*2 2*3 3*1 """ n, c, h, w = ref_feat.size() ref_cam_scaled, *srcs_cam_scaled = [ scale_camera(cam, 1 / s_scale) for cam in [ref_cam] + srcs_cam ] J_list = [] r_list = [] for src_feat, src_cam_scaled in zip(srcs_feat, srcs_cam_scaled): H = get_homographies( ref_cam_scaled, src_cam_scaled, 1, init_d.detach(), torch.ones(n, 1, 1, 1, dtype=torch.float32).cuda()).squeeze(1) # nhw33 # copied from homography.py with torch.no_grad(): pixel_grids = get_pixel_grids(*src_feat.size()[-2:]).unsqueeze( 0) # 1hw31 warped_homo_coord = (H @ pixel_grids).squeeze(-1) # nhw3 warped_coord = warped_homo_coord[..., :2] / ( warped_homo_coord[..., 2:3] + 1e-9) # nhw2 warped = interpolate(src_feat, warped_coord) residual = (warped - ref_feat).permute(0, 2, 3, 1).unsqueeze(-1) # nhwc1 src_grad = self.sobel_conv(src_feat) # n c*2 hw src_grad_warped = interpolate(src_grad, warped_coord).permute( 0, 2, 3, 1).reshape(n, h, w, c, 2) # nhwc2 d3to2_1 = torch.eye(2, dtype=torch.float32).cuda().view( 1, 1, 1, 2, 2) / (warped_homo_coord[..., -1].view( n, h, w, 1, 1) + 1e-9) # nhw22 d3to2_2 = warped_coord.unsqueeze(-1) / ( warped_homo_coord[..., -1].view(n, h, w, 1, 1) + 1e-9) # nhw21 d3to2 = torch.cat([d3to2_1, d3to2_2], dim=-1) # nhw23 Ki = src_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3) # n1133 K0 = ref_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3) Ri = src_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3) R0 = ref_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3) # dptod = Ki @ Ri @ R0.inverse() @ K0.inverse() @ pixel_grids # nhw31 dptod = (Ki @ Ri @ R0.inverse() @ K0.inverse() - H) @ pixel_grids / init_d.detach().view(n, h, w, 1, 1) Ji = src_grad_warped @ d3to2 @ dptod # nhwc1 r_list.append(residual) J_list.append(Ji) J, r = [torch.cat(l, dim=-2) for l in [J_list, r_list]] delta = (-(J.transpose(-1, -2) @ r) / (J.transpose(-1, -2) @ J + 1e-9)).reshape(n, 1, h, w) if (delta != delta).any(): raise NanError # delta = delta.clamp(-1, 1) # plt.imshow(delta[0,0,...].clone().cpu().data.numpy()) # plt.show() refined_d = init_d + delta return refined_d
def project_img(src_img, dst_depth, src_cam, dst_cam, height=None, width=None): # nchw, n1hw -> nchw, n1hw if height is None: height = src_img.size()[-2] if width is None: width = src_img.size()[-1] dst_idx_img_homo = get_pixel_grids(height, width).unsqueeze(0) # nhw31 dst_idx_cam_homo = idx_img2cam(dst_idx_img_homo, dst_depth, dst_cam) # nhw41 dst_idx_world_homo = idx_cam2world(dst_idx_cam_homo, dst_cam) # nhw41 dst2src_idx_cam_homo = idx_world2cam(dst_idx_world_homo, src_cam) # nhw41 dst2src_idx_img_homo = idx_cam2img(dst2src_idx_cam_homo, src_cam) # nhw31 warp_coord = dst2src_idx_img_homo[...,:2,0] # nhw2 warp_coord[..., 0] /= width warp_coord[..., 1] /= height warp_coord = (warp_coord*2-1).clamp(-1.1, 1.1) # nhw2 in_range = bin_op_reduce([-1<=warp_coord[...,0], warp_coord[...,0]<=1, -1<=warp_coord[...,1], warp_coord[...,1]<=1], torch.min).to(src_img.dtype).unsqueeze(1) # n1hw warped_img = F.grid_sample(src_img, warp_coord, mode='bilinear', padding_mode='zeros', align_corners=False) return warped_img, in_range
def get_reproj(ref_depth, srcs_depth, ref_cam, srcs_cam): # n1hw, nv1hw -> n1hw n, v, _, h, w = srcs_depth.size() srcs_depth_f = srcs_depth.view(n*v, 1, h, w) srcs_cam_f = srcs_cam.view(n*v, 2, 4, 4) ref_depth_r = ref_depth.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 1, h, w) ref_cam_r = ref_cam.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 2, 4, 4) idx_img = get_pixel_grids(h, w).unsqueeze(0) # 1hw31 srcs_idx_cam = idx_img2cam(idx_img, srcs_depth_f, srcs_cam_f) # Nhw41 srcs_idx_world = idx_cam2world(srcs_idx_cam, srcs_cam_f) # Nhw41 srcs2ref_idx_cam = idx_world2cam(srcs_idx_world, ref_cam_r) # Nhw41 srcs2ref_idx_img = idx_cam2img(srcs2ref_idx_cam, ref_cam_r) # Nhw31 srcs2ref_xyd = torch.cat([srcs2ref_idx_img[...,:2,0], srcs2ref_idx_cam[...,2:3,0]], dim=-1).permute(0,3,1,2) # N3hw reproj_xyd_f, in_range_f= project_img(srcs2ref_xyd, ref_depth_r, srcs_cam_f, ref_cam_r) # N3hw, N1hw reproj_xyd = reproj_xyd_f.view(n,v,3,h,w) in_range = in_range_f.view(n,v,1,h,w) return reproj_xyd, in_range
def get_reproj(ref_depth, srcs_depth, ref_cam, srcs_cam): # n1hw, nv1hw -> nv3hw, nv1hw n, v, _, h, w = srcs_depth.size() srcs_depth_f = srcs_depth.view(n*v, 1, h, w) srcs_valid_f = (srcs_depth_f > 1e-9).to(srcs_depth_f.dtype) srcs_cam_f = srcs_cam.view(n*v, 2, 4, 4) ref_depth_r = ref_depth.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 1, h, w) ref_cam_r = ref_cam.unsqueeze(1).repeat(1,v,1,1,1).view(n*v, 2, 4, 4) idx_img = get_pixel_grids(h, w).unsqueeze(0) # 1hw31 srcs_idx_cam = idx_img2cam(idx_img, srcs_depth_f, srcs_cam_f) # Nhw41 srcs_idx_world = idx_cam2world(srcs_idx_cam, srcs_cam_f) # Nhw41 srcs2ref_idx_cam = idx_world2cam(srcs_idx_world, ref_cam_r) # Nhw41 srcs2ref_idx_img = idx_cam2img(srcs2ref_idx_cam, ref_cam_r) # Nhw31 srcs2ref_xydv = torch.cat([srcs2ref_idx_img[...,:2,0], srcs2ref_idx_cam[...,2:3,0], srcs_valid_f.permute(0,2,3,1)], dim=-1).permute(0,3,1,2) # N4hw reproj_xydv_f, in_range_f= project_img(srcs2ref_xydv, ref_depth_r, srcs_cam_f, ref_cam_r) # N4hw, N1hw reproj_xyd = reproj_xydv_f.view(n,v,4,h,w)[:,:,:3] in_range = (in_range_f * reproj_xydv_f[:,3:]).view(n,v,1,h,w) return reproj_xyd, in_range
reproj_xyd_g, in_range_g = get_reproj(ref_depth_g, srcs_depth_g, ref_cam_g, srcs_cam_g) vis_masks_g, vis_mask_g = vis_filter(ref_depth_g, reproj_xyd_g, in_range_g, 1, 0.01, args.vthresh) update[id] = { 'mask': vis_mask_g.cpu() } del ref_depth_g, ref_cam_g, srcs_depth_g, srcs_cam_g, reproj_xyd_g, in_range_g, vis_masks_g, vis_mask_g for i, id in enumerate(pair['id_list']): views[id]['mask'] = views[id]['mask'] & update[id]['mask'] views[id]['depth'] *= views[id]['mask'] pcds = {} for i, id in tqdm(enumerate(pair['id_list']), 'back proj', n_views): ref_depth_g, ref_cam_g = views[id]['depth'].cuda(), views[id]['cam'].cuda() idx_img_g = get_pixel_grids(*ref_depth_g.size()[-2:]).unsqueeze(0) idx_cam_g = idx_img2cam(idx_img_g, ref_depth_g, ref_cam_g) points_g = idx_cam2world(idx_cam_g, ref_cam_g)[...,:3,0] # nhw3 cam_center_g = (- ref_cam_g[:,0,:3,:3].transpose(-2,-1) @ ref_cam_g[:,0,:3,3:])[...,0] # n3 dir_vec_g = cam_center_g.reshape(-1,1,1,3) - points_g # nhw3 p_f = points_g.cpu()[ views[id]['mask'].squeeze(1) ] # m3 c_f = views[id]['image'].permute(0,2,3,1)[ views[id]['mask'].squeeze(1) ] / 255 # m3 d_f = dir_vec_g.cpu()[ views[id]['mask'].squeeze(1) ] # m3 pcds[id] = { 'points': p_f, 'colors': c_f, 'dirs': d_f, } del views[id]