def vis_filter(ref_depth, reproj_xyd, in_range, img_dist_thresh, depth_thresh, vthresh): n, v, _, h, w = reproj_xyd.size() xy = get_pixel_grids(h, w).permute(3,2,0,1).unsqueeze(1)[:,:,:2] # 112hw dist_masks = (reproj_xyd[:,:,:2,:,:] - xy).norm(dim=2, keepdim=True) < img_dist_thresh # nv1hw depth_masks = (ref_depth.unsqueeze(1) - reproj_xyd[:,:,2:,:,:]).abs() < (torch.max(ref_depth.unsqueeze(1), reproj_xyd[:,:,2:,:,:])*depth_thresh) # nv1hw masks = bin_op_reduce([in_range, dist_masks.to(ref_depth.dtype), depth_masks.to(ref_depth.dtype)], torch.min) # nv1hw mask = masks.sum(dim=1) >= (vthresh-1.1) # n1hw return masks, mask
def project_img(src_img, dst_depth, src_cam, dst_cam, height=None, width=None): # nchw, n1hw -> nchw, n1hw if height is None: height = src_img.size()[-2] if width is None: width = src_img.size()[-1] dst_idx_img_homo = get_pixel_grids(height, width).unsqueeze(0) # nhw31 dst_idx_cam_homo = idx_img2cam(dst_idx_img_homo, dst_depth, dst_cam) # nhw41 dst_idx_world_homo = idx_cam2world(dst_idx_cam_homo, dst_cam) # nhw41 dst2src_idx_cam_homo = idx_world2cam(dst_idx_world_homo, src_cam) # nhw41 dst2src_idx_img_homo = idx_cam2img(dst2src_idx_cam_homo, src_cam) # nhw31 warp_coord = dst2src_idx_img_homo[...,:2,0] # nhw2 warp_coord[..., 0] /= width warp_coord[..., 1] /= height warp_coord = (warp_coord*2-1).clamp(-1.1, 1.1) # nhw2 in_range = bin_op_reduce([-1<=warp_coord[...,0], warp_coord[...,0]<=1, -1<=warp_coord[...,1], warp_coord[...,1]<=1], torch.min).to(src_img.dtype).unsqueeze(1) # n1hw warped_img = F.grid_sample(src_img, warp_coord, mode='bilinear', padding_mode='zeros', align_corners=False) return warped_img, in_range
def forward(self, outputs, gt, masks, ref_cam, max_d, occ_guide=False, mode='soft'): #MVS outputs, refined_depth = outputs depth_start = ref_cam[:, 1:2, 3:4, 0:1] # n111 depth_interval = ref_cam[:, 1:2, 3:4, 1:2] # n111 depth_end = depth_start + (max_d - 2) * depth_interval # strict range masks = [masks[:, i, ...] for i in range(masks.size()[1])] stage_losses = [] stats = [] for est_depth, pair_results in outputs: gt_downsized = F.interpolate(gt, size=(est_depth.size()[2], est_depth.size()[3]), mode='bilinear', align_corners=False) masks_downsized = [ F.interpolate(mask, size=(est_depth.size()[2], est_depth.size()[3]), mode='nearest') for mask in masks ] in_range = torch.min((gt_downsized >= depth_start), (gt_downsized <= depth_end)) masks_valid = [ torch.min((mask > 50), in_range) for mask in masks_downsized ] # mask and in_range masks_overlap = [ torch.min((mask > 200), in_range) for mask in masks_downsized ] union_overlap = bin_op_reduce(masks_overlap, torch.max) # A(B+C)=AB+AC valid = union_overlap if occ_guide else in_range same_size = est_depth.size()[2] == pair_results[0][0].size( )[2] and est_depth.size()[3] == pair_results[0][0].size()[3] gt_interm = F.interpolate( gt, size=(pair_results[0][0].size()[2], pair_results[0][0].size()[3]), mode='bilinear', align_corners=False) if not same_size else gt_downsized masks_interm = [ F.interpolate(mask, size=(pair_results[0][0].size()[2], pair_results[0][0].size()[3]), mode='nearest') for mask in masks ] if not same_size else masks_downsized in_range_interm = torch.min( (gt_interm >= depth_start), (gt_interm <= depth_end)) if not same_size else in_range masks_valid_interm = [ torch.min((mask > 50), in_range_interm) for mask in masks_interm ] if not same_size else masks_valid # mask and in_range masks_overlap_interm = [ torch.min((mask > 200), in_range_interm) for mask in masks_interm ] if not same_size else masks_overlap union_overlap_interm = bin_op_reduce( masks_overlap_interm, torch.max) if not same_size else union_overlap # A(B+C)=AB+AC valid_interm = (union_overlap_interm if occ_guide else in_range_interm) if not same_size else valid abs_err = (est_depth - gt_downsized).abs() abs_err_scaled = abs_err / depth_interval pair_abs_err = [ (est - gt_interm).abs() for est in [est for est, (uncert, occ) in pair_results] ] pair_abs_err_scaled = [ err / depth_interval for err in pair_abs_err ] l1 = abs_err_scaled[valid].mean() # ===== pair l1 ===== if occ_guide: pair_l1_losses = [ err[mask_overlap].mean() for err, mask_overlap in zip( pair_abs_err_scaled, masks_overlap_interm) ] else: pair_l1_losses = [ err[in_range_interm].mean() for err in pair_abs_err_scaled ] pair_l1_loss = sum(pair_l1_losses) / len(pair_l1_losses) # ===== uncert ===== if mode in ['soft', 'hard', 'uwta']: if occ_guide: uncert_losses = [ (err[mask_valid] * (-uncert[mask_valid]).exp() + uncert[mask_valid]).mean() for err, (est, (uncert, occ)), mask_valid, mask_overlap in zip(pair_abs_err_scaled, pair_results, masks_valid_interm, masks_overlap_interm) ] else: uncert_losses = [ (err[in_range_interm] * (-uncert[in_range_interm]).exp() + uncert[in_range_interm]).mean() for err, (est, ( uncert, occ)) in zip(pair_abs_err_scaled, pair_results) ] uncert_loss = sum(uncert_losses) / len(uncert_losses) # ===== logistic ===== if occ_guide and mode in ['soft', 'hard', 'uwta']: logistic_losses = [ nn.SoftMarginLoss()( occ[mask_valid], -mask_overlap[mask_valid].to(gt.dtype) * 2 + 1) for (est, (uncert, occ)), mask_valid, mask_overlap in zip( pair_results, masks_valid_interm, masks_overlap_interm) ] logistic_loss = sum(logistic_losses) / len(logistic_losses) less1 = (abs_err_scaled[valid] < 1.).to(gt.dtype).mean() less3 = (abs_err_scaled[valid] < 3.).to(gt.dtype).mean() pair_loss = pair_l1_loss if mode in ['soft', 'hard', 'uwta']: pair_loss = pair_loss + uncert_loss if occ_guide: pair_loss = pair_loss + logistic_loss loss = l1 + pair_loss stage_losses.append(loss) stats.append((l1, less1, less3)) abs_err = (refined_depth - gt_downsized).abs() abs_err_scaled = abs_err / depth_interval l1 = abs_err_scaled[valid].mean() less1 = (abs_err_scaled[valid] < 1.).to(gt.dtype).mean() less3 = (abs_err_scaled[valid] < 3.).to(gt.dtype).mean() loss = stage_losses[0] * 0.5 + stage_losses[1] * 1.0 + stage_losses[ 2] * 2.0 # + l1*2.0 return loss, pair_loss, less1, less3, l1, stats, abs_err_scaled, valid
views = {} pbar = tqdm.tqdm(loader, dynamic_ncols=True) for sample_np in pbar: if sample_np.get('skip') is not None and np.any(sample_np['skip']): continue sample = {attr: torch.from_numpy(sample_np[attr]).float().cuda() for attr in sample_np if attr not in ['skip', 'id']} prob_mask = prob_filter(sample['ref_probs'], pthresh) reproj_xyd, in_range = get_reproj(*[sample[attr] for attr in ['ref_depth', 'srcs_depth', 'ref_cam', 'srcs_cam']]) vis_masks, vis_mask = vis_filter(sample['ref_depth'], reproj_xyd, in_range, 1, 0.01, args.vthresh) ref_depth_ave = ave_fusion(sample['ref_depth'], reproj_xyd, vis_masks) mask = bin_op_reduce([prob_mask, vis_mask], torch.min) if args.show_result: subplot_map([ [sample['ref_depth'][0,0].cpu().data.numpy(), ref_depth_ave[0,0].cpu().data.numpy(), (ref_depth_ave*mask)[0,0].cpu().data.numpy()], [prob_mask[0,0].cpu().data.numpy(), vis_mask[0,0].cpu().data.numpy(), mask[0,0].cpu().data.numpy()] ]) plt.show() idx_img = get_pixel_grids(*ref_depth_ave.size()[-2:]).unsqueeze(0) idx_cam = idx_img2cam(idx_img, ref_depth_ave, sample['ref_cam']) points = idx_cam2world(idx_cam, sample['ref_cam'])[...,:3,0].permute(0,3,1,2) cam_center_np = (- sample['ref_cam'][:,0,:3,:3].transpose(-2,-1) @ sample['ref_cam'][:,0,:3,3:])[...,0].cpu().numpy() # n3 points_np = points.cpu().data.numpy() mask_np = mask.cpu().data.numpy() for i in range(points_np.shape[0]):