Example #1
0
def fast_cat_fms(reference_fm, target_fm, max_disp=192, start_disp=0, dilation=1, disp_sample=None):
    device = reference_fm.device
    B, C, H, W = reference_fm.shape

    if disp_sample is None:
        end_disp = start_disp + max_disp - 1

        disp_sample_number = (max_disp + dilation - 1) // dilation
        D = disp_sample_number

        # generate disparity samples, in [B,D, H, W] layout
        disp_sample = torch.linspace(start_disp, end_disp, D)
        disp_sample = disp_sample.view(1, D, 1, 1).expand(B, D, H, W).to(device).float()

    else: # direct provide disparity samples
        # the number of disparity samples
        D = disp_sample.shape[1]

    # expand D dimension
    concat_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W)
    concat_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W)

    # shift target feature according to disparity samples
    concat_target_fm = inverse_warp_3d(concat_target_fm, -disp_sample, padding_mode='zeros')

    # mask out features in reference
    concat_reference_fm = concat_reference_fm * (concat_target_fm > 0).type_as(concat_reference_fm)

    # [B, 2C, D, H, W)
    concat_fm = torch.cat((concat_reference_fm, concat_target_fm), dim=1)

    return concat_fm
def fast_dif_fms(reference_fm, target_fm, max_disp=192, start_disp=0, dilation=1, disp_sample=None,
                 normalize=False, p=1.0,):
    device = reference_fm.device
    B, C, H, W = reference_fm.shape

    if disp_sample is None:
        end_disp = start_disp + max_disp - 1

        disp_sample_number = (max_disp + dilation - 1) // dilation
        D = disp_sample_number

        # generate disparity samples, in [B,D, H, W] layout
        disp_sample = torch.linspace(start_disp, end_disp, D)
        disp_sample = disp_sample.view(1, D, 1, 1).expand(B, D, H, W).to(device).float()

    else:  # direct provide disparity samples
        # the number of disparity samples
        D = disp_sample.shape[1]

    # expand D dimension
    dif_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W)
    dif_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W)

    # shift reference feature map with disparity through grid sample
    # shift target feature according to disparity samples
    dif_target_fm = inverse_warp_3d(dif_target_fm, -disp_sample, padding_mode='zeros')

    # mask out features in reference
    dif_reference_fm = dif_reference_fm * (dif_target_fm > 0).type_as(dif_reference_fm)

    # [B, C, D, H, W)
    dif_fm = dif_reference_fm - dif_target_fm

    if normalize:
        # [B, D, H, W]
        dif_fm = torch.norm(dif_fm, p=p, dim=1, keepdim=False)

    return dif_fm
    def forward(self, left, right, disparity_samples, disparity_sample_noise):

        B, C, H, W = left.shape
        # disparity_sample_number * propagation_filter_size
        D = disparity_samples.shape[1]

        # warp right image feature according to disparity samples
        # [B, C, disparity_sample_number * propagation_filter_size, H, W]
        left = left.unsqueeze(2).expand(B, C, D, H, W)
        right = right.unsqueeze(2).expand(B, C, D, H, W)
        warped_right = inverse_warp_3d(right, -disparity_samples)

        # matching scores are computed by taking the inner product
        cost_volume = torch.mean(left * warped_right, dim=1) * self. temperature
        cost_volume = cost_volume.view(B, D//self.propagation_filter_size, self.propagation_filter_size, H, W)
        # [B, propagation_filter_size, disparity_sample_number, H, W]
        cost_volume = cost_volume.permute([0, 2, 1, 3, 4])

        disparity_samples = disparity_samples.view(B, D//self.propagation_filter_size, self.propagation_filter_size, H, W)
        # [B, propagation_filter_size, disparity_sample_number, H, W]
        disparity_samples = disparity_samples.permute([0, 2, 1, 3, 4])

        disparity_sample_noise = disparity_sample_noise.view(B, D//self.propagation_filter_size, self.propagation_filter_size, H, W)
        # [B, propagation_filter_size, disparity_sample_number, H, W]
        disparity_sample_noise = disparity_sample_noise.permute([0, 2, 1, 3, 4])

        # pick the most possible matching disparity from neighbours
        # [B, 1, disparity_sample_number, H, W]
        prob_volume = F.softmax(cost_volume, dim=1)

        # [B, disparity_sample_number, H, W]
        disparity_samples = torch.sum(prob_volume * disparity_samples, dim=1)
        # [B, disparity_sample_number, H, W]
        disparity_sample_noise = torch.sum(prob_volume * disparity_sample_noise, dim=1)

        return disparity_samples, disparity_sample_noise