def lr_loss_per_level(self,
                          leftEstDisp,
                          rightEstDisp,
                          leftImage,
                          rightImage,
                          leftMask=None,
                          rightMask=None):
        assert leftEstDisp.shape == rightEstDisp.shape, \
            'The shape of left and right disparity map should be the same!'
        N, C, H, W = leftEstDisp.shape
        leftImage = F.interpolate(leftImage, (H, W), mode='area')
        rightImage = F.interpolate(rightImage, (H, W), mode='area')

        leftImage_fromWarp = inverse_warp(rightImage, -leftEstDisp)
        rightImage_fromWarp = inverse_warp(leftImage, rightEstDisp)

        if leftMask is None:
            leftMask = torch.ones_like(leftImage > 0)
        loss = self.rms(leftImage[leftMask], leftImage_fromWarp[leftMask])
        loss += self.ssim_weight * SSIM(leftImage, leftImage_fromWarp,
                                        leftMask)

        if rightMask is None:
            rightMask = torch.ones_like(rightImage > 0)
        loss += self.rms(rightImage[rightMask], rightImage_fromWarp[rightMask])
        loss += self.ssim_weight * SSIM(rightImage, rightImage_fromWarp,
                                        leftMask)

        return loss
Beispiel #2
0
def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub):
    """
    Do occlusoin evaluation.
    Args:
        est_disp:
        ref_gt_disp:
        target_gt_disp:
        lb:
        ub:

    Returns:

    """
    error_dict = {}
    if est_disp is None:
        warnings.warn('Estimated disparity map is None, expected given')
        return error_dict
    if ref_gt_disp is None:
        warnings.warn(
            'Reference ground truth disparity map is None, expected given')
        return error_dict
    if target_gt_disp is None:
        warnings.warn(
            'Target ground truth disparity map is None, expected given')
        return error_dict

    if torch.is_tensor(est_disp):
        est_disp = est_disp.clone().cpu()
    if torch.is_tensor(ref_gt_disp):
        ref_gt_disp = ref_gt_disp.clone().cpu()
    if torch.is_tensor(target_gt_disp):
        target_gt_disp = target_gt_disp.clone().cpu()

    warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(),
                                    -ref_gt_disp.clone())
    theta = 1.0
    eps = 1e-6
    occlusion = (
        (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) |
        (torch.abs(warp_ref_gt_disp.clone()) < eps)).prod(
            dim=1, keepdim=True).type_as(ref_gt_disp)
    occlusion = occlusion.clamp(0, 1)

    occlusion_error_dict = calc_error(est_disp.clone() * occlusion.clone(),
                                      ref_gt_disp.clone() * occlusion.clone(),
                                      lb=lb,
                                      ub=ub)
    for key in occlusion_error_dict.keys():
        error_dict['occ/' + key] = occlusion_error_dict[key]

    not_occlusion = 1.0 - occlusion
    not_occlusion_error_dict = calc_error(
        est_disp.clone() * not_occlusion.clone(),
        ref_gt_disp.clone() * not_occlusion.clone(),
        lb=lb,
        ub=ub)
    for key in not_occlusion_error_dict.keys():
        error_dict['noc/' + key] = not_occlusion_error_dict[key]

    return error_dict
    def forward(self, disp, left, right):
        B, C, H, W = left.shape

        # the scale of downsample
        scale = W / disp.shape[-1]

        # upsample disparity map to image size, in [BatchSize, 1, Height, Width]
        up_disp = F.interpolate(disp,
                                size=(h, w),
                                mode='bilinear',
                                align_corners=True)
        up_disp = up_disp * scale

        # calculate warp error
        warp_right = inverse_warp(right, -up_disp)
        error = torch.abs(left - warp_right)

        # residual refinement
        # mix the info inside the disparity map, left image, right image and warp error
        mix_feat = self.conv_mix(
            torch.cat((left, right, warp_right, error, disp), 1))

        for block in self.residual_dilation_blocks:
            mix_feat = block(mix_feat)

        # get residual disparity map, in [BatchSize, 1, Height, Width]
        res_disp = self.conv_res(mix_feat)

        # refine the upsampled disparity map, in [BatchSize, 1, Height, Width]
        refine_disp = res_disp + up_disp

        # promise all disparity value larger than 0, in [BatchSize, 1, Height, Width]
        refine_disp = F.relu(refine_disp, inplace=True)

        return refine_disp
Beispiel #4
0
    def get_per_level_not_occlusion(self, estLeftDisp, estRightDisp):
        assert estLeftDisp.shape == estRightDisp.shape
        leftDisp_fromWarp = inverse_warp(estRightDisp, -estLeftDisp)
        rightDisp_fromWarp = inverse_warp(estLeftDisp, estRightDisp)

        # left and right consistency check
        leftOcclusion = ((torch.abs(leftDisp_fromWarp - estLeftDisp) > self.theta) |
                         (torch.abs(leftDisp_fromWarp) < self.eps))
        rightOcclusion = ((torch.abs(rightDisp_fromWarp - estRightDisp) > self.theta) |
                          (torch.abs(rightDisp_fromWarp) < self.eps))

        # get not occlusion mask
        leftNotOcclusion = (1 - leftOcclusion).type_as(leftOcclusion)
        rightNotOcclusion = (1 - rightOcclusion).type_as(rightOcclusion)

        return leftNotOcclusion, rightNotOcclusion
    def loss_per_level(self, estDisp, leftImage, rightImage, mask=None):
        from dmb.modeling.stereo.losses.utils import SSIM
        N, C, H, W = estDisp.shape
        leftImage = F.interpolate(leftImage, (H, W), mode='area')
        rightImage = F.interpolate(rightImage, (H, W), mode='area')

        leftImage_fromWarp = inverse_warp(rightImage, -estDisp)

        if mask is None:
            mask = torch.ones_like(leftImage > 0)
        loss = self.rms(leftImage[mask], leftImage_fromWarp[mask])
        loss += self.ssim_weight * SSIM(leftImage, leftImage_fromWarp, mask)

        return loss
def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub):
    """
    Do occlusoin evaluation.
    Args:
        est_disp: estimated disparity map, in [BatchSize, Channel, Height, Width] or
            [BatchSize, Height, Width] or [Height, Width] layout
        ref_gt_disp: reference(left) ground truth disparity map, in [BatchSize, Channel, Height, Width] or
            [BatchSize, Height, Width] or [Height, Width] layout
        target_gt_disp: target(right) ground truth disparity map, in [BatchSize, Channel, Height, Width] or
            [BatchSize, Height, Width] or [Height, Width] layout
        lb, (scalar): the lower bound of disparity you want to mask out
        ub, (scalar): the upper bound of disparity you want to mask out

    Returns:

    """
    error_dict = {}
    if est_disp is None:
        warnings.warn('Estimated disparity map is None, expected given')
        return error_dict
    if ref_gt_disp is None:
        warnings.warn(
            'Reference ground truth disparity map is None, expected given')
        return error_dict
    if target_gt_disp is None:
        warnings.warn(
            'Target ground truth disparity map is None, expected given')
        return error_dict

    if torch.is_tensor(est_disp):
        est_disp = est_disp.clone().cpu()
    if torch.is_tensor(ref_gt_disp):
        ref_gt_disp = ref_gt_disp.clone().cpu()
    if torch.is_tensor(target_gt_disp):
        target_gt_disp = target_gt_disp.clone().cpu()

    warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(),
                                    -ref_gt_disp.clone())
    theta = 1.0
    eps = 1e-6
    occlusion = (
        (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) |
        (torch.abs(warp_ref_gt_disp.clone()) < eps)).prod(
            dim=1, keepdim=True).type_as(ref_gt_disp)
    occlusion = occlusion.clamp(0, 1)

    occlusion_error_dict = calc_error(est_disp.clone() * occlusion.clone(),
                                      ref_gt_disp.clone() * occlusion.clone(),
                                      lb=lb,
                                      ub=ub)
    for key in occlusion_error_dict.keys():
        error_dict['occ_' + key] = occlusion_error_dict[key]

    not_occlusion = 1.0 - occlusion
    not_occlusion_error_dict = calc_error(
        est_disp.clone() * not_occlusion.clone(),
        ref_gt_disp.clone() * not_occlusion.clone(),
        lb=lb,
        ub=ub)
    for key in not_occlusion_error_dict.keys():
        error_dict['noc_' + key] = not_occlusion_error_dict[key]

    return error_dict