def fast_cat_fms(reference_fm, target_fm, max_disp=192, start_disp=0, dilation=1, disp_sample=None): device = reference_fm.device B, C, H, W = reference_fm.shape if disp_sample is None: end_disp = start_disp + max_disp - 1 disp_sample_number = (max_disp + dilation - 1) // dilation D = disp_sample_number # generate disparity samples, in [B,D, H, W] layout disp_sample = torch.linspace(start_disp, end_disp, D) disp_sample = disp_sample.view(1, D, 1, 1).expand(B, D, H, W).to(device).float() else: # direct provide disparity samples # the number of disparity samples D = disp_sample.shape[1] # expand D dimension concat_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W) concat_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W) # shift target feature according to disparity samples concat_target_fm = inverse_warp_3d(concat_target_fm, -disp_sample, padding_mode='zeros') # mask out features in reference concat_reference_fm = concat_reference_fm * (concat_target_fm > 0).type_as(concat_reference_fm) # [B, 2C, D, H, W) concat_fm = torch.cat((concat_reference_fm, concat_target_fm), dim=1) return concat_fm
def fast_dif_fms(reference_fm, target_fm, max_disp=192, start_disp=0, dilation=1, disp_sample=None, normalize=False, p=1.0,): device = reference_fm.device B, C, H, W = reference_fm.shape if disp_sample is None: end_disp = start_disp + max_disp - 1 disp_sample_number = (max_disp + dilation - 1) // dilation D = disp_sample_number # generate disparity samples, in [B,D, H, W] layout disp_sample = torch.linspace(start_disp, end_disp, D) disp_sample = disp_sample.view(1, D, 1, 1).expand(B, D, H, W).to(device).float() else: # direct provide disparity samples # the number of disparity samples D = disp_sample.shape[1] # expand D dimension dif_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W) dif_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W) # shift reference feature map with disparity through grid sample # shift target feature according to disparity samples dif_target_fm = inverse_warp_3d(dif_target_fm, -disp_sample, padding_mode='zeros') # mask out features in reference dif_reference_fm = dif_reference_fm * (dif_target_fm > 0).type_as(dif_reference_fm) # [B, C, D, H, W) dif_fm = dif_reference_fm - dif_target_fm if normalize: # [B, D, H, W] dif_fm = torch.norm(dif_fm, p=p, dim=1, keepdim=False) return dif_fm
def forward(self, left, right, disparity_samples, disparity_sample_noise): B, C, H, W = left.shape # disparity_sample_number * propagation_filter_size D = disparity_samples.shape[1] # warp right image feature according to disparity samples # [B, C, disparity_sample_number * propagation_filter_size, H, W] left = left.unsqueeze(2).expand(B, C, D, H, W) right = right.unsqueeze(2).expand(B, C, D, H, W) warped_right = inverse_warp_3d(right, -disparity_samples) # matching scores are computed by taking the inner product cost_volume = torch.mean(left * warped_right, dim=1) * self. temperature cost_volume = cost_volume.view(B, D//self.propagation_filter_size, self.propagation_filter_size, H, W) # [B, propagation_filter_size, disparity_sample_number, H, W] cost_volume = cost_volume.permute([0, 2, 1, 3, 4]) disparity_samples = disparity_samples.view(B, D//self.propagation_filter_size, self.propagation_filter_size, H, W) # [B, propagation_filter_size, disparity_sample_number, H, W] disparity_samples = disparity_samples.permute([0, 2, 1, 3, 4]) disparity_sample_noise = disparity_sample_noise.view(B, D//self.propagation_filter_size, self.propagation_filter_size, H, W) # [B, propagation_filter_size, disparity_sample_number, H, W] disparity_sample_noise = disparity_sample_noise.permute([0, 2, 1, 3, 4]) # pick the most possible matching disparity from neighbours # [B, 1, disparity_sample_number, H, W] prob_volume = F.softmax(cost_volume, dim=1) # [B, disparity_sample_number, H, W] disparity_samples = torch.sum(prob_volume * disparity_samples, dim=1) # [B, disparity_sample_number, H, W] disparity_sample_noise = torch.sum(prob_volume * disparity_sample_noise, dim=1) return disparity_samples, disparity_sample_noise