예제 #1
0
    def _compute_weights(self, X):
        #X = check_array(X)
        #self.nbrs_ = NearestNeighbors(n_neighbors=self.n_neighbors,
        #                              algorithm=self.neighbors_algorithm)
        #self.nbrs_.fit(X)
        points_number, *para = X.size()
        knn = KNN(k=self.n_neighbors,
                  transpose_mode=True)  #points,neighbours..
        dist, col = knn(X.unsqueeze(0), X.unsqueeze(0))
        #knn_pc = torch.cat([X[0,col[0,:,i],:].unsqueeze(-1) for i in range(self.n_neighbours)],dim=-1).permute(0,2,1)
        W = torch.zeros([points_number, points_number],
                        dtype=X.dtype,
                        device=X.device)
        for i in range(5):
            W[:, col[0, :, i]] = 1

        if self.weight == 'adjacency':
            W = W  #kneighbors_graph(self.nbrs_, self.n_neighbors,
            #                     mode='connectivity', include_self=True)
        elif self.weight == 'heat':
            W = W  #kneighbors_graph(self.nbrs_, self.n_neighbors,
            #                     mode='distance', include_self=True)
            W = torch.exp(-W**2 / self.weight_width**2)
        else:
            raise ValueError("Unrecognized Weight")

        # symmetrize the matrix
        # TODO: make this more efficient & keep sparse output
        #W = W.toarray()
        p = torch.cat([W.unsqueeze(-1), W.t().unsqueeze(-1)], dim=-1)
        W = torch.max(p, dim=-1, keepdim=False)[0]
        #print(W.size())
        return W
예제 #2
0
def knn(x, k: int):
    """
    inputs:
    - x: b x npoints1 x num_dims (partical_cloud)
    - k: int (the number of neighbor)
    outputs:
    - idx: int (neighbor_idx)
    """
    # x : (batch_size, feature_dim, num_points)
    # Retrieve nearest neighbor indices

    if torch.cuda.is_available():
        from knn_cuda import KNN

        ref = x.transpose(2, 1).contiguous()  # (batch_size, num_points, feature_dim)
        query = ref
        _, idx = KNN(k=k, transpose_mode=True)(ref, query)

    else:
        inner = -2 * torch.matmul(x.transpose(2, 1), x)
        xx = torch.sum(x ** 2, dim=1, keepdim=True)
        pairwise_distance = -xx - inner - xx.transpose(2, 1)
        idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)

    return idx
예제 #3
0
 def __init__(self,
              in_channels,
              coord_channels,
              out_channels,
              K,
              mlp=[32, 64],
              point_pooling=False,
              attentive_aggregation=False):
     super(ContinuousConvolution, self).__init__()
     self.K = K
     self.C_in = in_channels
     self.C_coord = coord_channels
     self.C_out = out_channels
     self.C_hid = mlp
     self.num_output_features = self.C_out
     self.knn = KNN(k=self.K, transpose_mode=True)
     self.kernel = MLP_conv1x1(channels=[self.C_in + self.C_coord] +
                               self.C_hid + [self.C_out])
     self.point_pooling = point_pooling
     if self.point_pooling:
         self.num_output_features += self.C_in
     self.attentive_aggregation = attentive_aggregation
     if self.attentive_aggregation:
         self.aggr_mlp = MLP_conv1x1(channels=[self.K, 1])
         self.num_output_features += self.C_out
예제 #4
0
    def forward(self, coords, features):
        r"""
            Forward pass

            Parameters
            ----------
            coords: torch.Tensor, shape (B, N, 3)
                coordinates of the point cloud
            features: torch.Tensor, shape (B, d_in, N, 1)
                features of the point cloud

            Returns
            -------
            torch.Tensor, shape (B, 2*d_out, N, 1)
        """
        knn1 = KNN(k=1, transpose_mode=True)
        knn_output = knn1(coords.cuda().contiguous(),
                          coords.cuda().contiguous())
        x = self.mlp1(features)

        x = self.lse1(coords, x, knn_output)
        x = self.pool1(x)

        x = self.lse2(coords, x, knn_output)
        x = self.pool2(x)

        return self.lrelu(self.mlp2(x) + self.shortcut(features))
예제 #5
0
    def _project_and_match(self, x: torch.Tensor,
                           simp: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        match = None
        proj = None

        # Projected points
        if self.training:
            if not self.skip_projection:
                proj = self.project(point_cloud=x, query_cloud=simp)
            else:
                proj = simp

        # Matched points
        else:  # Inference
            # Retrieve nearest neighbor indices
            _, idx = KNN(1, transpose_mode=False)(x.contiguous(),
                                                  simp.contiguous())
            """Notice that we detach the tensors and do computations in numpy,
            and then convert back to Tensors.
            This should have no effect as the network is in eval() mode
            and should require no gradients.
            """

            # Convert to numpy arrays in B x N x 3 format. we assume 'bcn' format.
            x = x.permute(0, 2, 1).cpu().detach().numpy()

            idx = idx.cpu().detach().numpy()
            idx = np.squeeze(idx, axis=1)

            z = sputils.nn_matching(x,
                                    idx,
                                    self.num_out_points,
                                    complete_fps=self.complete_fps)

            # Matched points are in B x N x 3 format.
            match = torch.tensor(z, dtype=torch.float32).cuda()

        # Change to output shapes
        if self.output_shape == "bnc":
            simp = simp.permute(0, 2, 1)
            if proj is not None:
                proj = proj.permute(0, 2, 1)
        elif self.output_shape == "bcn" and match is not None:
            match = match.permute(0, 2, 1)
            match = match.contiguous()

        # Assert contiguous tensors
        simp = simp.contiguous()
        if proj is not None:
            proj = proj.contiguous()
            out = proj
        else:
            match = match.contiguous()
            out = match

        if self.debug:
            return simp, proj, match
        else:
            return simp, out
def sample_and_group_cuda(npoint,
                          k,
                          xyz,
                          points,
                          cat_xyz_feature=True,
                          fps_only=False):
    """
    Input:
        npoint:
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        # fps_idx = torch.arange(npoint).repeat(xyz.shape[0], 1).int().cuda() # DEBUG ONLY
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
        new_points = index_points_cuda(points.transpose(1, 2), fps_idx)
    else:
        new_xyz = xyz

    if fps_only:
        return new_xyz.transpose(1, 2), new_points.transpose(1, 2), fps_idx

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1,
                                                  C_xyz)  # [B, npoint, k, 3]
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    if cat_xyz_feature:
        new_points = torch.cat([grouped_xyz_norm, grouped_points],
                               dim=1)  # [B, C+C_xyz, npoint, k]
    else:
        new_points = grouped_points  # [B, C+C_xyz, npoint, k]

    return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points, idx
 def __init__(self, r, k, knn=False):
     self.r = r
     self.k = k
     if knn:
         self.use_knn = True
         self.knn = KNN(k=k, transpose_mode=True)
     else:
         self.use_knn = False
예제 #8
0
def sample_and_group_cuda(npoint, k, xyz, points, cat_xyz_feature=True):
    """
    Input:
        npoint:
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
    else:
        new_xyz = xyz

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    torch.cuda.empty_cache()
    try:
        # DEBUG: when using the mixed-trans, some last voxels may have less points
        grouped_xyz_norm = grouped_xyz - new_xyz.view(-1, min(
            npoint, N), 1, C_xyz)  # [B, npoint, k, 3]
    except:
        import ipdb
        ipdb.set_trace()
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    if cat_xyz_feature:
        new_points = torch.cat([grouped_xyz_norm, grouped_points],
                               dim=1)  # [B, C+C_xyz, npoint, k]
    else:
        new_points = grouped_points  # [B, C+C_xyz, npoint, k]

    return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
예제 #9
0
def stem_knn(xyz, points, k):
    knn = KNN(k=k, transpose_mode=True)
    xyz = xyz.permute([0, 2, 1])
    _, idx = knn(xyz.contiguous(),
                 xyz)  # xyz: [bs, npoints, coord] idx: [bs, npoint, k]
    idx = idx.int()

    # take in [B, 3, N]
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(), idx)  # [bs, xyz, n_point, k]
    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k)

    return grouped_xyz, grouped_points
예제 #10
0
def sample_and_group_cuda(npoint, k, xyz, points):
    """
    Input:
        npoint:
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
    else:
        new_xyz = xyz

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    #print(grouped_xyz.size())
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1,
                                                  C_xyz)  # [B, npoint, k, 3]
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    new_points = torch.cat([grouped_xyz_norm, grouped_points],
                           dim=1)  # [B, C+C_xyz, npoint, k]

    return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
예제 #11
0
    def forward(self, input_p, input_x, instance=None, instance_relation=None):
        '''
        input_p:  B, 3, npoint
        input_x: B, in_dim, npoint
        '''
        '''
        how to use the instance information:
        1. use it as guidance of the attention, mask the knns points with different instance label
        2. directly random choose points of same instance label as attention receptive field
        3. attend to the instance center
        '''
        INSTANCE_SCHEME = 3

        B, in_dim, npoint = list(input_x.size())
        n_sample = self.n_sample
        k = min(n_sample, npoint)
        h = self.nhead

        res = input_x

        input_p = input_p.permute([0, 2, 1])
        ori_input_p = input_p

        if instance is not None and INSTANCE_SCHEME == 1:
            # knn more points for sampling
            knn_sample_more_ratio = 2
            enlarged_k = k * knn_sample_more_ratio
            self.knn = KNN(k=enlarged_k, transpose_mode=True)
        else:
            self.knn = KNN(k=k, transpose_mode=True)

        # DEBUG: error here is that in the last block only 4-points;
        # however the knn still gives 16 idxs
        # so when n-point is smaller than the k(n_smaple)
        # if npoint < self.n_sample:
        # self.knn = KNN(k=npoint, transpose_mode=True)
        # else:
        # self.knn = KNN(k=n_sample, transpose_mode=True)
        # pass # regular case

        # DEBUG ONLY: using the input_x: feature space knn!
        # _, idx = self.knn(input_x.transpose(1,2), input_x.transpose(1,2))

        if instance is not None:

            if INSTANCE_SCHEME == 3:
                '''
                Ver3.0: use cur instance center as knn center,
                calc the instance center, and weighting the cur-idx and the instance center idx
                ERROR:
                    - all points of the same instance will have the same idxes? and all are cloest N points to centroid
                    - if use weiighted center and coord, However, need to do N-pointx KNN, will be slow...
                '''
                ori_input_p = input_p.clone()
                # where = torch.where(instance[0] == 1)
                # instance_xyz = input_p[:,where,:].mean(dim=1)   # get [bs, 3] centroid for cur instance
                for i_bs in range(instance.shape[0]):
                    for v in torch.unique(instance[i_bs]):
                        tmp_idx = torch.where(instance[i_bs] == v)[0]
                        ins_center = input_p[:, tmp_idx, :].mean(
                            dim=1)  # the centroids for each intsance
                        # average cur point and the instance center
                        alpha = 0.999
                        input_p[:,
                                tmp_idx, :] = alpha * input_p[:, tmp_idx, :] + (
                                    1 - alpha) * ins_center.unsqueeze(
                                        1)  # [bs, n_cur_ins, 3] + [bs, 1, 3]

                _, idx = self.knn(ori_input_p.contiguous(), ori_input_p)
                _, idx2 = self.knn(ori_input_p.contiguous(), input_p)
                print((idx == idx2).int().sum() / idx.nelement())
            else:
                _, idx = self.knn(input_p.contiguous(), input_p)
        else:
            _, idx = self.knn(input_p.contiguous(), input_p)

        idx = idx.int()

        if INSTANCE_SCHEME == 1:
            '''
            Ver1.0(Naive Version): mask the knn(instance label as auxiliary filter)
            older version of the instance mask
            directly ck if knn grouped point within the same pointset
            then mask if not in
            '''

            if instance is not None:
                # print('start processing the instance mask')
                masks = []
                for i_bs, idx_cur_bs in enumerate(idx):
                    # [4096, 16] cur_bs_idx
                    # [4096]: instance_label[i_bs]
                    mask = instance[i_bs][idx_cur_bs.long()]  # [4096, 2*k]
                    mask = mask - mask[:, 0].unsqueeze(
                        -1
                    )  # get the 1st column(the 1st element in k-elements is itself)
                    mask = (mask == 0).int()  # acuiqre the 0-1 mask
                    masks.append(mask)
                masks = torch.stack(masks)
                print("mask ratio {:.4f}".format(
                    masks.sum() / masks.nelement()))  # >0.5 means ok
                '''
                generate bigger knn-idx and mask, then choose the 1st n_sample(16) elements
                random sample other points from the latter, and use mask to fill into the 0 ones

                get the 1st k idxes that is not 0 in mask
                since the mask values are all 0-1, use argsort will return a vector
                however, we want smaller idxes in the front
                so we give 0 elments a large value to make it appears at last
                if use descend=True, biggest idx with 1 will come first
                '''

                inv_masks = (masks == 0).int()
                tmp_inds = torch.arange(masks.shape[2]).repeat(
                    masks.shape[0], masks.shape[1],
                    1).to(idx.device)  # generate the [1,2,...,enlarged_k] inds
                tmp_inds = tmp_inds * masks
                tmp_inds = tmp_inds + (
                    masks.shape[2] + 1
                ) * inv_masks  # fill the places of 0 with value bigger than the maximum value
                tmp_inds = torch.argsort(
                    tmp_inds
                )[:, :, :
                  k]  # after argsort, the former elements should be non-zero value with smaller idx
                idx = torch.gather(idx, -1, tmp_inds)
                idx = idx.int()

                # TODO: if nk still does not contain enough elements, the argsort will contain the closet knn result while not instance

        elif INSTANCE_SCHEME == 2:
            '''
            # Ver2.0: directly use the points of the same instance label as neighbors
            # random sample k points in the same instance
            '''
            if instance is not None:
                instance_relations = []
                for i_bs in range(instance.shape[0]):
                    instance_inds = [
                        torch.where(instance[i_bs] == v)[0]
                        for v in torch.unique(instance[i_bs])
                    ]  # torch.where returns a tuple, so use [0] to getthe tensor
                    instance_relation = torch.full([instance[0].shape[0], k],
                                                   -1).to(instance.device)
                    for i, ins_id in enumerate(instance_inds):
                        # TODO; stupid pytorch has no func like random.choice
                        if len(ins_id
                               ) <= 5:  # for small outlier points, skip em
                            continue
                        try:
                            perms = torch.multinomial(
                                ins_id.repeat(len(ins_id), 1).float(),
                                num_samples=min(k, len(ins_id)),
                                replacement=False)
                        except RuntimeError:
                            import ipdb
                            ipdb.set_trace()
                        choices = ins_id[perms]
                        instance_relation[
                            instance_inds[i], :choices.shape[1]] = choices
                    instance_relation[:, 0] = torch.arange(
                        instance_relation.shape[0])
                    instance_relations.append(instance_relation)
                instance_relations = torch.stack(instance_relations)

                # print('replacing the instance_relation')
                instance_relation_nonzero_mask = (instance_relations >=
                                                  0).int()
                instance_relation_zero_mask = (instance_relations < 0).int()

                idx = idx * instance_relation_zero_mask + instance_relations * instance_relation_nonzero_mask
                idx = idx.int()

        # ===================== Deprecated Methods ===========================1
        '''
        # Ver 2.3: failed version of receiving a instance_relation,
        # however, point downsample could not be handled
        # the instance feed in here is of the same size as the idxes
        # if the num within the same instance group as less than k
        # then the instance_relation will contain -1, we will replace these -1s
        # with the original idx acquired by knn
        if instance_relation is not None:
            print('replacing the instance_relation')
            # import ipdb; ipdb.set_trace()

            instance_relation = instance_relation[:,:,:k]
            instance_relation_nonzero_mask = (instance_relation>=0).int()
            instance_relation_zero_mask = (instance_relation<0).int()

            # idx = idx*instance_relation_zero_mask + instance_relation*instance_relation_nonzero_mask
            idx = instance_relation.int()

        '''
        '''
        Ver 2.2: Hash Table-based
        1st pack the instance into dict(hash table)
        then ref the points within the same scope to replace the knn points
        if ont enough points of the same insatcne, keep the knn idxs
        '''
        '''
        # pack the instacne into dict for further reference
        if instance is not None:
            print('start creating instance dicts')
            instance_dicts = []
            for i_bs, instance_cur_bs in enumerate(instance):
                instance_dict = {}
                for ins_idx, ins in enumerate(instance_cur_bs):
                    if ins.item() in instance_dict.keys():
                        instance_dict[ins.item()].append(ins_idx)
                    else:
                        instance_dict[ins.item()] = [ins_idx]
                for ins_k in instance_dict.keys():
                    instance_dict[ins_k] = torch.tensor(instance_dict[ins_k]).to(instance.device)
                instance_dicts.append(instance_dict)

            l1 = []
            for i_bs in range(instance.shape[0]):
                l0 = []
                for i_point in range(instance.shape[1]):
                    tmp = torch.zeros([k])
                    instance_gathered  = instance_dicts[i_bs][instance[i_bs][i_point].item()][:k]
                    tmp[:len(instance_gathered)] = instance_gathered
                    # idx[i_bs][i_point][:len(instance_gathered)] = instance_gathered
                    l0.append(tmp)
                tmp1 = torch.stack(l0)
                l1.append(tmp1)
            new_idx = torch.stack(l1)
        '''
        '''
        Ver: 2.1: Naive Version of for-loop replacement 
        # Too slow version, needs improving
        # 1st use knn then use mask the value belongs not to the same instance
        instance_masks = []
        for i_batch, single_batch_instance in enumerate(instance):
            # single_batch_instance: [npoint]
            masks_cur_batch = []
            for i_point, gathered_points in enumerate(idx[i_batch]):
                # gathered_points: [k]
                points_with_same_instance = torch.where(single_batch_instance == single_batch_instance[i_point])[0]
                # ck if the grouped idxes are within the same idxes
                cur_mask = torch.tensor([g.item() in points_with_same_instance for g in gathered_points])
                masks_cur_batch.append(cur_mask)
            masks_cur_batch = torch.stack(masks_cur_batch)
            instance_masks.append(masks_cur_batch)
        instance_masks = torch.stack(instance_masks)
        '''

        # ==========================================================================================

        grouped_input_p = grouping_operation_cuda(
            input_p.transpose(1, 2).contiguous(), idx)  # [bs, xyz, npoint, k]

        if self.pre_ln:
            input_x = self.ln_top(input_x.transpose(1, 2)).transpose(1, 2)

        input_x = self.linear_top(input_x)

        # TODO: apply the layer-norm
        # however the original is [bs, dim, npoint]
        if self.pre_ln:
            input_x = self.ln_attn(input_x.transpose(1, 2)).transpose(1, 2)

        # grouped_input_x = index_points(input_x.permute([0,2,1]), idx.long()).permute([0,3,1,2])
        # grouped_input_x = grouping_operation_cuda(input_x.contiguous(), idx)  # [bs, xyz, npoint, K]
        phi = self.phi(input_x)
        phi = phi[:, :, :, None].repeat(1, 1, 1, k)
        psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx)
        alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(),
                                        idx)  # [bs, xyz, npoint, k]

        relative_xyz = input_p.permute([0, 2, 1])[:, :, :,
                                                  None] - grouped_input_p
        pos_encoding = self.delta(relative_xyz)  # [bs, dims, npoint, k]

        if self.use_vector_attn:
            # the attn_map: [vector_dim];
            # the alpha:    [out_dim]
            attn_map = F.softmax(self.gamma(phi - psi + pos_encoding),
                                 dim=-1)  # [B, Dim, N, k]
            # if instance is not None: # apply mask
            # attn_map = attn_map*(masks.unsqueeze(1))
            y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                1) * (alpha + pos_encoding)
            y = y.sum(dim=-1)
        else:
            phi = phi.reshape(B, h, self.out_dim // h, npoint, k)
            psi = psi.reshape(B, h, self.out_dim // h, npoint, k)
            attn_map = F.softmax(
                (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding,
                dim=-1)
            y = attn_map * (alpha + pos_encoding)
            y = y.sum(dim=-1)

        if self.pre_ln:
            y = self.ln_down(y.transpose(1, 2)).transpose(1, 2)

        y = self.linear_down(y)

        return y + res, attn_map.detach().cpu().data
예제 #12
0
from dataset import *

import my_utils

my_utils.plant_seeds(randomized_seed=False)
from sklearn.neighbors import NearestNeighbors
from ply import *
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "0"
from model_err_nfeaPointnet2 import *
import json
import datetime
from LaplacianLoss import *
from knn_cuda import KNN
knn = KNN(k=1, transpose_mode=True)
import time
import warnings
warnings.filterwarnings("ignore")
# =============PARAMETERS======================================== #
lambda_laplace = 0.005
lambda_ratio = 0.005

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=48)
parser.add_argument('--nepoch', type=int, default=75, help='number of epochs to train for')
parser.add_argument('--model', type=str, default='', help='optional reload model path')
parser.add_argument('--env', type=str, default="3DCODED_unsupervised", help='visdom environment')
parser.add_argument('--laplace', type=int, default=1, help='regularize towords 0 curvature, or template curvature')
예제 #13
0
    def forward(self, input):
        r"""
            Forward pass

            Parameters
            ----------
            input: torch.Tensor, shape (B, N, d_in)
                input points

            Returns
            -------
            torch.Tensor, shape (B, num_classes, N)
                segmentation scores for each point
        """
        N = input.size(1)
        d = self.decimation

        coords = input[..., :3].clone().cpu()
        x = self.fc_start(input).transpose(-2, -1).unsqueeze(-1)
        x = self.bn_start(x)  # shape (B, d, N, 1)

        decimation_ratio = 1

        # <<<<<<<<<< ENCODER
        x_stack = []

        permutation = torch.randperm(N)
        coords = coords[:, permutation]
        x = x[:, :, permutation]

        for lfa in self.encoder:
            # at iteration i, x.shape = (B, N//(d**i), d_in)
            startenc = time.time()

            x = lfa(coords[:, :N // decimation_ratio], x)
            x_stack.append(x.clone())
            decimation_ratio *= d
            x = x[:, :, :N // decimation_ratio]
            endenc = time.time()
            #print("each dec time is",endenc-startenc)

        # # >>>>>>>>>> ENCODER

        x = self.mlp(x)
        #print("encoder time is",end-start)
        # <<<<<<<<<< DECODER
        for mlp in self.decoder:
            knn = KNN(k=1, transpose_mode=True)
            _, neighbors = knn(
                coords[:, :N //
                       decimation_ratio].cuda().contiguous(),  # original set
                coords[:, :d * N //
                       decimation_ratio].cuda().contiguous()  # upsampled set
            )  # shape (B, N, 1)
            neighbors = neighbors.to(self.device)
            extended_neighbors = neighbors.unsqueeze(1).expand(
                -1, x.size(1), -1, 1)

            x_neighbors = torch.gather(x, -2, extended_neighbors)

            x = torch.cat((x_neighbors, x_stack.pop()), dim=1)

            x = mlp(x)

            decimation_ratio //= d

        # >>>>>>>>>> DECODER
        # inverse permutation
        x = x[:, :, torch.argsort(permutation)]
        scores = self.fc_end(x)
        return scores.squeeze(-1)
예제 #14
0
 def __init__(self,k=16):
     super(get_edge_feature,self).__init__()
     self.KNN=KNN(k=k+1,transpose_mode=False)
     self.k=k
예제 #15
0
def create_knn_adj_mat(features,
                       k,
                       weighted=False,
                       n_jobs=None,
                       algorithm='auto',
                       threshold=None,
                       use_gpu=False):
    """
    Create a directed normalized adjacency matrix from input nodes based on k-nearest neighbours
    Parameters:
        features (numpy array of node features): array of size N*M (N nodes, M feature size)
        k (int): number of neighbours to find
        weighted (bool): set to True for weighted adjacency matrix (based on Euclidean distance)
        n_jobs (int): number of jobs to deploy if GPU is not used
        algorithm (str): Choose between auto, ball_tree, kd_tree or brute
        threshold (float): Cutoff value for the Euclidean distance
        use_gpu (bool): Indicates whether GPU is to be used for the KNN algorithm
    Returns:
        (coo matrix): adjacency matrix as a sparse coo matrix
    """
    t_start = datetime.now()

    if use_gpu:

        features_extra_dim = np.expand_dims(features, axis=2)

        knn = KNN(k=k, transpose_mode=True)

        # Find the k nearest neighbours and their distance
        dist, idx = knn(
            torch.from_numpy(features_extra_dim).cuda(),
            torch.from_numpy(features_extra_dim).clone().cuda())

        torch.cuda.empty_cache()

        del features_extra_dim

        idx = idx.cpu()
        dist = dist.cpu()

        # Clean up the indices and distances
        dist = dist.flatten()

        # Create tuples of indices where an edge exists
        rows = np.repeat(np.arange(features.shape[0]), k)
        columns = idx.flatten()
        non_zero_indices = tuple(np.stack((rows, columns)))

        del rows
        del columns
        del idx

        # Remove edges where the distance is higher than the threshold
        if threshold:
            indices_to_remove = dist > threshold
            indices_to_remove = np.where(indices_to_remove)
            non_zero_indices = tuple(
                np.delete(non_zero_indices, indices_to_remove, 1))
            dist = np.delete(dist, indices_to_remove[0], 0)

            del indices_to_remove

        if weighted:

            # Create zero matrix as the initial adjacency matrix
            adj_mat_weighted = np.zeros((features.shape[0], features.shape[0]),
                                        dtype=np.float32)

            # Fill in the adjacency matrix with node distances
            adj_mat_weighted[non_zero_indices] = dist

            non_zero_indices = np.nonzero(adj_mat_weighted)

            # Take reciprocal of non-zero elements to associate lower weight to higher distances
            adj_mat_weighted[
                non_zero_indices] = 1 / adj_mat_weighted[non_zero_indices]

            # Normalize rows
            coo_matrix = sp.coo_matrix(adj_mat_weighted)
            normalized_coo_matrix = normalize(coo_matrix)

            # DGL requires self loops
            normalized_coo_matrix = normalized_coo_matrix + sp.eye(
                normalized_coo_matrix.shape[0])

            t_end = datetime.now()
            logger.debug("it took {} to create the graph".format(t_end -
                                                                 t_start))

            return normalized_coo_matrix

        else:
            # Create eye matrix as the initial adjacency matrix
            adj_mat_binary = np.zeros((features.shape[0], features.shape[0]))

            # Create the binary adjacency matrix
            adj_mat_binary[non_zero_indices] = 1

            t_end = datetime.now()
            logger.debug("it took {} to create the graph".format(t_end -
                                                                 t_start))

            return sp.coo_matrix(adj_mat_binary)
    else:

        # initialize and fit nearest neighbour algorithm
        neigh = NearestNeighbors(n_neighbors=k,
                                 n_jobs=n_jobs,
                                 algorithm=algorithm)
        neigh.fit(features)

        # Obtain matrix with distance of k-nearest points
        adj_mat_weighted = np.array(
            sp.coo_matrix(neigh.kneighbors_graph(features, k,
                                                 mode='distance')).toarray())

        if threshold:
            indices_to_zero = adj_mat_weighted > threshold
            adj_mat_weighted[indices_to_zero] = 0

        non_zero_indices = np.nonzero(adj_mat_weighted)

        if weighted:
            # Take reciprocal of non-zero elements to associate lower weight to higher distances
            adj_mat_weighted[
                non_zero_indices] = 1 / adj_mat_weighted[non_zero_indices]

            # Normalize rows
            adj_mat_weighted = sp.coo_matrix(adj_mat_weighted)
            normalized_coo_matrix = normalize(adj_mat_weighted)

            # DGL requires self loops
            normalized_coo_matrix = normalized_coo_matrix + sp.eye(
                normalized_coo_matrix.shape[0])

            del adj_mat_weighted

            t_end = datetime.now()
            logger.debug("it took {} to create the graph".format(t_end -
                                                                 t_start))

            return normalized_coo_matrix

        # Obtain the binary adjacency matrix
        adj_mat_binary = adj_mat_weighted
        adj_mat_binary[non_zero_indices] = 1
        adj_mat_binary = adj_mat_binary + np.eye(adj_mat_binary.shape[0])

        del adj_mat_weighted

        t_end = datetime.now()
        logger.debug("it took {} to create the graph".format(t_end - t_start))

        return sp.coo_matrix(adj_mat_binary)
 def knn(ref, query, n, k) -> torch.Tensor:
       ref = ref.view(1,n,3)
       query = query.view(1,n,3)
       d, I = KNN(k, transpose_mode=True)(ref=ref, query=query)
       return d.view(n,k), I.view(n*k)
 def knn_grad(ref, query, n, k) -> torch.Tensor: #NOTE output tensor shape [n,k,3]
     ref = ref.view(1,n,3)
     query = query.view(1,n,3)
     d, I = KNN(ref=ref, query=query)
     diff = query.view(n,1,3) - ref[0, I.view(-1),:].view(n,k,3) #shape [n,k,3]
     return diff.view(n,k,3), I
 def _knn_point(self, nsample, ref, query):
     knn = KNN(k=nsample, transpose_mode=True)
     dist, indx = knn(ref, query)
     return dist, indx
예제 #19
0
 def __init__(self, radius=1.0):
     super(Loss, self).__init__()
     self.radius = radius
     self.knn_uniform = KNN(k=2, transpose_mode=True)
     self.knn_repulsion = KNN(k=20, transpose_mode=True)
예제 #20
0
 def __init__(self, sharpen=1.0):
     super().__init__()
     self.k = 50
     self.knn = KNN(self.k, transpose_mode=True)
     self.sharpen = sharpen
예제 #21
0
    def forward(self, input_p, input_x):
        '''
        input_p:  B, 3, npoint
        input_x: B, in_dim, npoint
        '''

        B, in_dim, npoint = list(input_x.size())
        n_sample = self.n_sample
        k = min(n_sample, npoint)
        h = self.nhead

        res = input_x

        input_p = input_p.permute([0, 2, 1])

        # DEBUG: error here is that in the last block only 4-points;
        # however the knn still gives 16 idxs
        if npoint < self.n_sample:
            self.knn = KNN(k=npoint, transpose_mode=True)

        # DEBUG ONLY: using the input_x: feature space knn!
        # _, idx = self.knn(input_x.transpose(1,2), input_x.transpose(1,2))

        _, idx = self.knn(input_p.contiguous(), input_p)
        idx = idx.int()

        grouped_input_p = grouping_operation_cuda(
            input_p.transpose(1, 2).contiguous(), idx)  # [bs, xyz, npoint, k]

        if self.pre_ln:
            input_x = self.ln_top(input_x.transpose(1, 2)).transpose(1, 2)

        input_x = self.linear_top(input_x)

        # TODO: apply the layer-norm
        # however the original is [bs, dim, npoint]
        if self.pre_ln:
            input_x = self.ln_attn(input_x.transpose(1, 2)).transpose(1, 2)

        # grouped_input_x = index_points(input_x.permute([0,2,1]), idx.long()).permute([0,3,1,2])
        # grouped_input_x = grouping_operation_cuda(input_x.contiguous(), idx)  # [bs, xyz, npoint, K]
        phi = self.phi(input_x)
        phi = phi[:, :, :, None].repeat(1, 1, 1, k)
        psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx)
        alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(),
                                        idx)  # [bs, xyz, npoint, k]

        relative_xyz = input_p.permute([0, 2, 1])[:, :, :,
                                                  None] - grouped_input_p
        pos_encoding = self.delta(relative_xyz)  # [bs, dims, npoint, k]

        if self.use_vector_attn:
            # the attn_map: [vector_dim];
            # the alpha:    [out_dim]
            attn_map = F.softmax(self.gamma(phi - psi + pos_encoding), dim=-1)
            y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                1) * (alpha + pos_encoding)
            y = y.sum(dim=-1)
        else:
            phi = phi.reshape(B, h, self.out_dim // h, npoint, k)
            psi = psi.reshape(B, h, self.out_dim // h, npoint, k)
            attn_map = F.softmax(
                (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding,
                dim=-1)
            y = attn_map * (alpha + pos_encoding)
            y = y.sum(dim=-1)

        if self.pre_ln:
            y = self.ln_down(y.transpose(1, 2)).transpose(1, 2)

        y = self.linear_down(y)

        return y + res, attn_map.detach().cpu().data
예제 #22
0
    def __init__(self, in_dim, is_firstlayer=False, n_sample=16):
        super().__init__()
        '''
        Point Transformer Layer

        in_dim: feature dimension of the input feature x
        out_dim: feature dimension of the Point Transformer Layer(currently same with hidden-dim)
        [?] - not sure how to set hidden. the paper only gives the out
        '''

        self.in_dim = in_dim
        self.is_firstlayer = is_firstlayer

        # TODO: set the hidden/vector/out_dims
        self.hidden_dim = in_dim
        # self.out_dim = min(4*in_dim, 512)
        self.out_dim = in_dim
        self.vector_dim = self.out_dim
        self.n_sample = n_sample

        # whether use BN or LN or None
        # 0 - None
        # 1 - BN
        # 2 - LN

        self.use_bn = 1
        # use transformer-like preLN before the attn & ff layer
        self.pre_ln = False

        # whether to use the vector att or the original attention
        self.use_vector_attn = True
        self.nhead = 4

        self.linear_top = nn.Sequential(
            nn.Conv1d(in_dim, self.hidden_dim, 1),
            # TransposeLayerNorm(self.hidden_dim),
            nn.BatchNorm1d(self.hidden_dim) if self.use_bn else nn.Identity())
        self.linear_down = nn.Sequential(
            nn.Conv1d(self.out_dim, self.in_dim, 1),
            # TransposeLayerNorm(self.in_dim),
            nn.BatchNorm1d(self.in_dim) if self.use_bn else nn.Identity())

        self.phi = nn.Sequential(
            nn.Conv1d(self.hidden_dim, self.out_dim, 1),
            # nn.BatchNorm1d(self.out_dim) if self.use_bn else nn.Identity()
        )
        self.psi = nn.Sequential(
            nn.Conv1d(self.hidden_dim, self.out_dim, 1),
            # nn.BatchNorm1d(self.out_dim) if self.use_bn else nn.Identity()
        )
        self.alpha = nn.Sequential(
            nn.Conv1d(self.hidden_dim, self.out_dim, 1),
            # nn.BatchNorm1d(self.out_dim) if self.use_bn else nn.Identity()
        )

        self.gamma = nn.Sequential(
            nn.Conv2d(self.out_dim, self.hidden_dim, 1),
            # TransposeLayerNorm(self.hidden_dim),
            nn.BatchNorm2d(self.hidden_dim) if self.use_bn else nn.Identity(),
            nn.ReLU(),
            nn.Conv2d(self.hidden_dim, self.vector_dim, 1),
            # TransposeLayerNorm(self.vector_dim),
            nn.BatchNorm2d(self.vector_dim) if self.use_bn else nn.Identity())

        self.delta = nn.Sequential(
            nn.Conv2d(3, self.hidden_dim, 1),
            # TransposeLayerNorm(self.hidden_dim),
            nn.BatchNorm2d(self.hidden_dim) if self.use_bn else nn.Identity(),
            nn.ReLU(),
            nn.Conv2d(self.hidden_dim, self.out_dim, 1),
            nn.BatchNorm2d(self.out_dim) if self.use_bn else nn.Identity()
            # TransposeLayerNorm(self.out_dim),
        )

        if self.pre_ln:
            self.ln_top = nn.LayerNorm(self.in_dim)
            self.ln_attn = nn.LayerNorm(self.hidden_dim)
            self.ln_down = nn.LayerNorm(self.out_dim)

        self.knn = KNN(k=n_sample, transpose_mode=True)
예제 #23
0
    def forward(self, input_p, input_x):
        '''
        input_p:  B, 3, npoint
        input_x: B, in_dim, npoint
        '''
        B, in_dim, npoint = list(input_x.size())  # npoint: the input point-num
        n_sample = self.n_sample  # the knn-sample num cur block
        k = min(n_sample, npoint)  # denoting the num_point cur layer
        if not self.use_vector_attn:
            h = self.nhead  # only used in non-vextor attn

        input_p = input_p.permute([0, 2, 1])  # [B, npoint, 3]
        self.register_buffer('in_xyz_map', input_p)

        if self.fps_rate is not None:
            npoint = npoint // self.fps_rate
            fps_idx = farthest_point_sample_cuda(input_p, npoint)
            torch.cuda.empty_cache()
            input_p_fps = index_points_cuda(input_p, fps_idx)  # [B. M, 3]
            if self.SKIP_ALL:
                input_p_reduced = input_p_fps.transpose(1, 2)
                input_x_reduced = index_points_cuda(
                    self.tmp_linear(input_x).transpose(1, 2),
                    fps_idx).transpose(1, 2)
                return input_p_reduced, input_x_reduced
        else:
            input_p_fps = input_p
            input_x_fps = input_x

        res = input_x  # [B, dim, M]

        if self.USE_KNN:
            self.knn = KNN(k=k, transpose_mode=True)
            _, idx = self.knn(input_p.contiguous(), input_p_fps.contiguous())
            idx = idx.int()  # [bs, npoint, k]
        else:
            idx = query_ball_point_cuda(
                self.radius, k, input_p.contiguous(),
                input_p_fps.contiguous())  # [bs, npoint, k]

        grouped_input_p = grouping_operation_cuda(
            input_p.transpose(1, 2).contiguous(), idx)  # [bs, xyz, npoint, k]
        grouped_input_x = grouping_operation_cuda(
            input_x.contiguous(), idx)  # [bs, hidden_dim, npoint, k]

        self.register_buffer('neighbor_map', idx)

        # TODO: define proper r for em
        # query_idx = query_ball_point_cuda(radius, k, coord, coord) # [bs, npoint, k]
        # self.knn = KNN(k=k, transpose_mode=True)
        # _, knn_idx = self.knn(input_p.contiguous(), input_p)
        # import ipdb; ipdb.set_trace()

        if self.fps_rate is not None:
            if self.SKIP_ATTN:
                pass  # only apply linear-top for ds blocks
            else:
                input_x = self.linear_top(input_x)
        else:
            if self.SKIP_ATTN:
                pass  # only apply linear-top for ds blocks
            else:
                input_x = self.linear_top(input_x)
        # input_x = self.linear_top(input_x)

        if self.SKIP_ATTN:
            # import ipdb; ipdb.set_trace()
            # out_dim should be the same with in_dim, since here contains no TD
            if self.POS_ENCODING:
                relative_xyz = input_p_fps.permute(
                    [0, 2, 1])[:, :, :, None] - grouped_input_p
                pos_encoding = self.delta(
                    relative_xyz)  # [bs, dims, npoint, k]
                if self.CAT_POS:
                    alpha = self.alpha(
                        torch.cat([grouped_input_x, relative_xyz], dim=1))
                else:  # use sum
                    alpha = self.alpha(grouped_input_x + pos_encoding)
            else:
                alpha = self.alpha(grouped_input_x)
                # alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(), idx)

            y = alpha.max(dim=-1)[0]
            # y = alpha.sum(dim=-1)
            y = self.linear_down(y)

            if self.fps_rate is not None:
                input_p_reduced = input_p_fps.transpose(1, 2)
                # WRONG!: noneed for applying fps_idx here
                # input_x_reduced = index_points_cuda(y.transpose(1,2), fps_idx).transpose(1,2)  # [B, dim, M]
                input_x_reduced = y
                return input_p_reduced, input_x_reduced
            else:
                input_p_reduced = input_p_fps.transpose(1, 2)
                input_x_reduced = y + res
                return input_p_reduced, input_x_reduced

        # when downsampling the TRBlock
        # should use downsampled qkv here, so use input_x_fps
        # as for normal block, input_x and input_x_fps are the same
        if self.fps_rate is not None:
            input_x_fps = index_points_cuda(input_x.transpose(
                1, 2), fps_idx).transpose(
                    1, 2)  # it is only used for tr-like downsample block
            phi = self.phi(input_x_fps)
        else:
            phi = self.phi(input_x)
        phi = phi[:, :, :, None].repeat(1, 1, 1, k)
        psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx)
        self.skip_knn = True
        alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(),
                                        idx)  # [bs, xyz, npoint, k]

        if self.POS_ENCODING:
            relative_xyz = input_p_fps.permute(
                [0, 2, 1])[:, :, :, None] - grouped_input_p
            pos_encoding = self.delta(relative_xyz)  # [bs, dims, npoint, k]

        if self.use_vector_attn:
            # the attn_map: [vector_dim];
            # the alpha:    [out_dim]
            if self.POS_ENCODING:
                # if V_POS and QK_POS is both false, then apply all pos_encoding
                assert (
                    self.V_POS_ONLY and self.QK_POS_ONLY
                ) is False  # only one of the V_ONLY and QK_ONLY should be applied
                if self.V_POS_ONLY:
                    attn_map = F.softmax(self.gamma(phi - psi), dim=-1)
                else:
                    attn_map = F.softmax(self.gamma(phi - psi + pos_encoding),
                                         dim=-1)
                if self.QK_POS_ONLY:
                    y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                        1) * (alpha)
                else:
                    y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                        1) * (alpha + pos_encoding)
            else:
                attn_map = F.softmax(self.gamma(phi - psi), dim=-1)
                y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                    1) * (alpha)
            if self.MAX_POOL:
                y = y.max(dim=-1)[0]
            else:
                y = y.sum(dim=-1)
        else:
            assert self.POS_ENCODING == True
            phi = phi.reshape(B, h, self.out_dim // h, npoint, k)
            psi = psi.reshape(B, h, self.out_dim // h, npoint, k)
            attn_map = F.softmax(
                (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding,
                dim=-1)
            y = attn_map * (alpha + pos_encoding)
            y = y.sum(dim=-1)

        self.register_buffer('attn_map', attn_map.mean(dim=1))

        y = self.linear_down(y)

        if self.fps_rate is not None:
            input_p_reduced = input_p_fps.transpose(1, 2)
            # input_x_reduced = index_points_cuda(y.transpose(1,2), fps_idx).transpose(1,2)  # [B, dim, M]
            input_x_reduced = y
            return input_p_reduced, input_x_reduced
        else:
            input_p_reduced = input_p_fps.transpose(1, 2)
            input_x_reduced = y + res
            return input_p_reduced, input_x_reduced
예제 #24
0
    def forward(self, x: torch.Tensor):
        # x shape should be B x 3 x N
        if self.input_shape == "bnc":
            x = x.permute(0, 2, 1)

        if x.shape[1] != 3:
            raise RuntimeError(
                "shape of x must be of [Batch x 3 x NumInPoints]")

        y = F.relu(self.bn1(self.conv1(x)))
        y = F.relu(self.bn2(self.conv2(y)))
        y = F.relu(self.bn3(self.conv3(y)))
        y = F.relu(self.bn4(self.conv4(y)))
        y = F.relu(self.bn5(self.conv5(y)))  # Batch x 128 x NumInPoints

        # Max pooling for global feature vector:
        y = torch.max(y, 2)[0]  # Batch x 128

        y = F.relu(self.bn_fc1(self.fc1(y)))
        y = F.relu(self.bn_fc2(self.fc2(y)))
        y = F.relu(self.bn_fc3(self.fc3(y)))
        y = self.fc4(y)

        y = y.view(-1, 3, self.num_out_points)

        # Simplified points
        simp = y
        match = None
        proj = None

        # Projected points
        if self.training:
            if not self.skip_projection:
                proj = self.project(point_cloud=x, query_cloud=y)
            else:
                proj = simp

        # Matched points
        else:  # Inference
            # Retrieve nearest neighbor indices
            _, idx = KNN(1, transpose_mode=False)(x.contiguous(),
                                                  y.contiguous())
            """Notice that we detach the tensors and do computations in numpy,
            and then convert back to Tensors.
            This should have no effect as the network is in eval() mode
            and should require no gradients.
            """

            # Convert to numpy arrays in B x N x 3 format. we assume 'bcn' format.
            x = x.permute(0, 2, 1).cpu().detach().numpy()
            y = y.permute(0, 2, 1).cpu().detach().numpy()

            idx = idx.cpu().detach().numpy()
            idx = np.squeeze(idx, axis=1)

            z = sputils.nn_matching(x,
                                    idx,
                                    self.num_out_points,
                                    complete_fps=self.complete_fps)

            # Matched points are in B x N x 3 format.
            match = torch.tensor(z, dtype=torch.float32).cuda()

        # Change to output shapes
        if self.output_shape == "bnc":
            simp = simp.permute(0, 2, 1)
            if proj is not None:
                proj = proj.permute(0, 2, 1)
        elif self.output_shape == "bcn" and match is not None:
            match = match.permute(0, 2, 1)
            match = match.contiguous()

        # Assert contiguous tensors
        simp = simp.contiguous()
        if proj is not None:
            proj = proj.contiguous()
        if match is not None:
            match = match.contiguous()

        out = proj if self.training else match

        return simp, out
예제 #25
0
def sample_and_group_cuda(npoint,
                          k,
                          xyz,
                          points,
                          instance=None,
                          instance_relation=None):
    """
    Input:
        npoint: seems 1/4 of N
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
        instance: input_instance, [B,N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
        new_instance, [B, npoint]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
    else:
        new_xyz = xyz

    # unsqueeze to [B,N,1] then apply indexing
    if instance is not None:
        new_instance = index_points_cuda(
            instance.unsqueeze(-1).float(), fps_idx).squeeze(-1)
    else:
        pass

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    #print(grouped_xyz.size())
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1,
                                                  C_xyz)  # [B, npoint, k, 3]
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    new_points = torch.cat([grouped_xyz_norm, grouped_points],
                           dim=1)  # [B, C+C_xyz, npoint, k]

    if instance is not None:
        return new_xyz.transpose(1,
                                 2), grouped_xyz_norm, new_points, new_instance
    else:
        return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
예제 #26
0
def knn_point(group_size, point_cloud, query_cloud):
    knn_obj = KNN(k=group_size, transpose_mode=False)
    dist, idx = knn_obj(point_cloud, query_cloud)
    return dist, idx
예제 #27
0
 def __init__(self, num):
     super(RepulsionLoss, self).__init__()
     self.k = num
     self.knn_repulsion = KNN(k=self.k, transpose_mode=True)