예제 #1
0
def vis_filter(ref_depth, reproj_xyd, in_range, img_dist_thresh, depth_thresh, vthresh):
    n, v, _, h, w = reproj_xyd.size()
    xy = get_pixel_grids(h, w).permute(3,2,0,1).unsqueeze(1)[:,:,:2]  # 112hw
    dist_masks = (reproj_xyd[:,:,:2,:,:] - xy).norm(dim=2, keepdim=True) < img_dist_thresh  # nv1hw
    depth_masks = (ref_depth.unsqueeze(1) - reproj_xyd[:,:,2:,:,:]).abs() < (torch.max(ref_depth.unsqueeze(1), reproj_xyd[:,:,2:,:,:])*depth_thresh)  # nv1hw
    masks = bin_op_reduce([in_range, dist_masks.to(ref_depth.dtype), depth_masks.to(ref_depth.dtype)], torch.min)  # nv1hw
    mask = masks.sum(dim=1) >= (vthresh-1.1)  # n1hw
    return masks, mask
예제 #2
0
def project_img(src_img, dst_depth, src_cam, dst_cam, height=None, width=None):  # nchw, n1hw -> nchw, n1hw
    if height is None: height = src_img.size()[-2]
    if width is None: width = src_img.size()[-1]
    dst_idx_img_homo = get_pixel_grids(height, width).unsqueeze(0)  # nhw31
    dst_idx_cam_homo = idx_img2cam(dst_idx_img_homo, dst_depth, dst_cam)  # nhw41
    dst_idx_world_homo = idx_cam2world(dst_idx_cam_homo, dst_cam)  # nhw41
    dst2src_idx_cam_homo = idx_world2cam(dst_idx_world_homo, src_cam)  # nhw41
    dst2src_idx_img_homo = idx_cam2img(dst2src_idx_cam_homo, src_cam)  # nhw31
    warp_coord = dst2src_idx_img_homo[...,:2,0]  # nhw2
    warp_coord[..., 0] /= width
    warp_coord[..., 1] /= height
    warp_coord = (warp_coord*2-1).clamp(-1.1, 1.1)  # nhw2
    in_range = bin_op_reduce([-1<=warp_coord[...,0], warp_coord[...,0]<=1, -1<=warp_coord[...,1], warp_coord[...,1]<=1], torch.min).to(src_img.dtype).unsqueeze(1)  # n1hw
    warped_img = F.grid_sample(src_img, warp_coord, mode='bilinear', padding_mode='zeros', align_corners=False)
    return warped_img, in_range
예제 #3
0
    def forward(self,
                outputs,
                gt,
                masks,
                ref_cam,
                max_d,
                occ_guide=False,
                mode='soft'):  #MVS
        outputs, refined_depth = outputs

        depth_start = ref_cam[:, 1:2, 3:4, 0:1]  # n111
        depth_interval = ref_cam[:, 1:2, 3:4, 1:2]  # n111
        depth_end = depth_start + (max_d - 2) * depth_interval  # strict range
        masks = [masks[:, i, ...] for i in range(masks.size()[1])]

        stage_losses = []
        stats = []
        for est_depth, pair_results in outputs:
            gt_downsized = F.interpolate(gt,
                                         size=(est_depth.size()[2],
                                               est_depth.size()[3]),
                                         mode='bilinear',
                                         align_corners=False)
            masks_downsized = [
                F.interpolate(mask,
                              size=(est_depth.size()[2], est_depth.size()[3]),
                              mode='nearest') for mask in masks
            ]
            in_range = torch.min((gt_downsized >= depth_start),
                                 (gt_downsized <= depth_end))
            masks_valid = [
                torch.min((mask > 50), in_range) for mask in masks_downsized
            ]  # mask and in_range
            masks_overlap = [
                torch.min((mask > 200), in_range) for mask in masks_downsized
            ]
            union_overlap = bin_op_reduce(masks_overlap,
                                          torch.max)  # A(B+C)=AB+AC
            valid = union_overlap if occ_guide else in_range

            same_size = est_depth.size()[2] == pair_results[0][0].size(
            )[2] and est_depth.size()[3] == pair_results[0][0].size()[3]
            gt_interm = F.interpolate(
                gt,
                size=(pair_results[0][0].size()[2],
                      pair_results[0][0].size()[3]),
                mode='bilinear',
                align_corners=False) if not same_size else gt_downsized
            masks_interm = [
                F.interpolate(mask,
                              size=(pair_results[0][0].size()[2],
                                    pair_results[0][0].size()[3]),
                              mode='nearest') for mask in masks
            ] if not same_size else masks_downsized
            in_range_interm = torch.min(
                (gt_interm >= depth_start),
                (gt_interm <= depth_end)) if not same_size else in_range
            masks_valid_interm = [
                torch.min((mask > 50), in_range_interm)
                for mask in masks_interm
            ] if not same_size else masks_valid  # mask and in_range
            masks_overlap_interm = [
                torch.min((mask > 200), in_range_interm)
                for mask in masks_interm
            ] if not same_size else masks_overlap
            union_overlap_interm = bin_op_reduce(
                masks_overlap_interm,
                torch.max) if not same_size else union_overlap  # A(B+C)=AB+AC
            valid_interm = (union_overlap_interm if occ_guide else
                            in_range_interm) if not same_size else valid

            abs_err = (est_depth - gt_downsized).abs()
            abs_err_scaled = abs_err / depth_interval
            pair_abs_err = [
                (est - gt_interm).abs()
                for est in [est for est, (uncert, occ) in pair_results]
            ]
            pair_abs_err_scaled = [
                err / depth_interval for err in pair_abs_err
            ]

            l1 = abs_err_scaled[valid].mean()

            # ===== pair l1 =====
            if occ_guide:
                pair_l1_losses = [
                    err[mask_overlap].mean() for err, mask_overlap in zip(
                        pair_abs_err_scaled, masks_overlap_interm)
                ]
            else:
                pair_l1_losses = [
                    err[in_range_interm].mean() for err in pair_abs_err_scaled
                ]
            pair_l1_loss = sum(pair_l1_losses) / len(pair_l1_losses)

            # ===== uncert =====
            if mode in ['soft', 'hard', 'uwta']:
                if occ_guide:
                    uncert_losses = [
                        (err[mask_valid] * (-uncert[mask_valid]).exp() +
                         uncert[mask_valid]).mean()
                        for err, (est, (uncert, occ)), mask_valid, mask_overlap
                        in zip(pair_abs_err_scaled, pair_results,
                               masks_valid_interm, masks_overlap_interm)
                    ]
                else:
                    uncert_losses = [
                        (err[in_range_interm] *
                         (-uncert[in_range_interm]).exp() +
                         uncert[in_range_interm]).mean()
                        for err, (est, (
                            uncert,
                            occ)) in zip(pair_abs_err_scaled, pair_results)
                    ]
                uncert_loss = sum(uncert_losses) / len(uncert_losses)

            # ===== logistic =====
            if occ_guide and mode in ['soft', 'hard', 'uwta']:
                logistic_losses = [
                    nn.SoftMarginLoss()(
                        occ[mask_valid],
                        -mask_overlap[mask_valid].to(gt.dtype) * 2 + 1)
                    for (est, (uncert, occ)), mask_valid, mask_overlap in zip(
                        pair_results, masks_valid_interm, masks_overlap_interm)
                ]
                logistic_loss = sum(logistic_losses) / len(logistic_losses)

            less1 = (abs_err_scaled[valid] < 1.).to(gt.dtype).mean()
            less3 = (abs_err_scaled[valid] < 3.).to(gt.dtype).mean()

            pair_loss = pair_l1_loss
            if mode in ['soft', 'hard', 'uwta']:
                pair_loss = pair_loss + uncert_loss
                if occ_guide:
                    pair_loss = pair_loss + logistic_loss
            loss = l1 + pair_loss
            stage_losses.append(loss)
            stats.append((l1, less1, less3))

        abs_err = (refined_depth - gt_downsized).abs()
        abs_err_scaled = abs_err / depth_interval
        l1 = abs_err_scaled[valid].mean()
        less1 = (abs_err_scaled[valid] < 1.).to(gt.dtype).mean()
        less3 = (abs_err_scaled[valid] < 3.).to(gt.dtype).mean()

        loss = stage_losses[0] * 0.5 + stage_losses[1] * 1.0 + stage_losses[
            2] * 2.0  # + l1*2.0

        return loss, pair_loss, less1, less3, l1, stats, abs_err_scaled, valid
예제 #4
0
    views = {}

    pbar = tqdm.tqdm(loader, dynamic_ncols=True)
    for sample_np in pbar:
        if sample_np.get('skip') is not None and np.any(sample_np['skip']): continue
        sample = {attr: torch.from_numpy(sample_np[attr]).float().cuda() for attr in sample_np if attr not in ['skip', 'id']}

        prob_mask = prob_filter(sample['ref_probs'], pthresh)

        reproj_xyd, in_range = get_reproj(*[sample[attr] for attr in ['ref_depth', 'srcs_depth', 'ref_cam', 'srcs_cam']])
        vis_masks, vis_mask = vis_filter(sample['ref_depth'], reproj_xyd, in_range, 1, 0.01, args.vthresh)

        ref_depth_ave = ave_fusion(sample['ref_depth'], reproj_xyd, vis_masks)

        mask = bin_op_reduce([prob_mask, vis_mask], torch.min)

        if args.show_result:
            subplot_map([
                [sample['ref_depth'][0,0].cpu().data.numpy(), ref_depth_ave[0,0].cpu().data.numpy(), (ref_depth_ave*mask)[0,0].cpu().data.numpy()],
                [prob_mask[0,0].cpu().data.numpy(), vis_mask[0,0].cpu().data.numpy(), mask[0,0].cpu().data.numpy()]
            ])
            plt.show()
        
        idx_img = get_pixel_grids(*ref_depth_ave.size()[-2:]).unsqueeze(0)
        idx_cam = idx_img2cam(idx_img, ref_depth_ave, sample['ref_cam'])
        points = idx_cam2world(idx_cam, sample['ref_cam'])[...,:3,0].permute(0,3,1,2)
        cam_center_np = (- sample['ref_cam'][:,0,:3,:3].transpose(-2,-1) @ sample['ref_cam'][:,0,:3,3:])[...,0].cpu().numpy()  # n3
        points_np = points.cpu().data.numpy()
        mask_np = mask.cpu().data.numpy()
        for i in range(points_np.shape[0]):