Example #1
0
 def forward(self, init_d, ref_feat, ref_cam, srcs_feat, srcs_cam, s_scale):
     """
          d Fi     d p2     d p3
     J = ------ * ------ * ------
          d p2     d p3     d d
          c*2      2*3      3*1
     """
     n, c, h, w = ref_feat.size()
     ref_cam_scaled, *srcs_cam_scaled = [
         scale_camera(cam, 1 / s_scale) for cam in [ref_cam] + srcs_cam
     ]
     J_list = []
     r_list = []
     for src_feat, src_cam_scaled in zip(srcs_feat, srcs_cam_scaled):
         H = get_homographies(
             ref_cam_scaled, src_cam_scaled, 1, init_d.detach(),
             torch.ones(n, 1, 1, 1,
                        dtype=torch.float32).cuda()).squeeze(1)  # nhw33
         # copied from homography.py
         with torch.no_grad():
             pixel_grids = get_pixel_grids(*src_feat.size()[-2:]).unsqueeze(
                 0)  # 1hw31
             warped_homo_coord = (H @ pixel_grids).squeeze(-1)  # nhw3
             warped_coord = warped_homo_coord[..., :2] / (
                 warped_homo_coord[..., 2:3] + 1e-9)  # nhw2
         warped = interpolate(src_feat, warped_coord)
         residual = (warped - ref_feat).permute(0, 2, 3,
                                                1).unsqueeze(-1)  # nhwc1
         src_grad = self.sobel_conv(src_feat)  # n c*2 hw
         src_grad_warped = interpolate(src_grad, warped_coord).permute(
             0, 2, 3, 1).reshape(n, h, w, c, 2)  # nhwc2
         d3to2_1 = torch.eye(2, dtype=torch.float32).cuda().view(
             1, 1, 1, 2, 2) / (warped_homo_coord[..., -1].view(
                 n, h, w, 1, 1) + 1e-9)  # nhw22
         d3to2_2 = warped_coord.unsqueeze(-1) / (
             warped_homo_coord[..., -1].view(n, h, w, 1, 1) + 1e-9)  # nhw21
         d3to2 = torch.cat([d3to2_1, d3to2_2], dim=-1)  # nhw23
         Ki = src_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3)  # n1133
         K0 = ref_cam_scaled[:, 1, :3, :3].reshape(-1, 1, 1, 3, 3)
         Ri = src_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3)
         R0 = ref_cam_scaled[:, 0, :3, :3].reshape(-1, 1, 1, 3, 3)
         # dptod = Ki @ Ri @ R0.inverse() @ K0.inverse() @ pixel_grids  # nhw31
         dptod = (Ki @ Ri @ R0.inverse() @ K0.inverse() -
                  H) @ pixel_grids / init_d.detach().view(n, h, w, 1, 1)
         Ji = src_grad_warped @ d3to2 @ dptod  # nhwc1
         r_list.append(residual)
         J_list.append(Ji)
     J, r = [torch.cat(l, dim=-2) for l in [J_list, r_list]]
     delta = (-(J.transpose(-1, -2) @ r) /
              (J.transpose(-1, -2) @ J + 1e-9)).reshape(n, 1, h, w)
     if (delta != delta).any():
         raise NanError
     # delta = delta.clamp(-1, 1)
     # plt.imshow(delta[0,0,...].clone().cpu().data.numpy())
     # plt.show()
     refined_d = init_d + delta
     return refined_d
Example #2
0
    def build_cost_maps(self, ref, ref_cam, source, source_cam, depth_num,
                        depth_start, depth_interval, scale):
        ref_cam_scaled, source_cam_scaled = [
            scale_camera(cam, 1 / scale) for cam in [ref_cam, source_cam]
        ]
        Hs = get_homographies(ref_cam_scaled, source_cam_scaled, depth_num,
                              depth_start, depth_interval)

        cost_maps = []
        for d in range(depth_num):
            H = Hs[:, d, ...]
            warped_source = homography_warping(source, H)
            cost_maps.append(torch.cat([ref, warped_source], dim=1))
        return cost_maps
Example #3
0
 def build_cost_volume(self, ref, ref_cam, src, src_cam, depth_num,
                       depth_start, depth_interval, s_scale, d_scale):
     ref_cam_scaled, src_cam_scaled = [
         scale_camera(cam, 1 / s_scale) for cam in [ref_cam, src_cam]
     ]
     Hs = get_homographies(ref_cam_scaled, src_cam_scaled,
                           depth_num // d_scale, depth_start,
                           depth_interval * d_scale)
     # ndhw33
     src_nd_c_h_w = src.unsqueeze(1).repeat(1, depth_num // d_scale,
                                            1, 1, 1).view(
                                                -1,
                                                *src.size()[1:])  # n*d chw
     warped_src_nd_c_h_w = homography_warping(
         src_nd_c_h_w, Hs.view(-1,
                               *Hs.size()[2:]))  # n*d chw
     warped_src = warped_src_nd_c_h_w.view(-1, depth_num // d_scale,
                                           *src.size()[1:]).transpose(
                                               1, 2)  # ncdhw
     return warped_src