示例#1
0
文件: utils.py 项目: felja633/RLLReg
def GARPointList(t1_list, t2_list, R1_list, R2_list, V, c, alpha):
    err = TensorListList()

    for t1, t2, R1, R2, Vb in zip(t1_list, t2_list, R1_list, R2_list, V):
        M = len(t1)
        err1 = TensorList()
        for ind1 in range(M - 1):
            for ind2 in range(ind1 + 1, M):
                Rtest = R1[ind1].permute(1, 0) @ R1[ind2]
                # estimated relative translation from ind2 to ind1
                test = R1[ind1].permute(1, 0) @ t1[ind2] - R1[ind1].permute(
                    1, 0) @ t1[ind1]

                Rgt = R2[ind1].permute(1, 0) @ R2[ind2]
                # ground truth translation from ind2 to ind1
                tgt = R2[ind1].permute(1, 0) @ t2[ind2] - R2[ind1].permute(
                    1, 0) @ t2[ind1]

                diff = Rtest @ Vb[ind2] + test - (Rgt @ Vb[ind2] + tgt)

                err_i = (diff * diff).sum(dim=-2).sqrt()
                err1.append(lossfun(err_i, alpha=alpha, scale=c).mean())

        err.append(err1)

    return err
示例#2
0
def empirical_estimate(points, num_neighbors):
    ps, batch = points.permute(1, 0).cat_tensors()

    N = ps.shape[0]
    val = knn(ps.contiguous(),
              ps.contiguous(),
              batch_x=batch,
              batch_y=batch,
              k=num_neighbors)
    A = ps[val[1, :]].reshape(N, num_neighbors, 3)
    A = A - A.mean(dim=1, keepdim=True)
    Asqr = A.permute(0, 2, 1).bmm(A)
    sigma, _ = Asqr.cpu().symeig()
    w = (sigma[:, 2] * sigma[:, 1]).sqrt()
    val = val[1, :].reshape(N, num_neighbors)
    w, _ = torch.median(w[val].to(ps.device), dim=1, keepdim=True)

    weights = TensorListList()
    bi = 0
    for point_list in points:
        ww = TensorList()
        for p in point_list:
            ww.append(w[batch == bi])
            bi = bi + 1

        weights.append(ww)

    return weights
示例#3
0
    def __init__(self, parameters, K, features, s, mu, repeat_list):
        self.params = parameters
        self.mu = self.initialize_mu(K, features)
        if len(mu) > 0:
            self.mu = TensorListList(mu, repeat=repeat_list)

        self.K = K
        self.repeat_list = repeat_list
        self.s2 = s * s
        self.local_posterior = 1
示例#4
0
文件: utils.py 项目: felja633/RLLReg
def L2sqrList(R1_list, R2_list):
    err = TensorListList()
    for R1, R2 in zip(R1_list, R2_list):
        M = len(R1)
        err1 = TensorList()
        for ind1 in range(M - 1):
            for ind2 in range(ind1 + 1, M):
                Rdiff = R1[ind1].permute(1, 0) @ R1[ind2] - R2[ind1].permute(
                    1, 0) @ R2[ind2]
                err1.append(
                    (Rdiff * Rdiff).sum(dim=-1).sum(dim=-1).unsqueeze(0))
        err.append(err1)

    return err
示例#5
0
def get_init_transformation_list(pcds, mean_init=True):
    if mean_init:
        m_pcds = pcds.mean(dim=-1, keepdim=True)
        m_target = TensorList([m[0] for m in m_pcds])
        init_t = -(m_pcds - m_target)
    else:
        tt = []
        for i in range(len(pcds)):
            tt.append(TensorList([torch.zeros(3, 1).to(pcds[0][0].device) for i in range(len(pcds[i]))]))
        init_t = TensorListList(tt)

    rr = []
    for i in range(len(pcds)):
        rr.append(TensorList([torch.eye(3, 3).to(pcds[0][0].device) for i in range(len(pcds[i]))]))

    init_R = TensorListList(rr)
    return init_R, init_t
示例#6
0
文件: utils.py 项目: felja633/RLLReg
def L2sqrTransList(t1_list, t2_list, R1_list, R2_list):
    err = TensorListList()
    for t1, t2, R1, R2 in zip(t1_list, t2_list, R1_list, R2_list):
        M = len(t1)
        err1 = TensorList()
        for ind1 in range(M - 1):
            for ind2 in range(ind1 + 1, M):
                # estimated relative translation from ind2 to ind1
                test = R1[ind1].permute(1, 0) @ t1[ind2] - R1[ind1].permute(
                    1, 0) @ t1[ind1]
                # ground truth translation from ind2 to ind1
                tgt = R2[ind1].permute(1, 0) @ t2[ind2] - R2[ind1].permute(
                    1, 0) @ t2[ind1]

                diff = test - tgt
                err1.append((diff * diff).sum(dim=-2))
        err.append(err1)
    return err
示例#7
0
    def cluster_features(self, features, num_clusters):
        feature_labels_LL = TensorListList()
        for f in features:
            feature_labels_L = TensorList()
            fcat = torch.cat([fi for fi in f])
            fcat = fcat.to("cpu").numpy()
            kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(fcat)
            labels = torch.from_numpy(kmeans.labels_)
            onehot = torch.nn.functional.one_hot(
                labels.long(), num_clusters).to(self.params.device)
            cnt = 0
            for fi in f:
                feature_labels_L.append(onehot[cnt:cnt + fi.shape[0]])
                cnt += fi.shape[0]

            feature_labels_LL.append(feature_labels_L.permute(1, 0).float())

        return feature_labels_LL
示例#8
0
class MultinomialModel(BaseFeatureModel):
    def __init__(self, parameters, K, features, repeat_list):
        self.params = parameters
        self.repeat_list = repeat_list
        self.initialize(K, features)

    def initialize(self, K, features):
        init_method = self.params.get("init_method", "dirichlet")
        if init_method == "uniform":
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)
        elif init_method == "dirichlet":
            init_distr_list = TensorList()
            relative_distr = features.sum(dim=1).sum_list()
            relative_distr = relative_distr / relative_distr.sum()
            for r in relative_distr:
                dir = torch.distributions.dirichlet.Dirichlet(r)
                init_distr_list.append(dir.sample((K, )))
        else:
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)

        self.distr = TensorListList(init_distr_list, repeat=self.repeat_list)

    def to(self, dev):
        self.distr = self.distr.to(dev)

    def posteriors(self, y):
        p = y.permute(1, 0) @ self.distr.permute(1, 0)  # marginalization
        return p

    def maximize(self, ap, ow, y, den):
        tmp = y / den.permute(1, 0)
        as_sum = (tmp @ (ow * ap)).sum_list()
        self.distr = self.distr * TensorListList(as_sum.permute(1, 0),
                                                 repeat=self.repeat_list)
        self.distr = self.distr / self.distr.sum(dim=1, keepdims=True)
示例#9
0
文件: utils.py 项目: felja633/RLLReg
def global_to_relative(R, t, ref_index):
    R_rel = TensorListList()
    t_rel = TensorListList()
    for Ri, ti in zip(R, t):
        R_ref = Ri[ref_index]
        t_ref = ti[ref_index]
        R_reli = R_ref.permute(1, 0) @ Ri
        t_reli = R_ref.permute(1, 0) @ ti - R_ref.permute(1, 0) @ t_ref
        R_rel.append(R_reli)
        t_rel.append(t_reli)

    return R_rel, t_rel
示例#10
0
class VonMisesModelList(BaseFeatureModel):
    def __init__(self, parameters, K, features, s, mu, repeat_list):
        self.params = parameters
        self.mu = self.initialize_mu(K, features)
        if len(mu) > 0:
            self.mu = TensorListList(mu, repeat=repeat_list)

        self.K = K
        self.repeat_list = repeat_list
        self.s2 = s * s
        self.local_posterior = 1

    def to(self, dev):
        self.mu = self.mu.to(dev)

    def initialize_mu(self, K, features):
        X = TensorList()
        for TV in features:
            Xi = np.random.randn(TV[0].shape[0], K).astype(np.float32)
            Xi = torch.from_numpy(Xi).to(TV[0].device)
            Xi = Xi / torch.norm(Xi, dim=0, keepdim=True)
            X.append(Xi.permute(1, 0))

        return X

    def posteriors(self, y):
        log_p = y.permute(1, 0) @ TensorListList(self.mu.permute(1, 0),
                                                 repeat=self.repeat_list)
        p = log_p / self.s2
        return p.exp()

    def maximize(self, a, y, den):
        self.mu = ((y @ a).sum_list()).permute(1, 0)
        self.mu = self.mu / self.mu.norm(dim=-1, keepdim=True)
        return

    def detach(self):
        self.mu = self.mu.detach()
示例#11
0
    def initialize(self, K, features):
        init_method = self.params.get("init_method", "dirichlet")
        if init_method == "uniform":
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)
        elif init_method == "dirichlet":
            init_distr_list = TensorList()
            relative_distr = features.sum(dim=1).sum_list()
            relative_distr = relative_distr / relative_distr.sum()
            for r in relative_distr:
                dir = torch.distributions.dirichlet.Dirichlet(r)
                init_distr_list.append(dir.sample((K, )))
        else:
            init_distr_list = []
            for i in range(len(features)):
                init_distr_list.append(
                    torch.ones(K, self.params.num_feature_clusters).float() /
                    self.params.num_feature_clusters)

        self.distr = TensorListList(init_distr_list, repeat=self.repeat_list)
示例#12
0
def collate_tensorlist(batch):
    out = dict()
    info = dict()
    tensorlistkeysout = []
    tensorlistkeysinfo = []
    for b in batch:
        cb, ib = b
        for k, v in cb.items():
            if isinstance(v, TensorList):
                if not k in out:
                    out[k] = TensorList()
                    tensorlistkeysout.append(k)
            else:
                if not k in out:
                    out[k] = []

            out[k].append(v)

        for k, v in ib.items():
            if isinstance(v, TensorList):
                if not k in info:
                    info[k] = TensorList()
                    tensorlistkeysinfo.append(k)
            else:
                if not k in info:
                    info[k] = []

            info[k].append(v)

    for k in tensorlistkeysout:
        out[k] = TensorListList(out[k])

    for k in tensorlistkeysinfo:
        info[k] = TensorListList(info[k])

    return out, info
示例#13
0
    def extract_features(self, x):
        coords = x['coordinates']

        sinput, inds_list = self.create_sparse_tensors(coords)
        features_dict = self.feature_extractor(sinput)
        features = features_dict["features"]
        if "attention" in features_dict.keys():
            att = features_dict["attention"]

        batch_indices = list(features.coords_man.get_batch_indices())
        features_LL = TensorListList()
        att_LL = TensorListList()
        coords_ds_LL = TensorListList()
        inds_LL = []
        ind_cnt = 0
        for coords_b in coords:
            features_L = TensorList()
            att_L = TensorList()
            coords_ds_L = TensorList()
            inds_L = []
            for coords_s in coords_b:
                mask = features.C[:, 0] == batch_indices[ind_cnt]
                if mask.int().sum() != inds_list[ind_cnt].shape[0]:
                    mask = features.C[:, -1] == batch_indices[ind_cnt]

                f = features.F[mask]
                assert f.shape[0] == inds_list[ind_cnt].shape[0]
                if "attention" in features_dict.keys():
                    a = att.F[mask]
                    assert a.shape[0] == inds_list[ind_cnt].shape[0]
                    att_L.append(a)

                features_L.append(f.permute(1, 0))

                coords_ds_L.append(coords_s[:, inds_list[ind_cnt]])
                inds_L.append(inds_list[ind_cnt])

                ind_cnt = ind_cnt + 1

            features_LL.append(features_L)
            att_LL.append(att_L)
            coords_ds_LL.append(coords_ds_L)
            inds_LL.append(inds_L)

        out = dict()
        out['features'] = self.cluster_features(
            features_LL,
            self.params.feature_distr_parameters.num_feature_clusters)
        out['att'] = att_LL
        out['coordinates_ds'] = coords_ds_LL.to(self.params.device)
        out['indices'] = inds_LL

        return out
示例#14
0
    def forward(self, x):
        coords = x['coordinates'].clone()
        t_tot = time.time()

        sinput, inds_list = self.create_sparse_tensors(coords)
        features_dict = self.feature_extractor(sinput)
        features = features_dict["features"]
        if "attention" in features_dict.keys():
            att = features_dict["attention"]

        if torch.isnan(features.feats).any():
            print("nans in features!")

        batch_indices = list(features.coords_man.get_batch_indices())
        features_LL = TensorListList()
        att_LL = TensorListList()
        coords_ds_LL = TensorListList()
        inds_LL = []
        ind_cnt = 0
        for coords_b in coords:
            features_L = TensorList()
            att_L = TensorList()
            coords_ds_L = TensorList()
            inds_L = []
            for coords_s in coords_b:
                mask = features.C[:, 0] == batch_indices[ind_cnt]
                if mask.int().sum() != inds_list[ind_cnt].shape[0]:
                    mask = features.C[:, -1] == batch_indices[ind_cnt]

                f = features.F[mask]
                assert f.shape[0] == inds_list[ind_cnt].shape[0]
                if "attention" in features_dict.keys():
                    a = att.F[mask]
                    assert a.shape[0] == inds_list[ind_cnt].shape[0]
                    att_L.append(a)

                features_L.append(f)

                coords_ds_L.append(coords_s[:, inds_list[ind_cnt]])
                inds_L.append(inds_list[ind_cnt])

                ind_cnt = ind_cnt + 1

            features_LL.append(features_L)
            att_LL.append(att_L)
            coords_ds_LL.append(coords_ds_L)
            inds_LL.append(inds_L)

        x = dict()
        x['features'] = self.cluster_features(
            features_LL,
            self.params.feature_distr_parameters.num_feature_clusters)
        x['att'] = att_LL
        x['coordinates_ds'] = coords_ds_LL.to(self.params.device)
        x['indices'] = inds_LL

        out = self.registration(x)

        tot_time = time.time() - t_tot
        print("tot time: %.1f ms" % (tot_time * 1000))

        out["time"] = tot_time
        out["features"] = features_LL
        out["indices"] = inds_LL
        out["coordinates_ds"] = coords_ds_LL

        return out
示例#15
0
    def register_point_sets(self, x):
        point_clouds = x["coordinates"].clone()
        t_tot = time.time()

        if isinstance(point_clouds, TensorListList):
            Rs = TensorListList()
            ts = TensorListList()
        else:
            # B, P, N, M = point_clouds.shape
            # assert P == 2
            Rs = []
            ts = []

        target_R = torch.eye(3, 3)
        target_t = torch.zeros(3, 1)
        for pcds in point_clouds:
            source = o3d.geometry.PointCloud()
            target = o3d.geometry.PointCloud()
            source.points = o3d.utility.Vector3dVector(pcds[0].cpu().numpy().T)
            target.points = o3d.utility.Vector3dVector(pcds[1].cpu().numpy().T)

            T = self.dgr.register(source, target)
            R = T[0:3, 0:3]
            t = T[0:3, 3:]

            if isinstance(point_clouds, TensorListList):
                Rs.append(TensorList([torch.from_numpy(R).float(), target_R]))
                ts.append(TensorList([torch.from_numpy(t).float(), target_t]))
            else:
                Rs.append(torch.stack([torch.from_numpy(R).float(), target_R]))
                ts.append(torch.stack([torch.from_numpy(t).float(), target_t]))

        tot_time = time.time() - t_tot
        if isinstance(point_clouds, TensorListList):
            return dict(Rs=Rs, ts=ts, Vs=Rs @ point_clouds + ts, time=tot_time)
        else:
            return dict(Rs=torch.stack(Rs),
                        ts=torch.stack(ts),
                        Vs=point_clouds,
                        time=tot_time)
示例#16
0
文件: icp.py 项目: felja633/RLLReg
    def register_point_sets(self, x):
        point_clouds = x["coordinates"].clone()
        t_tot = time.time()

        if isinstance(point_clouds, TensorListList):
            Rs = TensorListList()
            ts = TensorListList()
        else:
            #B, P, N, M = point_clouds.shape
            #assert P == 2
            Rs = []
            ts = []

        target_R = torch.eye(3, 3)
        target_t = torch.zeros(3, 1)
        for pcds in point_clouds:
            if self.params.get("mean_init", True):
                m_target = pcds[1].mean(dim=1)
                m_source = pcds[0].mean(dim=1)
                init_t = -(m_target - m_source).cpu().numpy()
            else:
                init_t = [0, 0, 0]

            trans_init = np.asarray([[1., 0., 0., init_t[0]],
                                     [0., 1., 0., init_t[1]],
                                     [0., 0., 1., init_t[2]],
                                     [0.0, 0.0, 0.0, 1.0]])

            source = o3d.geometry.PointCloud()
            target = o3d.geometry.PointCloud()
            source.points = o3d.utility.Vector3dVector(pcds[0].cpu().numpy().T)
            target.points = o3d.utility.Vector3dVector(pcds[1].cpu().numpy().T)

            voxel_size = self.params.voxel_size
            max_correspondence_distance = self.params.threshold
            radius_normal = float(self.params.radius_normal)

            if self.params.metric == "p2pl":
                source.estimate_normals(
                    search_param=o3d.geometry.KDTreeSearchParamHybrid(
                        radius=radius_normal, max_nn=30))
                target.estimate_normals(
                    search_param=o3d.geometry.KDTreeSearchParamHybrid(
                        radius=radius_normal, max_nn=30))
                # o3d.geometry.estimate_normals(source)
                # o3d.geometry.estimate_normals(target)

            if self.params.get("downsample", True):
                source = source.voxel_down_sample(voxel_size=voxel_size)
                target = target.voxel_down_sample(voxel_size=voxel_size)

            reg_p2p = o3d.registration.registration_icp(
                source,
                target,
                max_correspondence_distance=max_correspondence_distance,
                estimation_method=self.metric,
                init=trans_init)

            T = reg_p2p.transformation
            R = T[0:3, 0:3]
            t = T[0:3, 3:]

            if isinstance(point_clouds, TensorListList):
                Rs.append(TensorList([torch.from_numpy(R).float(), target_R]))
                ts.append(TensorList([torch.from_numpy(t).float(), target_t]))
            else:
                Rs.append(torch.stack([torch.from_numpy(R).float(), target_R]))
                ts.append(torch.stack([torch.from_numpy(t).float(), target_t]))

        tot_time = time.time() - t_tot
        print("ICP tot time: %.1f ms" % (tot_time * 1000))

        if isinstance(point_clouds, TensorListList):
            return dict(Rs=Rs, ts=ts, Vs=Rs @ point_clouds + ts, time=tot_time)
        else:
            return dict(Rs=torch.stack(Rs),
                        ts=torch.stack(ts),
                        Vs=point_clouds,
                        time=tot_time)
示例#17
0
    def forward(self, x_in):
        coords = x_in['coordinates'].clone()
        t_tot = time.time()
        sinput, inds_list = self.create_sparse_tensors(coords)
        resample_time = time.time() - t_tot
        print("resample time: %.1f ms" % ((resample_time) * 1000))
        if not self.params.feature_distr_parameters.model=='none':
            features_dict = self.feature_extractor(sinput)
            features = features_dict["features"]
            if "attention" in features_dict.keys():
                att = features_dict["attention"]

            extract_time=time.time()-t_tot
            print("extract time: %.1f ms" % ((extract_time) * 1000))
            batch_indices = list(features.coords_man.get_batch_indices())
        else:
            features_dict=None
            extract_time=0

        time_preprocess = time.time()

        features_LL = TensorListList()
        att_LL = TensorListList()
        coords_ds_LL = TensorListList()
        inds_LL = []
        ind_cnt = 0
        for coords_b in coords:
            features_L = TensorList()
            att_L = TensorList()
            coords_ds_L = TensorList()
            inds_L = []
            for coords_s in coords_b:
                if not features_dict is None:
                    mask = features.C[:, 0] == batch_indices[ind_cnt]

                    # hacky way of finding batch channel dimension
                    if mask.int().sum() != inds_list[ind_cnt].shape[0]:
                        mask = features.C[:, -1] == batch_indices[ind_cnt]

                    f = features.F[mask]
                    assert f.shape[0] == inds_list[ind_cnt].shape[0]
                    if "attention" in features_dict.keys():
                        a = att.F[mask]
                        assert a.shape[0] == inds_list[ind_cnt].shape[0]
                        att_L.append(a)

                    features_L.append(f.permute(1, 0))

                coords_ds_L.append(coords_s[:, inds_list[ind_cnt]])
                inds_L.append(inds_list[ind_cnt])

                ind_cnt = ind_cnt + 1

            features_LL.append(features_L)
            att_LL.append(att_L)
            coords_ds_LL.append(coords_ds_L)
            inds_LL.append(inds_L)

        x = dict()
        x['features'] = features_LL
        x['att'] = att_LL
        x['coordinates_ds'] = coords_ds_LL.to(self.params.device)
        x['indices'] = inds_LL
        x['coordinates'] = x_in['coordinates']

        print("preprocess time: %.1f ms" % ((time.time()-time_preprocess) * 1000))

        reg_time=time.time()
        out = self.registration(x)
        reg_time2 = time.time() - reg_time
        print("reg time: %.1f ms" % ((time.time() - reg_time) * 1000))
        tot_time = time.time() - t_tot
        print("tot time: %.1f ms" % (tot_time * 1000))

        out["time"] = tot_time
        out["reg_time"] = reg_time2
        out["extract_time"] =extract_time
        out["resample_time"] = resample_time
        out["coordinates_ds"] = coords_ds_LL

        return out
示例#18
0
def get_init_transformation_list_dgr(pcds, dgr_init_model):
    if isinstance(pcds, TensorListList):
        Rs = TensorListList()
        ts = TensorListList()
    else:
        # B, P, N, M = point_clouds.shape
        # assert P == 2
        Rs = []
        ts = []

    target_R = torch.eye(3, 3).to(pcds[0][0].device)
    target_t = torch.zeros(3, 1).to(pcds[0][0].device)
    for pc in pcds:
        assert len(pc) == 2
        R,t=dgr_init_model.register_point_sets(pc[0].permute(1,0), pc[1].permute(1,0))

        if isinstance(pcds, TensorListList):
            Rs.append(TensorList([R, target_R]))
            ts.append(TensorList([t.unsqueeze(dim=1), target_t]))
        else:
            Rs.append(torch.stack([R, target_R]))
            ts.append(torch.stack([t.unsqueeze(dim=1), target_t]))

    return Rs, ts
示例#19
0
 def posteriors(self, y):
     log_p = y.permute(1, 0) @ TensorListList(self.mu.permute(1, 0),
                                              repeat=self.repeat_list)
     p = log_p / self.s2
     return p.exp()
示例#20
0
    def forward(self, x):
        coords = x['coordinates'].clone()
        t_tot = time.time()

        features_LL = TensorListList()
        att_LL = TensorListList()
        coords_ds_LL = TensorListList()
        inds_LL = []
        ind_cnt = 0
        for coords_b in coords:
            features_L = TensorList()
            att_L = TensorList()
            coords_ds_L = TensorList()
            inds_L = []
            for coords_s in coords_b:
                pcd_down, f = self.preprocess_point_cloud(
                    coords_s, self.params.voxel_size)

                features_L.append(f)
                coords_ds_L.append(pcd_down)

                ind_cnt = ind_cnt + 1

            features_LL.append(features_L)
            att_LL.append(att_L)
            coords_ds_LL.append(coords_ds_L)
            inds_LL.append(inds_L)

        x = dict()
        x['features'] = self.cluster_features(
            features_LL,
            self.params.feature_distr_parameters.num_feature_clusters)
        x['att'] = att_LL
        x['coordinates_ds'] = coords_ds_LL.to(self.params.device)
        x['indices'] = inds_LL

        out = self.registration(x)
        tot_time = time.time() - t_tot
        print("tot time: %.1f ms" % (tot_time * 1000))

        out["time"] = tot_time
        out["features"] = features_LL
        out["indices"] = inds_LL
        out["coordinates_ds"] = coords_ds_LL

        return out
示例#21
0
文件: fppsr.py 项目: felja633/RLLReg
    def register_point_sets(self, x):
        Vs = x["coordinates_ds"]
        features = x["features"]
        features_w = x["att"]
        repeat_list = [len(V) for V in Vs]
        init_R, init_t = get_init_transformation_list(
            Vs, self.params.get("mean_init", True))
        TVs = init_R @ Vs + init_t

        X = TensorList()
        Q = TensorList()
        mu = TensorList()
        for TV, Fs in zip(TVs, features):
            if self.params.cluster_init == "box":
                Xi = get_randn_box_cluster_means_list(TV, self.params.K)
            else:
                Xi = get_randn_sphere_cluster_means_list(
                    TV, self.params.K,
                    self.params.get("cluster_mean_scale", 1.0))
            Q.append(
                get_scaled_cluster_precisions_list(
                    TV, Xi, self.params.cluster_precision_scale))
            X.append(Xi.T)

        feature_distr = feature_models.MultinomialModel(
            self.params.feature_distr_parameters,
            self.params.K,
            features,
            repeat_list=repeat_list)

        feature_distr.to(self.params.device)

        X = TensorListList(X, repeat=repeat_list)
        self.betas = get_default_beta(Q, self.params.gamma)

        Vs = Vs
        TVs = TVs

        # Compute the observation weights
        if self.params.use_dare_weighting:
            observation_weights = empirical_estimate(Vs, self.params.ow_args)
            ow_reg_factor = 8.0
            ow_mean = observation_weights.mean(dim=0, keepdim=True)
            for idx in range(len(observation_weights)):
                for idxx in range(len(observation_weights[idx])):
                    observation_weights[idx][idxx][observation_weights[idx][idxx] > ow_reg_factor * ow_mean[idx][idxx]] \
                        = ow_reg_factor * ow_mean[idx][idxx]

        else:
            observation_weights = 1.0

        ds = TVs.permute(1, 0).sqe(X).permute(1, 0)

        if self.params.debug:
            self.visdom.register(
                dict(pcds=Vs[0].cpu(), X=X[0][0].cpu(), c=None),
                'point_clouds', 2, 'init')
            time.sleep(1)

        Rs = init_R.to(self.params.device)
        ts = init_t.to(self.params.device)

        self.betas = TensorListList(self.betas, repeat=repeat_list)
        QL = TensorListList(Q, repeat=repeat_list)
        Riter = TensorListList()
        titer = TensorListList()
        TVs_iter = TensorListList()
        for i in range(self.params.num_iters):
            Qt = QL.permute(1, 0)

            ap = (-0.5 * ds * QL).exp() * QL.pow(1.5)

            if i < 1000:
                pyz_feature = feature_distr.posteriors(features)
            else:
                pyz_feature = 1.0

            a = ap * pyz_feature

            ac_den = a.sum(dim=-1, keepdim=True) + self.betas
            a = a / ac_den  # normalize row-wise
            a = a * observation_weights

            L = a.sum(dim=-2, keepdim=True).permute(1, 0)
            W = (Vs @ a) * QL

            b = L * Qt  # weights, b
            mW = W.sum(dim=-1, keepdim=True)
            mX = (b.permute(1, 0) @ X).permute(1, 0)
            z = L.permute(1, 0) @ Qt
            P = (W @ X).permute(1, 0) - mX @ mW.permute(1, 0) / z

            # Compute R and t
            svd_list_list = P.cpu().svd()
            Rs = TensorListList()
            for svd_list in svd_list_list:
                Rs_list = TensorList()
                for svd in svd_list:
                    uu, vv = svd.U, svd.V
                    vvt = vv.permute(1, 0)
                    detuvt = uu @ vvt
                    detuvt = detuvt.det()
                    S = torch.ones(1, 3)
                    S[:, -1] = detuvt
                    Rs_list.append((uu * S) @ vvt)

                Rs.append(Rs_list)

            Rs = Rs.to(self.params.device)
            Riter.append(Rs)
            ts = (mX - Rs @ mW) / z
            titer.append(ts)
            TVs = Rs @ Vs + ts

            TVs_iter.append(TVs.clone())
            if self.params.debug:
                self.visdom.register(
                    dict(pcds=TVs[0].cpu(), X=X[0][0].cpu(), c=None),
                    'point_clouds', 2, 'registration-iter')
                time.sleep(0.2)
            # Update X
            den = L.sum_list()

            if self.params.fix_cluster_pos_iter < i:
                X = (TVs @ a).permute(1, 0)
                X = TensorListList(X.sum_list() / den, repeat_list)

            # Update Q
            ds = TVs.permute(1, 0).sqe(X).permute(1, 0)

            wn = (a * ds).sum(dim=-2, keepdim=True).sum_list()
            Q = (3 * den /
                 (wn.permute(1, 0) + 3 * den * self.params.epsilon)).permute(
                     1, 0)
            QL = TensorListList(Q, repeat=repeat_list)

            feature_distr.maximize(ap=ap,
                                   ow=observation_weights,
                                   y=features,
                                   den=ac_den)

        out = dict(Rs=Rs,
                   ts=ts,
                   X=X,
                   Riter=Riter[:-1],
                   titer=titer[:-1],
                   Vs=TVs,
                   Vs_iter=TVs_iter[:-1],
                   ow=observation_weights)
        return out
示例#22
0
 def maximize(self, ap, ow, y, den):
     tmp = y / den.permute(1, 0)
     as_sum = (tmp @ (ow * ap)).sum_list()
     self.distr = self.distr * TensorListList(as_sum.permute(1, 0),
                                              repeat=self.repeat_list)
     self.distr = self.distr / self.distr.sum(dim=1, keepdims=True)
示例#23
0
    def register_point_sets(self, x):
        point_clouds = x["coordinates"].clone()

        if isinstance(point_clouds, TensorListList):
            Rs = TensorListList()
            ts = TensorListList()
        else:
            # B, P, N, M = point_clouds.shape
            # assert P == 2
            Rs = []
            ts = []

        target_R = torch.eye(3,3)
        target_t = torch.zeros(3,1)
        for pcds in point_clouds:
            source = o3d.geometry.PointCloud()
            target = o3d.geometry.PointCloud()
            source.points = o3d.utility.Vector3dVector(pcds[0].cpu().numpy().T)
            target.points = o3d.utility.Vector3dVector(pcds[1].cpu().numpy().T)
            source_down, target_down, source_fpfh, target_fpfh = \
                self.prepare_dataset(source, target, self.params.voxel_size)

            result_fast = self.execute_global_registration(source_down, target_down,
                                                           source_fpfh, target_fpfh,
                                                           self.params.voxel_size)

            T = result_fast.transformation
            R = T[0:3, 0:3]
            t = T[0:3, 3:]

            if self.params.get("refine", True):
                radius_normal = self.params.voxel_size * 2
                if self.params.metric == "p2pl":
                    source.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
                    target.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
                    #o3d.geometry.estimate_normals(source)
                    #o3d.geometry.estimate_normals(target)

                distance_threshold = self.params.voxel_size * 0.5
                reg_p2p = o3d.registration.registration_icp(
                        source, target, max_correspondence_distance=distance_threshold,
                        estimation_method=o3d.registration.TransformationEstimationPointToPlane(),
                    init=result_fast.transformation)

                T = reg_p2p.transformation
                R = T[0:3,0:3]
                t = T[0:3,3:]

            if isinstance(point_clouds, TensorListList):
                Rs.append(TensorList([torch.from_numpy(R).float(), target_R]))
                ts.append(TensorList([torch.from_numpy(t).float(), target_t]))
            else:
                Rs.append(torch.stack([torch.from_numpy(R).float(), target_R]))
                ts.append(torch.stack([torch.from_numpy(t).float(), target_t]))

        if isinstance(point_clouds, TensorListList):
            return dict(Rs=Rs, ts=ts, Vs=point_clouds)
        else:
            return dict(Rs=torch.stack(Rs), ts=torch.stack(ts), Vs=point_clouds)
示例#24
0
    def __call__(self, batch, epoch):
        if self.epoch != epoch:
            self.epoch = epoch
            self.iter_cnt = 0
            self.num_success_acc = 0

        data, info = batch
        batch_size = len(data['coordinates'])
        device = self.model.params.device
        loss_iter = 0
        loss = 0
        final_loss = 0

        out_feat = self.model.extract_features(data)
        features = out_feat["features"]
        info["R_gt"] = info["R_gt"].to(device)
        info["t_gt"] = info["t_gt"].to(device)

        num_pairs = 0
        for fb in features:
            M = len(fb)
            for ind1 in range(M - 1):
                for ind2 in range(ind1 + 1, M):
                    num_pairs = num_pairs + 1

        # check valid pairs wrt number of correspondences
        # invalid pairs are ignored in the computation of the loss
        valid_pairs_LL = []
        if self.min_corresponence_rate > 0.0:
            correspondences = []
            for coord_ds, r, t in zip(out_feat["coordinates_ds"], info["R_gt"],
                                      info["t_gt"]):
                correspondences.append(
                    utils.extract_correspondences_gpu(coord_ds, r, t))

            for fb, corr in zip(features, correspondences):
                db = corr["distances"].to(device)
                M = len(fb)
                cnt = 0
                valid_pairs_L = []
                for ind1 in range(M - 1):
                    for ind2 in range(ind1 + 1, M):
                        mask = db[cnt] < self.th
                        corrs_rate = mask.sum().float() / mask.shape[0]
                        valid_pairs_L.append(
                            corrs_rate > self.min_corresponence_rate)
                        cnt = cnt + 1

                valid_pairs_LL.append(sum(valid_pairs_L))

        else:
            for fb in features:
                cnt = 0
                M = len(fb)
                for ind1 in range(M - 1):
                    for ind2 in range(ind1 + 1, M):
                        cnt = cnt + 1

                valid_pairs_LL.append(cnt)

        Rgt, tgt = info["R_gt"], info["t_gt"]

        num_samples = sum(v > 0.0 for v in valid_pairs_LL)
        if self.min_corresponence_rate > 0.0:
            self.iter_cnt += num_samples.item()
        else:
            self.iter_cnt += num_samples

        # only compute loss is there is at least one valid pair in the batch
        if sum(valid_pairs_LL) == 0:
            return loss, self.num_success_acc / (self.iter_cnt)

        if sum(valid_pairs_LL) < num_pairs:
            out_feat_filt = dict()
            for k in out_feat.keys():
                out_feat_filt[k] = TensorListList([
                    out_feat[k][i] for i in range(len(out_feat[k]))
                    if valid_pairs_LL[i]
                ])

            Rgt = TensorListList(
                [Rgt[i] for i in range(batch_size) if valid_pairs_LL[i]])
            tgt = TensorListList(
                [tgt[i] for i in range(batch_size) if valid_pairs_LL[i]])
            out_feat = out_feat_filt

        out_reg = self.model.register(out_feat)
        if not self.vis is None:
            self.vis(out_reg, info, data)

        Rs, ts = out_reg["Rs"], out_reg["ts"]
        Riter, titer = out_reg["Riter"], out_reg["titer"]
        Rgt = Rgt.to(Rs[0][0].device)
        tgt = tgt.to(ts[0][0].device)

        rot_errs = utils.L2sqrList(Rs.detach(), Rgt).sqrt()
        trans_errs = utils.L2sqrTransList(ts.detach(), tgt, Rs.detach(),
                                          Rgt).sqrt()

        # check number of successful registrations
        num_success = 0
        valid_list = []
        for rs_err, ts_err in zip(rot_errs, trans_errs):
            for rerr, terr in zip(rs_err, ts_err):
                val = (rerr < self.eval_rot_err_thresh) * (
                    terr < self.eval_trans_err_thresh)
                valid_list.append(val.item())
                num_success = num_success + val.item()

        self.num_success_acc += num_success

        # compute registration error per iteration
        for Rit, tit, w in zip(
                Riter, titer, self.weight["Vs_iter"][self.compute_loss_iter:]):
            if w > 0:
                trans_errs2 = utils.GARPointList(tit,
                                                 tgt,
                                                 Rit,
                                                 Rgt,
                                                 V=out_feat["coordinates_ds"],
                                                 c=self.c,
                                                 alpha=self.alpha)
                for terr in trans_errs2:
                    for t in terr:
                        loss_iter = loss_iter + w * t

        # compute final registration error
        if self.weight["Vs"] > 0.0:
            trans_errs2 = utils.GARPointList(ts,
                                             tgt,
                                             Rs,
                                             Rgt,
                                             V=out_feat["coordinates_ds"],
                                             c=self.c,
                                             alpha=self.alpha)

            for terr in trans_errs2:
                for t in terr:
                    final_loss = final_loss + self.weight["Vs"] * t

        print("num valid: ", num_success, "num_success_acc rate: ",
              self.num_success_acc / (self.iter_cnt))

        if sum(valid_pairs_LL):
            loss = loss + final_loss + loss_iter
            loss = loss / num_samples

        return loss, self.num_success_acc / (self.iter_cnt)