def forward(self, init_d, ref_feat, ref_cam, srcs_feat, srcs_cam, s_scale): """ d Fi d p2 d p3 J = ------ * ------ * ------ d p2 d p3 d d c*2 2*3 3*1 """ n, c, h, w = ref_feat.size() ref_cam_scaled, *srcs_cam_scaled = [ scale_camera(cam, 1 / s_scale) for cam in [ref_cam] + srcs_cam ] J_list = [] r_list = [] for src_feat, src_cam_scaled in zip(srcs_feat, srcs_cam_scaled): H = get_homographies( ref_cam_scaled, src_cam_scaled, 1, init_d.detach(), torch.ones(n, 1, 1, 1, dtype=torch.float32).cuda()).squeeze(1) # nhw33 # copied from homography.py with torch.no_grad(): pixel_grids = get_pixel_grids(*src_feat.size()[-2:]).unsqueeze( 0) # 1hw31 warped_homo_coord = (H @ pixel_grids).squeeze(-1) # nhw3 warped_coord = warped_homo_coord[..., :2] / ( warped_homo_coord[..., 2:3] + 1e-9) # nhw2 warped = interpolate(src_feat, warped_coord) residual = (warped - ref_feat).permute(0, 2, 3, 1).unsqueeze(-1) # nhwc1 src_grad = self.sobel_conv(src_feat) # n c*2 hw src_grad_warped = interpolate(src_grad, warped_coord).permute( 0, 2, 3, 1).reshape(n, h, w, c, 2) # nhwc2 d3to2_1 = torch.eye(2, dtype=torch.float32).cuda().view( 1, 1, 1, 2, 2) / (warped_homo_coord[..., -1].view( n, h, w, 1, 1) + 1e-9) # nhw22 d3to2_2 = warped_coord.unsqueeze(-1) / ( warped_homo_coord[..., -1].view(n, h, w, 1, 1) + 1e-9) # nhw21 d3to2 = torch.cat([d3to2_1, d3to2_2], dim=-1) # nhw23 Ki = src_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3) # n1133 K0 = ref_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3) Ri = src_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3) R0 = ref_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3) # dptod = Ki @ Ri @ R0.inverse() @ K0.inverse() @ pixel_grids # nhw31 dptod = (Ki @ Ri @ R0.inverse() @ K0.inverse() - H) @ pixel_grids / init_d.detach().view(n, h, w, 1, 1) Ji = src_grad_warped @ d3to2 @ dptod # nhwc1 r_list.append(residual) J_list.append(Ji) J, r = [torch.cat(l, dim=-2) for l in [J_list, r_list]] delta = (-(J.transpose(-1, -2) @ r) / (J.transpose(-1, -2) @ J + 1e-9)).reshape(n, 1, h, w) if (delta != delta).any(): raise NanError # delta = delta.clamp(-1, 1) # plt.imshow(delta[0,0,...].clone().cpu().data.numpy()) # plt.show() refined_d = init_d + delta return refined_d
def build_cost_maps(self, ref, ref_cam, source, source_cam, depth_num, depth_start, depth_interval, scale): ref_cam_scaled, source_cam_scaled = [ scale_camera(cam, 1 / scale) for cam in [ref_cam, source_cam] ] Hs = get_homographies(ref_cam_scaled, source_cam_scaled, depth_num, depth_start, depth_interval) cost_maps = [] for d in range(depth_num): H = Hs[:, d, ...] warped_source = homography_warping(source, H) cost_maps.append(torch.cat([ref, warped_source], dim=1)) return cost_maps
def build_cost_volume(self, ref, ref_cam, src, src_cam, depth_num, depth_start, depth_interval, s_scale, d_scale): ref_cam_scaled, src_cam_scaled = [ scale_camera(cam, 1 / s_scale) for cam in [ref_cam, src_cam] ] Hs = get_homographies(ref_cam_scaled, src_cam_scaled, depth_num // d_scale, depth_start, depth_interval * d_scale) # ndhw33 src_nd_c_h_w = src.unsqueeze(1).repeat(1, depth_num // d_scale, 1, 1, 1).view( -1, *src.size()[1:]) # n*d chw warped_src_nd_c_h_w = homography_warping( src_nd_c_h_w, Hs.view(-1, *Hs.size()[2:])) # n*d chw warped_src = warped_src_nd_c_h_w.view(-1, depth_num // d_scale, *src.size()[1:]).transpose( 1, 2) # ncdhw return warped_src