Ejemplo n.º 1
0
    def forward(self,
                Gs,
                images,
                depths,
                intrinsics,
                graph=None,
                num_steps=12):
        """ Estimates SE3 or Sim3 between pair of frames """

        if graph is None:
            graph = OrderedDict()
            graph[0] = [1]
            graph[1] = [0]

        u = keyframe_indicies(graph)
        ii, jj, kk = graph_to_edge_list(graph)

        # use inverse depth parameterization
        depths = depths.clamp(min=0.1, max=1000.0)
        disps = 1.0 / depths[:, :, 3::8, 3::8]
        intrinsics = intrinsics / 8.0

        fmaps, net, inp = self.extract_features(images)
        corr_fn = CorrBlock(fmaps[:, ii], fmaps[:, jj], num_levels=4, radius=3)

        Gs_list, coords_list, residual_list = [], [], []
        for step in range(num_steps):
            Gs = Gs.detach()
            coords1_xyz, _ = pops.projective_transform(Gs, disps, intrinsics,
                                                       ii, jj)

            coords1, zinv_proj = coords1_xyz.split([2, 1], dim=-1)
            zinv = sample_depths(disps[:, jj], coords1)
            dz = (zinv - zinv_proj).clamp(-1.0, 1.0)

            corr = corr_fn(coords1)
            net, delta, weight = self.update(net, inp, corr, dz)

            target = coords1_xyz + delta
            for i in range(3):
                Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj)

            coords1_xyz, valid_mask = pops.projective_transform(
                Gs, disps, intrinsics, ii, jj)
            residual = valid_mask * (target - coords1_xyz)

            Gs_list.append(Gs)
            coords_list.append(target)
            residual_list.append(residual)

        return Gs_list, residual_list
Ejemplo n.º 2
0
    def moba(self, num_steps=5, is_init=False):
        """ motion only bundle adjustment """

        ii, jj = self.factors.ii, self.factors.jj
        ixs = torch.cat([ii, jj], 0)

        with autocast(enabled=True):
            fmap1 = self.fmaps[:, ii % self.mem]
            fmap2 = self.fmaps[:, jj % self.mem]
            poses = self.poses[:, :jj.max() + 1]

            corr_fn = CorrBlock(fmap1, fmap2, num_levels=4, radius=3)
            mask = (self.disps[:, ii] > 0.01).float()

            with autocast(enabled=False):
                coords, valid_mask = pops.projective_transform(
                    poses, self.disps, self.intrinsics, ii, jj)

            for i in range(num_steps):
                corr = corr_fn(coords[..., :2])
                corr = torch.cat([corr, mask[:, :, None]], dim=2)

                with autocast(enabled=False):
                    flow = self.factors.residu.permute(0, 1, 4, 2,
                                                       3).clamp(-32.0, 32.0)
                    flow = torch.cat([flow, mask[:, :, None]], dim=2)

                self.factors.hidden, delta, weight = \
                    self.update(self.factors.hidden, self.factors.inputs, corr, flow)

                with autocast(enabled=False):
                    target = coords + delta
                    weight[..., 2] = 0.0

                    for i in range(3):
                        poses = MoBA(target, weight, poses, self.disps,
                                     self.intrinsics, ii, jj, self.fixed_poses)

                    coords, valid_mask = pops.projective_transform(
                        poses, self.disps, self.intrinsics, ii, jj)
                    self.factors.residu = (target - coords)[..., :2]

        self.poses[:, :jj.max() + 1] = poses

        # update visualization
        if self.frontend is not None:
            for ix in ixs.cpu().numpy():
                self.frontend.update_pose(ix, self.poses[:, ix].inv()[0].data)
Ejemplo n.º 3
0
    def forward(self, Gs, images, depths, intrinsics, graph=None, num_steps=12):
        """ Estimates SE3 or Sim3 between pair of frames """

        u = keyframe_indicies(graph)
        ii, jj, kk = graph_to_edge_list(graph)

        depths = depths[:, :, 3::8, 3::8]
        intrinsics = intrinsics / 8
        mask = (depths > 0.1).float()
        disps = torch.where(depths>0.1, 1.0/depths, depths)

        fmaps, net, inp = self.extract_features(images)
        net, inp = net[:,ii], inp[:,ii]
        corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)

        coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
        residual = torch.zeros_like(coords[...,:2])

        Gs_list, coords_list, residual_list = [], [], []
        for step in range(num_steps):
            Gs = Gs.detach()
            coords = coords.detach()
            residual = residual.detach()

            corr = corr_fn(coords[...,:2])
            flow = residual.permute(0,1,4,2,3).clamp(-32.0, 32.0)
            
            corr = torch.cat([corr, mask[:,ii,None]], dim=2)
            flow = torch.cat([flow, mask[:,ii,None]], dim=2)
            net, delta, weight = self.update(net, inp, corr, flow)

            target = coords + delta
            weight[...,2] = 0.0

            for i in range(3):
                Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj)

            coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
            residual = (target - coords)[...,:2]

            Gs_list.append(Gs)
            coords_list.append(target)

            valid_mask = valid_mask * mask[:,ii].unsqueeze(-1)
            residual_list.append(valid_mask * residual)

        return Gs_list, residual_list
Ejemplo n.º 4
0
def reproj_test(args, N=2):
    """ Test to make sure project transform correctly maps points """

    db = dataset_factory(args.datasets, n_frames=N)
    train_loader = DataLoader(db, batch_size=1, shuffle=True, num_workers=0)

    for item in train_loader:
        images, poses, depths, intrinsics = [x.to('cuda') for x in item]
        poses = SE3(poses).inv()
        disps = 1.0 / depths

        coords, _ = pops.projective_transform(poses, disps, intrinsics, [0],
                                              [1])
        imagew = bilinear_sampler(images[:, [1]], coords[..., [0, 1]])

        # these two image should show camera motion
        show_image(images[0, 0])
        show_image(images[0, 1])

        # these two images should show the camera motion removed by reprojection / warping
        show_image(images[0, 0])
        show_image(imagew[0, 0])
Ejemplo n.º 5
0
def MoBA(target,
         weight,
         poses,
         disps,
         intrinsics,
         ii,
         jj,
         fixedp=1,
         lm=0.0001,
         ep=0.1):
    """ MoBA: Motion Only Bundle Adjustment """

    B, M = poses.shape[:2]
    D = poses.manifold_dim
    N = ii.shape[0]

    ### 1: commpute jacobians and residuals ###
    coords, valid, (Ji, Jj) = pops.projective_transform(poses,
                                                        disps,
                                                        intrinsics,
                                                        ii,
                                                        jj,
                                                        jacobian=True)

    r = (target - coords).view(B, N, -1, 1)
    w = (valid * weight).view(B, N, -1, 1)

    ### 2: construct linear system ###
    Ji = Ji.view(B, N, -1, D)
    Jj = Jj.view(B, N, -1, D)
    wJiT = (.001 * w * Ji).transpose(2, 3)
    wJjT = (.001 * w * Jj).transpose(2, 3)

    Hii = torch.matmul(wJiT, Ji)
    Hij = torch.matmul(wJiT, Jj)
    Hji = torch.matmul(wJjT, Ji)
    Hjj = torch.matmul(wJjT, Jj)

    vi = torch.matmul(wJiT, r).squeeze(-1)
    vj = torch.matmul(wJjT, r).squeeze(-1)

    # only optimize keyframe poses
    M = M - fixedp
    ii = ii - fixedp
    jj = jj - fixedp

    H = torch.zeros(B, M * M, D, D, device=target.device)
    safe_scatter_add_mat(H, Hii, ii, ii, B, M, D)
    safe_scatter_add_mat(H, Hij, ii, jj, B, M, D)
    safe_scatter_add_mat(H, Hji, jj, ii, B, M, D)
    safe_scatter_add_mat(H, Hjj, jj, jj, B, M, D)
    H = H.reshape(B, M, M, D, D)

    v = torch.zeros(B, M, D, device=target.device)
    safe_scatter_add_vec(v, vi, ii, B, M, D)
    safe_scatter_add_vec(v, vj, jj, B, M, D)

    ### 3: solve the system + apply retraction ###
    dx = block_solve(H, v, ep=ep, lm=lm)

    poses1, poses2 = poses[:, :fixedp], poses[:, fixedp:]
    poses2 = poses2.retr(dx)

    poses = lietorch.cat([poses1, poses2], dim=1)
    return poses
Ejemplo n.º 6
0
    def forward(self, poses, images, depths, intrinsics, num_steps=12):
        """ Estimates SE3 or Sim3 between pair of frames """

        keyframe_graph = KeyframeGraph(images, poses, depths, intrinsics)
        images, Gs, depths, intrinsics = keyframe_graph.get_keyframes()

        images = images.cuda()
        depths = depths.cuda()

        if self.frontend is not None:
            self.frontend.reset()
            for i, ix in enumerate(keyframe_graph.ixs):
                self.add_point_cloud(ix,
                                     images[:, i],
                                     Gs[:, i],
                                     depths[:, i],
                                     intrinsics[:, i],
                                     s=4)
            for i in range(poses.shape[1]):
                self.frontend.update_pose(i, poses[:, i].inv()[0].data)

        graph = keyframe_graph.get_graph()
        ii, jj, kk = graph_to_edge_list(graph)
        ixs = torch.cat([ii, jj], 0)

        images = normalize_images(images.cuda())
        depths = depths[:, :, 3::8, 3::8].cuda()
        mask = (depths > 0.1).float()
        disps = torch.where(depths > 0.1, 1.0 / depths, depths)
        intrinsics = intrinsics / 8

        with autocast(True):

            fmaps, net, inp = self.extract_features(images)
            net = net[:, ii]

            # alternate corr implementation uses less memory but 4x slower
            corr_fn = AltCorrBlock(fmaps.float(), (ii, jj),
                                   num_levels=4,
                                   radius=3)
            # corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)

            with autocast(False):
                coords, valid_mask = pops.projective_transform(
                    Gs, disps, intrinsics, ii, jj)
                residual = torch.zeros_like(coords[..., :2])

            for step in range(num_steps):
                print("Global refinement iteration #{}".format(step))
                net_list = []
                targets_list = []
                weights_list = []

                s = 64
                for i in range(0, ii.shape[0], s):
                    ii1 = ii[i:i + s]
                    jj1 = jj[i:i + s]

                    corr1 = corr_fn(coords[:, i:i + s, :, :, :2], ii1, jj1)
                    flow1 = residual[:, i:i + s].permute(0, 1, 4, 2,
                                                         3).clamp(-32.0, 32.0)

                    corr1 = torch.cat([corr1, mask[:, ii1, None]], dim=2)
                    flow1 = torch.cat([flow1, mask[:, ii1, None]], dim=2)

                    net1, delta, weight = self.update(net[:, i:i + s],
                                                      inp[:,
                                                          ii1], corr1, flow1)
                    net[:, i:i + s] = net1

                    targets_list += [coords[:, i:i + s] + delta.float()]
                    weights_list += [
                        weight.float() *
                        torch.as_tensor([1.0, 1.0, 0.0]).cuda()
                    ]

                target = torch.cat(targets_list, 1)
                weight = torch.cat(weights_list, 1)

                with autocast(False):
                    for i in range(3):
                        Gs = MoBA(target,
                                  weight,
                                  Gs,
                                  disps,
                                  intrinsics,
                                  ii,
                                  jj,
                                  lm=0.00001,
                                  ep=.01)

                    coords, valid_mask = pops.projective_transform(
                        Gs, disps, intrinsics, ii, jj)
                    residual = (target - coords)[..., :2]

                poses = keyframe_graph.get_poses(Gs)
                if self.frontend is not None:
                    for i in range(poses.shape[1]):
                        self.frontend.update_pose(i, poses[:, i].inv()[0].data)

        return poses
Ejemplo n.º 7
0
 def transform_project(self, ii, jj, **kwargs):
     """ helper function, compute project transform """
     return pops.projective_transform(self.poses, self.disps,
                                      self.intrinsics, ii, jj, **kwargs)