def preprocess_scene(x_2d, scene_rgb, scene_depth, scene_K, scene_Tcw,
                     scene_ori_rgb):
    rand_R = torch.eye(3).unsqueeze(0)
    N, L, C, H, W = scene_rgb.shape
    # scene_rgb = scene_rgb.view(N, L, C, H, W)
    scene_depth = scene_depth.view(N * L, 1, H, W)
    scene_K = scene_K.view(N * L, 3, 3)
    scene_Tcw = scene_Tcw.view(N * L, 3, 4)

    # generate 3D world position of scene
    d = scene_depth.view(N * L, H * W, 1)  # dim (N*L, H*W, 1)
    X_3d = batched_pi_inv(scene_K, x_2d, d)  # dim (N*L, H*W, 3)
    Rwc, twc = batched_inv_pose(
        R=scene_Tcw[:, :3, :3],
        t=scene_Tcw[:, :3, 3].squeeze(-1))  # dim (N*L, 3, 3), (N, 3)
    X_world = batched_transpose(Rwc.cuda(), twc.cuda(),
                                X_3d)  # dim (N*L, H*W, 3)
    X_world = X_world.view(N, L * H * W, 3)  # dim (N, L*H*W, 3)
    scene_center = torch.mean(X_world, dim=1)  # dim (N, 3)
    X_world -= scene_center.view(N, 1, 3)
    X_world = batched_transpose(
        rand_R.cuda().expand(N, 3, 3),
        torch.zeros(1, 3, 1).cuda().expand(N, 3, 1),
        X_world)  # dim (N, L*H*W, 3), data augmentation
    X_world = X_world.view(N, L, H, W,
                           3).permute(0, 1, 4, 2,
                                      3).contiguous()  # dim (N, L, 3, H, W)
    scene_input = torch.cat((scene_rgb, X_world), dim=2)

    return scene_input.cuda(), scene_ori_rgb.cuda(), X_world.cuda(), \
           torch.gt(scene_depth, 1e-5).cuda().view(N, L, H, W), scene_center.cuda(), rand_R.expand(N, 3, 3).cuda()
Example #2
0
def wrapping(I_a, I_b, d_a, K, R, t):
    """
    Wrap image by providing depth, rotation and translation
    :param I_a:
    :param I_b:
    :param d_a:
    :param K:
    :param R:
    :param t:
    :return:
    """
    import banet_track.ba_module as module

    H, W, C = I_a.shape
    # I_a = torch.from_numpy(I_a.transpose((2, 0, 1))).cuda().view((1, C, H, W))
    I_b = torch.from_numpy(I_b.transpose((2, 0, 1))).cuda().view((1, C, H, W))
    d_a = torch.from_numpy(d_a).cuda().view((1, H * W))
    K = torch.from_numpy(K).cuda().view(1, 3, 3)
    R = torch.from_numpy(R).cuda().view(1, 3, 3)
    t = torch.from_numpy(t).cuda().view(1, 3)

    x_a = module.x_2d_coords_torch(1, H, W).view(1, H * W,
                                                 2).cuda()  # dim: (N, H*W, 2)
    X_a_3d = module.batched_pi_inv(K, x_a, d_a.view((1, H * W, 1)))
    X_b_3d = module.batched_transpose(R, t, X_a_3d)
    x_b_2d, _ = module.batched_pi(K, X_b_3d)
    x_b_2d_out = x_b_2d.cpu().numpy()
    x_b_2d = module.batched_x_2d_normalize(H, W,
                                           x_b_2d).view(1, H, W,
                                                        2)  # (N, H, W, 2)
    wrap_img_b = module.batched_interp2d(I_b, x_b_2d)

    return wrap_img_b.cpu().numpy().transpose((0, 2, 3, 1)).reshape(
        (H, W, C)), x_b_2d_out.reshape(H, W, 2)
def preprocess(x_2d, scene_rgb, scene_depth, scene_K, scene_Tcw, scene_ori_rgb, scene_center=None):

    N, L, C, H, W = scene_rgb.shape
    _, _, _, ori_H, ori_W = scene_ori_rgb.shape

    scene_rgb = scene_rgb.view(N * L, C, H, W)
    scene_depth = scene_depth.view(N * L, 1, H, W)
    scene_K = scene_K.view(N * L, 3, 3)
    scene_Tcw = scene_Tcw.view(N * L, 3, 4)

    # generate 3D world position of scene
    d = scene_depth.view(N * L, H * W, 1)                                                     # dim (N*L, H*W, 1)
    X_3d = batched_pi_inv(scene_K, x_2d, d)                                                   # dim (N*L, H*W, 3)
    Rwc, twc = batched_inv_pose(R=scene_Tcw[:, :3, :3],
                                t=scene_Tcw[:, :3, 3].squeeze(-1))                            # dim (N*L, 3, 3), (N, 3)
    X_world = batched_transpose(Rwc.cuda(), twc.cuda(), X_3d)                                 # dim (N*L, H*W, 3)
    X_world = X_world.view(N, L * H * W, 3)                                                   # dim (N, L*H*W, 3)
    if scene_center is None:
        scene_center = torch.mean(X_world, dim=1)                                             # dim (N, 3)
    X_world -= scene_center.view(N, 1, 3)

    X_world = X_world.view(N * L, H, W, 3).permute(0, 3, 1, 2).contiguous()                   # dim (N * L, 3, H, W)
    scene_input = torch.cat((scene_rgb, X_world), dim=1)

    return scene_input.cuda(), \
           scene_ori_rgb.view((N*L, 3, ori_H, ori_W)).cuda(), \
           X_world.cuda(), \
           torch.gt(scene_depth, 1e-5).cuda().view(N * L, H, W), \
           scene_center.cuda()
def compute_pose_lm_pnp(gt_Tcws, query_X_w, rand_R, scene_center, query_K,
                        pnp_x_2d, repro_thres):
    N, _, H, W = query_X_w.shape

    # recover original scene coordinates
    query_X_3d_w = query_X_w.permute(0, 2, 3, 1).view(N, -1, 3)
    rand_R_t = torch.transpose(rand_R, 1, 2).to(query_X_3d_w.device)
    query_X_3d_w = batched_transpose(rand_R_t,
                                     torch.zeros(N, 3).to(query_X_3d_w.device),
                                     query_X_3d_w)
    query_X_3d_w += scene_center.view(N, 1, 3)
    query_X_3d_w = recover_original_scene_coordinates(query_X_w, rand_R,
                                                      scene_center)
    query_X_3d_w = query_X_3d_w.view(N, H, W,
                                     3).squeeze(0).detach().cpu().numpy()

    # run Ransac PnP
    lm_pnp_pose_vec, inlier_map = lm_pnp.compute_lm_pnp(
        pnp_x_2d, query_X_3d_w, query_K, repro_thres, 128, 100)
    R_res, _ = cv2.Rodrigues(lm_pnp_pose_vec[:3])
    lm_pnp_pose = np.eye(4, dtype=np.float32)
    lm_pnp_pose[:3, :3] = R_res
    lm_pnp_pose[:3, 3] = lm_pnp_pose_vec[3:].ravel()

    # measure accuracy
    gt_pose = gt_Tcws.squeeze(0).detach().cpu().numpy()

    R_acc = rel_rot_angle(lm_pnp_pose, gt_pose)
    t_acc = rel_distance(lm_pnp_pose, gt_pose)

    #     ransc_inlier = None
    return R_acc, t_acc, lm_pnp_pose, inlier_map
Example #5
0
    def verify_features(self, I_a, d_a, K, I_b, se3_gt, x, y, title):
        """
        Extract feature pyramids f_a, f_b of I_a and I_b
        Wrap f_b to f_a
        Compute distances of a pixel in f_a with the neighbors of its corresponding pixels in f_b
        :param I_a: Image of frame A, dim: (N, C, H, W)
        :param d_a: Depth of frame A, dim: (N, 1, H, W)
        :param K: intrinsic matrix at level 0: dim: (N, 3, 3)
        :param I_b: Image of frame B, dim: (N, C, H, W)
        :param se3_gt: Groundtruth of se3, dim: (N, 6)
        :return:
        """
        import banet_track.ba_debug as debug

        (N, C, H, W) = I_a.shape
        I_a.requires_grad_()
        I_b.requires_grad_()

        # Concate I_a and I_b
        I = torch.cat([I_a, I_b], dim=0)

        # Aggregate pyramid features
        aggr_pyramid = self.aggregate_pyramid_features(self.backbone_net.forward(I))
        aggr_pyramid_f_a = [f[:N, :, :, :] for f in aggr_pyramid]
        aggr_pyramid_f_b = [f[N:, :, :, :] for f in aggr_pyramid]

        for level in [2, 1, 0]:
            (level_H, level_W) = self.level_dim_hw[level]

            # Resize and Rescale the depth and the intrinsic matrix
            rescale_ratio = 1.0 / math.pow(2, level)
            level_K = rescale_ratio * K.detach()  # dim: (N, 3, 3)
            level_d_a = F.interpolate(d_a, scale_factor=rescale_ratio).detach()  # dim: (N, 1, H, W)

            # Cache several variables:
            R, t = se3_exp(se3_gt)
            x_a_2d = self.x_valid_2d[level]  # dim: (N, H*W, 2)
            X_a_3d = batched_pi_inv(level_K, x_a_2d,
                                    level_d_a.view((N, level_H * level_W, 1)))
            X_b_3d = batched_transpose(R, t, X_a_3d)
            x_b_2d, _ = batched_pi(level_K, X_b_3d)
            x_b_2d = module.batched_x_2d_normalize(float(level_H), float(level_W), x_b_2d).view(N, level_H, level_W, 2)  # (N, H, W, 2)

            # Wrap the feature
            level_aggr_pyramid_f_b_wrap = batched_interp2d(aggr_pyramid_f_b[level], x_b_2d)
            level_x = int(x * rescale_ratio)
            level_y = int(y * rescale_ratio)
            left = level_x - debug.similar_window_offset
            left = left if left >= 0 else 0
            right = level_x + debug.similar_window_offset
            up = level_y - debug.similar_window_offset
            up = up if up >= 0 else 0
            down = level_y + debug.similar_window_offset
            batch_distance = torch.norm(aggr_pyramid_f_a[level][:, :, up:down, left:right] -     # (N, level_H, level_W)
                                        level_aggr_pyramid_f_b_wrap[:, :, level_y:level_y+1, level_x:level_x+1], 2, 1)
            show_multiple_img([{'img': I_a[0].detach().cpu().numpy().transpose(1, 2, 0), 'title': 'I_a'},
                               {'img': I_b[0].detach().cpu().numpy().transpose(1, 2, 0), 'title': 'I_b'},
                               {'img': batch_distance[0].detach().cpu().numpy(), 'title': 'feature distance', 'cmap':'gray'}],
                              title=title, num_cols=3)
def compute_pose_pnp_from_valid_pixels(gt_Tcws, query_X_w, rand_R,
                                       scene_center, query_K, valid_pix_idx,
                                       pnp_x_2d, repro_thres):
    N, _, H, W = query_X_w.shape

    # recover original scene coordinates
    query_X_3d_w = query_X_w.permute(0, 2, 3, 1).view(N, -1, 3)
    rand_R_t = torch.transpose(rand_R, 1, 2).to(query_X_3d_w.device)
    query_X_3d_w = batched_transpose(rand_R_t,
                                     torch.zeros(N, 3).to(query_X_3d_w.device),
                                     query_X_3d_w)
    query_X_3d_w += scene_center.view(N, 1, 3)
    query_X_3d_w = recover_original_scene_coordinates(query_X_w, rand_R,
                                                      scene_center)
    query_X_3d_w = query_X_3d_w.view(N, H, W,
                                     3).squeeze(0).detach().cpu().numpy()

    # select valid pixels with input index
    x, y = valid_pix_idx
    x_2d_valid = pnp_x_2d[y, x, :]
    query_X_3d_valid = query_X_3d_w[y, x, :]
    selected_pixels = query_X_3d_valid.shape[0]

    query_X_3d_valid = query_X_3d_valid.reshape(1, selected_pixels, 3)
    x_2d_valid = x_2d_valid.reshape(1, selected_pixels, 2)

    # run Ransac PnP
    dist = np.zeros(4)
    k = query_K.squeeze(0).detach().cpu().numpy()
    retval, R_res, t_res, ransc_inlier = cv2.solvePnPRansac(
        query_X_3d_valid,
        x_2d_valid,
        k,
        dist,
        reprojectionError=repro_thres,
    )
    #     print(retval)
    #     _, R_res, t_res = cv2.solvePnP(query_X_3d_valid, x_2d_valid, k, dist)#, flags=cv2.SOLVEPNP_EPNP)

    R_res, _ = cv2.Rodrigues(R_res)
    pnp_pose = np.eye(4, dtype=np.float32)
    pnp_pose[:3, :3] = R_res
    pnp_pose[:3, 3] = t_res.ravel()

    # measure accuracy
    gt_pose = gt_Tcws.squeeze(0).detach().cpu().numpy()

    R_acc = rel_rot_angle(pnp_pose, gt_pose)
    t_acc = rel_distance(pnp_pose, gt_pose)

    #     ransc_inlier = None
    return R_acc, t_acc, pnp_pose, ransc_inlier
def recover_original_scene_coordinates(query_X_w, rand_R, scene_center):
    N, _, H, W = query_X_w.shape

    # recover original scene coordinates
    query_X_3d_w = query_X_w.permute(0, 2, 3, 1).view(N, -1, 3)
    rand_R_t = torch.transpose(rand_R, 1, 2).to(query_X_3d_w.device)
    query_X_3d_w = batched_transpose(rand_R_t,
                                     torch.zeros(N, 3).to(query_X_3d_w.device),
                                     query_X_3d_w)
    query_X_3d_w += scene_center.view(N, 1, 3)
    query_X_3d_w = query_X_3d_w.view(N, H, W, 3)

    return query_X_3d_w
Example #8
0
def preprocess(sample_dict, pre_x2d, out_dim, rescale_dist=0.0):
    rand_angle = np.random.random_sample() * 2.0 * np.pi
    rand_R = quaternion_matrix(quaternion_about_axis(rand_angle, (0.0, 1.0, 0.0)))[:3, :3]
    rand_R = torch.FloatTensor(rand_R).unsqueeze(0)

    scene_rgb = sample_dict['frames_img'][:, :5, ...].cuda()
    scene_depth = sample_dict['frames_depth'][:, :5, ...].cuda()
    scene_K = sample_dict['frames_K'][:, :5, ...].cuda()
    scene_Tcw = sample_dict['frames_Tcw'][:, :5, ...]
    scene_ori_rgb = sample_dict['frames_ori_img'][:, :5, ...].cuda()
    scene_neg_tags = sample_dict['frames_neg_tags'][:, :5, ...].cuda()

    N, L, C, H, W = scene_rgb.shape
    # scene_rgb = scene_rgb.view(N, L, C, H, W)
    scene_depth = scene_depth.view(N * L, 1, H, W)
    scene_K = scene_K.view(N * L, 3, 3)
    scene_Tcw = scene_Tcw.view(N * L, 3, 4)

    # generate 3D world position of scene
    d = scene_depth.view(N * L, H * W, 1)  # dim (N*L, H*W, 1)
    X_3d = batched_pi_inv(scene_K, pre_x2d, d)  # dim (N*L, H*W, 3)
    Rwc, twc = batched_inv_pose(R=scene_Tcw[:, :3, :3],
                                t=scene_Tcw[:, :3, 3].squeeze(-1))  # dim (N*L, 3, 3), (N, 3)
    X_world = batched_transpose(Rwc.cuda(), twc.cuda(), X_3d)  # dim (N*L, H*W, 3)
    X_world = X_world.contiguous().view(N, L * H * W, 3)        # dim (N, L*H*W, 3)
    scene_center = torch.mean(X_world, dim=1)  # dim (N, 3)
    X_world -= scene_center.view(N, 1, 3)
    X_world = batched_transpose(rand_R.cuda().expand(N, 3, 3),
                                torch.zeros(1, 3, 1).cuda().expand(N, 3, 1),
                                X_world)  # dim (N, L*H*W, 3), data augmentation
    X_world = X_world.view(N, L, H, W, 3).permute(0, 1, 4, 2, 3).contiguous()  # dim (N, L, 3, H, W)

    # query image:
    query_img = sample_dict['img']
    query_ori_img = sample_dict['ori_img']

    # compute multiscale ground truth query_X_worlds & valid_masks
    query_X_worlds = []
    valid_masks = []
    out_H, out_W = out_dim
    query_depth = sample_dict['depth'].cuda()
    ori_query_depth = query_depth.clone()
    N, C, H, W = query_depth.shape
    for i in range(4):
        query_depth_patch = F.unfold(
            query_depth,
            kernel_size=(H // out_H, W // out_W),
            stride=(H // out_H, W // out_W)
        ).view(N, -1, out_H, out_W)
        mask = torch.gt(query_depth_patch, 1e-5)
        count = torch.sum(mask.float(), dim=1)
        query_depth_down = torch.sum(query_depth_patch * mask.float(), dim=1) / \
                           torch.where(torch.le(count, 1e-5),
                                       torch.full(count.shape, 1e6).to(count.device),
                                       count)  # (N, 1, out_H, out_W)
        query_Tcw = sample_dict['Tcw']
        query_K = sample_dict['K'].clone().cuda()
        query_K[:, 0, 0] *= out_W / W
        query_K[:, 0, 2] *= out_W / W
        query_K[:, 1, 1] *= out_H / H
        query_K[:, 1, 2] *= out_H / H
        query_d = query_depth_down.view(N, out_H * out_W, 1)  # dim (N, H*W, 1)
        out_x_2d = x_2d_coords_torch(N, out_H, out_W).cuda().view(N, -1, 2)
        query_X_3d = batched_pi_inv(query_K, out_x_2d, query_d)  # dim (N, H*W, 3)
        query_Rwc, query_twc = batched_inv_pose(R=query_Tcw[:, :3, :3],
                                                t=query_Tcw[:, :3, 3].squeeze(-1))  # dim (N, 3, 3), (N, 3)
        query_X_world = batched_transpose(query_Rwc.cuda(), query_twc.cuda(), query_X_3d)  # dim (N, H*W, 3)
        query_X_world -= scene_center.view(N, 1, 3)
        query_X_world = batched_transpose(rand_R.cuda().expand(N, 3, 3),
                                          torch.zeros(1, 3, 1).cuda().expand(N, 3, 1),
                                          query_X_world)  # dim (N, H*W, 3), data augmentation
        query_X_world = query_X_world.permute(0, 2, 1).view(N, 3, out_H, out_W).contiguous()  # dim (N, 3, H, W)
        query_X_worlds.append(query_X_world.cuda())

        valid_masks.append(torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W))

        if i == 3:
            query_X_worlds.append(query_X_world.cuda())
            valid_masks.append(torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W))

        out_H //= 2
        out_W //= 2

    # compute norm_query_Tcw for normalized scene coordinate
    query_twc = query_twc.cuda() - scene_center.view(N, 3, 1)
    norm_query_Twc = torch.cat([query_Rwc.cuda(), query_twc], dim=-1)  # dim (N, 3, 4)
    norm_query_Twc = torch.bmm(rand_R.cuda().expand(N, 3, 3), norm_query_Twc)  # dim (N, 3, 4)
    query_Rcw, query_tcw = batched_inv_pose(R=norm_query_Twc[:, :3, :3],
                                            t=norm_query_Twc[:, :3, 3].squeeze(-1))  # dim (N, 3, 3), (N, 3)
    norm_query_Tcw = torch.cat([query_Rcw, query_tcw.view(N, 3, 1)], dim=-1)  # dim (N, 3, 4)

    # compute down sampled query K
    out_H, out_W = out_dim
    query_K = sample_dict['K'].clone().cuda()
    query_K[:, 0, 0] *= out_W / W
    query_K[:, 0, 2] *= out_W / W
    query_K[:, 1, 1] *= out_H / H
    query_K[:, 1, 2] *= out_H / H

    if rescale_dist > 0:
        query_X_worlds, X_world, rescale_factor = rescale_scene_coords(query_X_worlds, X_world, scene_neg_tags, rescale_dist)
    else:
        rescale_factor = torch.ones(N)
    scene_input = torch.cat((scene_rgb, X_world), dim=2)

    return scene_input.cuda(), query_img.cuda(), query_X_worlds[::-1], valid_masks[::-1], \
           scene_ori_rgb.cuda(), query_ori_img.cuda(), X_world.cuda(), \
           torch.gt(scene_depth, 1e-5).cuda().view(N, L, H, W), norm_query_Tcw, query_K, scene_neg_tags, rescale_factor.cuda()
Example #9
0
    def valid(self, I_a, d_a, sel_a_indices, K, I_b, se3_gt, epoch):
        """
        Pre cache the variable for prediction
        :param I_a: Image of frame A, dim: (N, C, H, W)
        :param d_a: Depth of frame A, dim: (N, 1, H, W)
        :param sel_a_indices: (N, 3, M)
        :param K: intrinsic matrix at level 0: dim: (N, 3, 3)
        :param I_b: Image of frame B, dim: (N, C, H, W)
        :param se3_gt: ground truth Pose
        """

        (N, C, H, W) = I_a.shape
        I_a.detach()
        I_b.detach()

        # Ground-truth pose
        R_gt, t_gt = se3_exp(se3_gt)

        # Concate I_a and I_b
        I = torch.cat([I_a, I_b], dim=0)

        # Aggregate pyramid features
        aggr_pyramid = self.aggregate_pyramid_features(
            self.backbone_net.forward(I))
        aggr_pyramid_f_a = [f[:N, :, :, :] for f in aggr_pyramid]
        aggr_pyramid_f_b = [f[N:, :, :, :] for f in aggr_pyramid]

        # Init a se(3) vector and mark requires_grad = True
        # alpha = torch.tensor([1e-4, 1e-4, 1e-4, 0.0, 0.0, 0.0]).repeat(N).view((N, 6))      # dim: (N, 6)
        # factor = 0.3
        # alpha = module.gen_random_alpha(se3_gt, rot_angle_rfactor=1.25, trans_vec_rfactor=0.16).view((N, 6)).cuda()
        # alpha.requires_grad_()
        T = torch.eye(4).view(1, 4, 4).repeat(N, 1, 1).detach()
        init_T = T

        pred_SE3_list = [
        ]  # (num_level: low_res to high_res, num_iter_per_level)
        gt_f_pair_list = []
        lambda_weight = []
        flow_list = []
        for level in [2, 1, 0]:

            pred_SE3_list.append([])
            lambda_weight.append([])
            flow_list.append([])
            (level_H, level_W) = self.level_dim_hw[level]

            M = sel_a_indices.shape[2]  # number of selected pts

            # Features on current level
            f_a = aggr_pyramid_f_a[level]
            f_b = aggr_pyramid_f_b[level]
            f_b_grad = batched_gradient(f_b)  # dim: (N, 2*C, H, W)

            # Resize and Rescale the depth and the intrinsic matrix
            rescale_ratio = 1.0 / math.pow(2, level)
            level_K = rescale_ratio * K.detach()  # dim: (N, 3, 3)
            level_d_a = F.interpolate(
                d_a, scale_factor=rescale_ratio).detach()  # dim: (N, 1, H, W)
            sel_a_idx = sel_a_indices[:,
                                      level, :].view(N,
                                                     M).detach()  # dim: (N, M)

            # Cache several variables:
            x_a_2d = self.x_train_2d[level]  # dim: (N, H*W, 2)
            X_a_3d = batched_pi_inv(level_K, x_a_2d,
                                    level_d_a.view((N, level_H * level_W, 1)))
            X_a_3d_sel = batched_index_select(X_a_3d, 1,
                                              sel_a_idx)  # dim: (N, M, 3)
            """ Ground-truth correspondence for Regularizer
            """
            f_C = f_a.shape[1]
            X_b_3d_gt = batched_transpose(R_gt, t_gt, X_a_3d)
            x_b_2d_gt, _ = batched_pi(level_K, X_b_3d_gt)
            x_b_2d_gt = module.batched_x_2d_normalize(float(level_H),
                                                      float(level_W),
                                                      x_b_2d_gt).view(
                                                          N, level_H, level_W,
                                                          2)  # (N, H, W, 2)
            gt_f_wrap_b = batched_interp2d(f_b, x_b_2d_gt)
            f_a_select = batched_index_select(
                f_a.view(N, f_C, level_H * level_W), 2, sel_a_idx)
            gt_f_wrap_b_select = batched_index_select(
                gt_f_wrap_b.view(N, f_C, level_H * level_W), 2, sel_a_idx)
            gt_f_pair_list.append((f_a_select, gt_f_wrap_b_select))

            # Run iteration 3 times
            for itr in range(0, 6):
                T, r, delta_norm, lamb, flow = module.dm_levenberg_marquardt_itr(
                    T, X_a_3d, X_a_3d_sel, f_a, sel_a_idx, level_K, f_b,
                    f_b_grad, self.lambda_prediction, level)
                pred_SE3_list[-1].append(T)
                flow_list[-1].append((flow, x_b_2d_gt.detach()))

        return pred_SE3_list, gt_f_pair_list, init_T.detach(), flow_list
def compute_pose(I_a,
                 d_a,
                 sel_a_idx,
                 I_b,
                 K,
                 alpha,
                 T_gt,
                 opt_max_itr=100,
                 opt_eps=1e-5):

    # Debug assert
    assert sel_a_indices.dtype == torch.int64
    assert I_a.dtype == torch.float32
    assert I_b.dtype == torch.float32
    assert d_a.dtype == torch.float32

    # Dimension
    N, C, H, W = I_a.shape

    # Pre-processing
    # sel_a_idx = select_gradient_pixels(I_a, d_a, threshold=50.0)[: 2000]
    M = sel_a_idx.shape[1]
    I_b_grad = batched_gradient(
        I_b
    )  # dim: (N, 2*C, H, W), (N, 0:C, H, W) = dI/dx, (N, C:2C, H, W) = dI/dy

    assert H == d_a.shape[2]
    assert W == d_a.shape[3]

    # se(3) vector init
    lambda_w = 0.2 * torch.ones(N, 6)
    d_a = d_a.view((N, H * W, 1))

    # Points' 3D Position at Frame a
    x_a_2d = x_2d_coords_torch(N, H, W).view(N, H * W, 2)
    X_a_3d = batched_pi_inv(K, x_a_2d, d_a)
    X_a_3d_sel = batched_index_select(X_a_3d, 1, sel_a_idx)

    # groundtruth wrap
    alpha_gt = torch.tensor([0.5, 0.5, 0.5, 0.0, 0.0, 0.0]).repeat(N).view(
        (N, 6))
    R_gt, _ = se3_exp(alpha_gt)
    I = torch.eye(3).view(1, 3, 3).expand(N, 3, 3).cuda()
    zeros = torch.zeros_like(T_gt[:, :, 3]).cuda()
    random_t = torch.zeros_like(T_gt[:, :, 3]).normal_(std=0.001)
    #print('random_t:', random_t)
    X_b_3d_gt = batched_transpose(R_gt, zeros, X_a_3d_sel)
    x_b_2d_gt, _ = batched_pi(K, X_b_3d_gt)

    for itr in range(0, opt_max_itr):

        R, t = se3_exp(alpha)

        X_b_3d = batched_transpose(R, t, X_a_3d_sel)
        x_b_2d, _ = batched_pi(K, X_b_3d)

        # Residual error
        e = (x_b_2d_gt - x_b_2d).view(N, M * 2)  # (N, H*W*2)

        # Compute Jacobin Mat.
        # Jacobi of Camera Pose: delta_u / delta_alpha
        J = -J_camera_pose(X_a_3d_sel, K).view(N, M * 2, 6)  # (N*M, 2, 6)

        # x_b_2d = batched_x_2d_normalize(H, W, x_b_2d).view(N, H, W, 2)                              # (N, H, W, 2)
        #
        # # Wrap the image
        # I_b_wrap = batched_interp2d(I_b, x_b_2d)
        #
        # # Residual error
        # e = (I_a - I_b_wrap).view(N, C, H*W)                                                        # (N, C, H, W)
        # e = batched_index_select(e, 2, sel_a_idx)                                                   # (N, C, M)
        # e = e.transpose(1, 2).contiguous().view(N, M*C)                                             # (N, M, C)
        #
        # # Compute Jacobin Mat.
        # # Jacobi of Camera Pose: delta_u / delta_alpha
        # du_d_alpha = J_camera_pose(X_a_3d_sel, K).view(N * M, 2, 6)                                 # (N*M, 2, 6)
        #
        # # Jacobi of Image gradient: delta_I_b / delta_u
        # dI_du = batched_interp2d(I_b_grad, x_b_2d)                                                  # (N, 2*C, H, W)
        # dI_du = batched_index_select(dI_du.view(N, 2*C, H*W), 2, sel_a_idx)                         # (N, 2*C, M)
        # dI_du = torch.transpose(dI_du, 1, 2).contiguous().view(N * M, 2, C)                                      # (N*M, 2, C)
        # dI_du = torch.transpose(dI_du, 1, 2)                                                        # (N*M, C, 2)
        #
        # # J = -dI_b/du * du/d_alpha
        # J = -torch.bmm(dI_du, du_d_alpha).view(N, C*M, 6)

        # Compute the update parameters
        delta, delta_norm = gauss_newtown_update(J, e)  # (N, 6), (N, 1)
        max_norm = torch.max(delta_norm).item()
        if max_norm < opt_eps:
            print('break')
            break

        r_norm = torch.sum(e * e, dim=1) / M  #2.0
        print('Itr:', itr, 'r_norm=', torch.sqrt(r_norm), "update_norm=",
              max_norm)

        # Update the delta
        alpha = alpha + delta

    return R, t
K = torch.from_numpy(K_set).cuda().view((N, 3, 3))
T_gt = torch.from_numpy(T_gt_set).cuda().view((N, 3, 4))

alpha = torch.tensor([1e-4, 1e-4, 1e-4, 0.0, 0.0, 0.0]).repeat(N).view((N, 6))
R_, t_ = compute_pose(I_a, d_a, sel_a_indices, I_b, K, alpha, T_gt)
print(R_, t_)
print(T_gt[:, :, :3], T_gt[:, :, 3])

d_a = d_a.view((N, H * W, 1))

# Points' 3D Position at Frame a
x_a_2d = x_2d_coords_torch(N, H, W).view(N, H * W, 2)
X_a_3d = batched_pi_inv(K, x_a_2d, d_a)

# groundtruth wrap
X_b_3d_gt = batched_transpose(T_gt[:, :, :3], T_gt[:, :, 3], X_a_3d)
x_b_2d_gt, _ = batched_pi(K, X_b_3d_gt)

#
# """ Test on 235 to 237
# """
# frame = frames.frames[128]
# # Current frame
# file_name = frame['file_name']
# img = cv2.imread(os.path.join(img_dir, file_name + '.jpg')).astype(np.float32)
# (H, W, C) = img.shape
# img = img.transpose((2, 0, 1)) / 255.0
# depth = load_depth_from_pgm(os.path.join(depth_dir, file_name + '.pgm'))
# depth[depth < 1e-5] = 1e-5
# pose = frame['extrinsic_Tcw']
# K = K_from_frame(frame)
Example #12
0
def preprocess_query(query_img,
                     query_depth,
                     query_ori_img,
                     query_Tcw,
                     ori_query_K,
                     scene_center,
                     rand_R,
                     out_dim=(48, 64)):
    # compute multiscale ground truth query_X_worlds & valid_masks
    query_X_worlds = []
    valid_masks = []
    out_H, out_W = out_dim
    ori_query_depth = query_depth.clone()
    N, C, H, W = query_depth.shape
    for i in range(5):
        query_depth_patch = F.unfold(query_depth,
                                     kernel_size=(H // out_H, W // out_W),
                                     stride=(H // out_H, W // out_W)).view(
                                         N, -1, out_H, out_W)
        mask = torch.gt(query_depth_patch, 1e-5)
        count = torch.sum(mask.float(), dim=1)
        query_depth_down = torch.sum(query_depth_patch * mask.float(), dim=1) /\
                           torch.where(torch.le(count, 1e-5),
                                       torch.full(count.shape, 1e6).cuda(),
                                       count)  # (N, 1, out_H, out_W)
        query_K = ori_query_K.clone().cuda()
        query_K[:, 0, 0] *= out_W / W
        query_K[:, 0, 2] *= out_W / W
        query_K[:, 1, 1] *= out_H / H
        query_K[:, 1, 2] *= out_H / H
        query_d = query_depth_down.view(N, out_H * out_W, 1)  # dim (N, H*W, 1)
        out_x_2d = x_2d_coords_torch(N, out_H, out_W).cuda().view(N, -1, 2)
        query_X_3d = batched_pi_inv(query_K, out_x_2d,
                                    query_d)  # dim (N, H*W, 3)
        query_Rwc, query_twc = batched_inv_pose(
            R=query_Tcw[:, :3, :3],
            t=query_Tcw[:, :3, 3].squeeze(-1))  # dim (N, 3, 3), (N, 3)
        query_X_world = batched_transpose(query_Rwc.cuda(), query_twc.cuda(),
                                          query_X_3d)  # dim (N, H*W, 3)
        query_X_world -= scene_center.view(N, 1, 3)
        query_X_world = batched_transpose(
            rand_R.cuda().expand(N, 3, 3),
            torch.zeros(1, 3, 1).cuda().expand(N, 3, 1),
            query_X_world)  # dim (N, H*W, 3), data augmentation
        query_X_world = query_X_world.permute(0, 2, 1).view(
            N, 3, out_H, out_W).contiguous()  # dim (N, 3, H, W)
        query_X_worlds.append(query_X_world.cuda())

        valid_masks.append(
            torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W))

        #         if i == 3:
        #             query_X_worlds.append(query_X_world.cuda())
        #             valid_masks.append(torch.gt(query_depth_down, 1e-5).cuda().view(N, out_H, out_W))

        out_H //= 2
        out_W //= 2

    # compute norm_query_Tcw for normalized scene coordinate
    query_twc = query_twc.cuda() - scene_center.view(N, 3, 1)
    norm_query_Twc = torch.cat([query_Rwc.cuda(), query_twc],
                               dim=-1)  # dim (N, 3, 4)
    norm_query_Twc = torch.bmm(rand_R.cuda().expand(N, 3, 3),
                               norm_query_Twc)  # dim (N, 3, 4)
    query_Rcw, query_tcw = batched_inv_pose(
        R=norm_query_Twc[:, :3, :3],
        t=norm_query_Twc[:, :3, 3].squeeze(-1))  # dim (N, 3, 3), (N, 3)
    norm_query_Tcw = torch.cat([query_Rcw, query_tcw.view(N, 3, 1)],
                               dim=-1)  # dim (N, 3, 4)

    # compute down sampled query K
    out_H, out_W = out_dim
    query_K = ori_query_K.clone().cuda()
    query_K[:, 0, 0] *= out_W / W
    query_K[:, 0, 2] *= out_W / W
    query_K[:, 1, 1] *= out_H / H
    query_K[:, 1, 2] *= out_H / H

    return query_img.cuda(), query_X_worlds[::-1], valid_masks[::-1], query_ori_img.cuda(), \
           scene_center.cuda(), query_Tcw.cuda(), query_K.cuda(), rand_R.expand(N, 3, 3).cuda()