def sample_and_group_cuda(npoint,
                          k,
                          xyz,
                          points,
                          cat_xyz_feature=True,
                          fps_only=False):
    """
    Input:
        npoint:
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        # fps_idx = torch.arange(npoint).repeat(xyz.shape[0], 1).int().cuda() # DEBUG ONLY
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
        new_points = index_points_cuda(points.transpose(1, 2), fps_idx)
    else:
        new_xyz = xyz

    if fps_only:
        return new_xyz.transpose(1, 2), new_points.transpose(1, 2), fps_idx

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1,
                                                  C_xyz)  # [B, npoint, k, 3]
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    if cat_xyz_feature:
        new_points = torch.cat([grouped_xyz_norm, grouped_points],
                               dim=1)  # [B, C+C_xyz, npoint, k]
    else:
        new_points = grouped_points  # [B, C+C_xyz, npoint, k]

    return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points, idx
예제 #2
0
def sample_and_group_cuda(npoint, k, xyz, points, cat_xyz_feature=True):
    """
    Input:
        npoint:
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
    else:
        new_xyz = xyz

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    torch.cuda.empty_cache()
    try:
        # DEBUG: when using the mixed-trans, some last voxels may have less points
        grouped_xyz_norm = grouped_xyz - new_xyz.view(-1, min(
            npoint, N), 1, C_xyz)  # [B, npoint, k, 3]
    except:
        import ipdb
        ipdb.set_trace()
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    if cat_xyz_feature:
        new_points = torch.cat([grouped_xyz_norm, grouped_points],
                               dim=1)  # [B, C+C_xyz, npoint, k]
    else:
        new_points = grouped_points  # [B, C+C_xyz, npoint, k]

    return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
예제 #3
0
def stem_knn(xyz, points, k):
    knn = KNN(k=k, transpose_mode=True)
    xyz = xyz.permute([0, 2, 1])
    _, idx = knn(xyz.contiguous(),
                 xyz)  # xyz: [bs, npoints, coord] idx: [bs, npoint, k]
    idx = idx.int()

    # take in [B, 3, N]
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(), idx)  # [bs, xyz, n_point, k]
    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k)

    return grouped_xyz, grouped_points
예제 #4
0
    def forward(self, xyz_1, xyz_2, points_1, points_2):
        """
        Input:
            M < N
            xyz_1: input points position data, [B, 3, M]
            xyz_2: input points position data, [B, 3, N]
            points_1: input points data, [B, C, M]
            points_2: input points data, [B, C, N]

            interpolate xyz_2's coordinates feature with knn neighbor's features weighted by inverse distance

        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """

        B, input_dim, M = list(points_1.size())
        B, output_dim, N = list(points_2.size())

        points_1 = self.linear_1(points_1)
        points_2 = self.linear_2(points_2)

        dists = square_distance(xyz_2.transpose(1, 2),
                                xyz_1.transpose(1, 2))  # [B, N, M]
        dists, idx = dists.sort(dim=-1)
        dists, idx = dists[:, :, :self.k], idx[:, :, :self.k]

        dist_recip = 1.0 / (dists + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight = dist_recip / norm
        interpolated_points = torch.sum( \
                        grouping_operation_cuda(points_1, idx.int())*weight.view(B, 1, N, 3)
                                                ,dim=-1)

        return xyz_2, (interpolated_points + points_2)
예제 #5
0
def sample_and_group_cuda(npoint, k, xyz, points):
    """
    Input:
        npoint:
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
    else:
        new_xyz = xyz

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    #print(grouped_xyz.size())
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1,
                                                  C_xyz)  # [B, npoint, k, 3]
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    new_points = torch.cat([grouped_xyz_norm, grouped_points],
                           dim=1)  # [B, C+C_xyz, npoint, k]

    return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
    def get_neighbor(self, ref: ME.SparseTensor, query: ME.SparseTensor):
        B_nq, _ = query.C.shape

        coord = query.C  # (N, 4)
        batch_info = coord[:, 0]
        coord, mask, idx_ = separate_batch(coord)  # (b, n, 3)
        b, n, _ = coord.shape

        if self.use_knn:
            _, idx = self.knn(coord.contiguous(), coord)
            grouped_coord = grouping_operation_cuda(
                coord.float().transpose(1, 2).contiguous(), idx.int())
            result_padded = grouped_coord.permute([0, 2, 3, 1])
        else:
            query_and_group_cuda = QueryAndGroup(radius=self.r,
                                                 nsample=self.k,
                                                 use_xyz=False)
            coord = coord.float()

            idxs = query_and_group_cuda(
                xyz=coord,
                new_xyz=coord,
                features=coord.transpose(1, 2).contiguous(),
            )  # idx: [bs, xyz, npoint, nsample]
            idxs = idxs.permute([0, 2, 3,
                                 1])  # idx: [bs, npoint, nsample, xyz]
            result_padded = idxs

        # unpad result (b, n, k, 3) -> (B_nq, k, 4) by applying mask
        result = torch.zeros([B_nq, self.k, 4],
                             dtype=torch.int32,
                             device=query.device)
        result[:, :,
               1:] = torch.gather(result_padded.reshape(-1, self.k, 3), 0,
                                  idx_.reshape(-1, 1, 1).repeat(1, self.k, 3))
        result[:, :, 0] = batch_info.unsqueeze(-1).repeat(1, self.k)

        return result, mask, idx_
예제 #7
0
def sample_and_group_cuda(npoint,
                          k,
                          xyz,
                          points,
                          instance=None,
                          instance_relation=None):
    """
    Input:
        npoint: seems 1/4 of N
        k:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, C, N]
        instance: input_instance, [B,N]
    Return:
        new_xyz: sampled points position data, [B, 3, npoint]
        new_points: sampled points data, [B, C+C_xyz, npoint, k]
        grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k]
        new_instance, [B, npoint]
    """
    k = min(npoint, k)
    knn = KNN(k=k, transpose_mode=True)

    B, N, C_xyz = xyz.shape

    if npoint < N:
        fps_idx = farthest_point_sample_cuda(xyz, npoint)  # [B, npoint]
        torch.cuda.empty_cache()
        new_xyz = index_points_cuda(xyz, fps_idx)  #[B, npoint, 3]
    else:
        new_xyz = xyz

    # unsqueeze to [B,N,1] then apply indexing
    if instance is not None:
        new_instance = index_points_cuda(
            instance.unsqueeze(-1).float(), fps_idx).squeeze(-1)
    else:
        pass

    torch.cuda.empty_cache()
    _, idx = knn(xyz.contiguous(), new_xyz)  # B, npoint, k
    idx = idx.int()

    torch.cuda.empty_cache()
    grouped_xyz = grouping_operation_cuda(
        xyz.transpose(1, 2).contiguous(),
        idx).permute(0, 2, 3, 1)  # [B, npoint, k, C_xyz]
    #print(grouped_xyz.size())
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1,
                                                  C_xyz)  # [B, npoint, k, 3]
    grouped_xyz_norm = grouped_xyz_norm.permute(
        0, 3, 1, 2).contiguous()  # [B, 3, npoint, k]
    torch.cuda.empty_cache()

    grouped_points = grouping_operation_cuda(points.contiguous(),
                                             idx)  #B, C, npoint, k

    new_points = torch.cat([grouped_xyz_norm, grouped_points],
                           dim=1)  # [B, C+C_xyz, npoint, k]

    if instance is not None:
        return new_xyz.transpose(1,
                                 2), grouped_xyz_norm, new_points, new_instance
    else:
        return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
예제 #8
0
    def forward(self, input_p, input_x, instance=None, instance_relation=None):
        '''
        input_p:  B, 3, npoint
        input_x: B, in_dim, npoint
        '''
        '''
        how to use the instance information:
        1. use it as guidance of the attention, mask the knns points with different instance label
        2. directly random choose points of same instance label as attention receptive field
        3. attend to the instance center
        '''
        INSTANCE_SCHEME = 3

        B, in_dim, npoint = list(input_x.size())
        n_sample = self.n_sample
        k = min(n_sample, npoint)
        h = self.nhead

        res = input_x

        input_p = input_p.permute([0, 2, 1])
        ori_input_p = input_p

        if instance is not None and INSTANCE_SCHEME == 1:
            # knn more points for sampling
            knn_sample_more_ratio = 2
            enlarged_k = k * knn_sample_more_ratio
            self.knn = KNN(k=enlarged_k, transpose_mode=True)
        else:
            self.knn = KNN(k=k, transpose_mode=True)

        # DEBUG: error here is that in the last block only 4-points;
        # however the knn still gives 16 idxs
        # so when n-point is smaller than the k(n_smaple)
        # if npoint < self.n_sample:
        # self.knn = KNN(k=npoint, transpose_mode=True)
        # else:
        # self.knn = KNN(k=n_sample, transpose_mode=True)
        # pass # regular case

        # DEBUG ONLY: using the input_x: feature space knn!
        # _, idx = self.knn(input_x.transpose(1,2), input_x.transpose(1,2))

        if instance is not None:

            if INSTANCE_SCHEME == 3:
                '''
                Ver3.0: use cur instance center as knn center,
                calc the instance center, and weighting the cur-idx and the instance center idx
                ERROR:
                    - all points of the same instance will have the same idxes? and all are cloest N points to centroid
                    - if use weiighted center and coord, However, need to do N-pointx KNN, will be slow...
                '''
                ori_input_p = input_p.clone()
                # where = torch.where(instance[0] == 1)
                # instance_xyz = input_p[:,where,:].mean(dim=1)   # get [bs, 3] centroid for cur instance
                for i_bs in range(instance.shape[0]):
                    for v in torch.unique(instance[i_bs]):
                        tmp_idx = torch.where(instance[i_bs] == v)[0]
                        ins_center = input_p[:, tmp_idx, :].mean(
                            dim=1)  # the centroids for each intsance
                        # average cur point and the instance center
                        alpha = 0.999
                        input_p[:,
                                tmp_idx, :] = alpha * input_p[:, tmp_idx, :] + (
                                    1 - alpha) * ins_center.unsqueeze(
                                        1)  # [bs, n_cur_ins, 3] + [bs, 1, 3]

                _, idx = self.knn(ori_input_p.contiguous(), ori_input_p)
                _, idx2 = self.knn(ori_input_p.contiguous(), input_p)
                print((idx == idx2).int().sum() / idx.nelement())
            else:
                _, idx = self.knn(input_p.contiguous(), input_p)
        else:
            _, idx = self.knn(input_p.contiguous(), input_p)

        idx = idx.int()

        if INSTANCE_SCHEME == 1:
            '''
            Ver1.0(Naive Version): mask the knn(instance label as auxiliary filter)
            older version of the instance mask
            directly ck if knn grouped point within the same pointset
            then mask if not in
            '''

            if instance is not None:
                # print('start processing the instance mask')
                masks = []
                for i_bs, idx_cur_bs in enumerate(idx):
                    # [4096, 16] cur_bs_idx
                    # [4096]: instance_label[i_bs]
                    mask = instance[i_bs][idx_cur_bs.long()]  # [4096, 2*k]
                    mask = mask - mask[:, 0].unsqueeze(
                        -1
                    )  # get the 1st column(the 1st element in k-elements is itself)
                    mask = (mask == 0).int()  # acuiqre the 0-1 mask
                    masks.append(mask)
                masks = torch.stack(masks)
                print("mask ratio {:.4f}".format(
                    masks.sum() / masks.nelement()))  # >0.5 means ok
                '''
                generate bigger knn-idx and mask, then choose the 1st n_sample(16) elements
                random sample other points from the latter, and use mask to fill into the 0 ones

                get the 1st k idxes that is not 0 in mask
                since the mask values are all 0-1, use argsort will return a vector
                however, we want smaller idxes in the front
                so we give 0 elments a large value to make it appears at last
                if use descend=True, biggest idx with 1 will come first
                '''

                inv_masks = (masks == 0).int()
                tmp_inds = torch.arange(masks.shape[2]).repeat(
                    masks.shape[0], masks.shape[1],
                    1).to(idx.device)  # generate the [1,2,...,enlarged_k] inds
                tmp_inds = tmp_inds * masks
                tmp_inds = tmp_inds + (
                    masks.shape[2] + 1
                ) * inv_masks  # fill the places of 0 with value bigger than the maximum value
                tmp_inds = torch.argsort(
                    tmp_inds
                )[:, :, :
                  k]  # after argsort, the former elements should be non-zero value with smaller idx
                idx = torch.gather(idx, -1, tmp_inds)
                idx = idx.int()

                # TODO: if nk still does not contain enough elements, the argsort will contain the closet knn result while not instance

        elif INSTANCE_SCHEME == 2:
            '''
            # Ver2.0: directly use the points of the same instance label as neighbors
            # random sample k points in the same instance
            '''
            if instance is not None:
                instance_relations = []
                for i_bs in range(instance.shape[0]):
                    instance_inds = [
                        torch.where(instance[i_bs] == v)[0]
                        for v in torch.unique(instance[i_bs])
                    ]  # torch.where returns a tuple, so use [0] to getthe tensor
                    instance_relation = torch.full([instance[0].shape[0], k],
                                                   -1).to(instance.device)
                    for i, ins_id in enumerate(instance_inds):
                        # TODO; stupid pytorch has no func like random.choice
                        if len(ins_id
                               ) <= 5:  # for small outlier points, skip em
                            continue
                        try:
                            perms = torch.multinomial(
                                ins_id.repeat(len(ins_id), 1).float(),
                                num_samples=min(k, len(ins_id)),
                                replacement=False)
                        except RuntimeError:
                            import ipdb
                            ipdb.set_trace()
                        choices = ins_id[perms]
                        instance_relation[
                            instance_inds[i], :choices.shape[1]] = choices
                    instance_relation[:, 0] = torch.arange(
                        instance_relation.shape[0])
                    instance_relations.append(instance_relation)
                instance_relations = torch.stack(instance_relations)

                # print('replacing the instance_relation')
                instance_relation_nonzero_mask = (instance_relations >=
                                                  0).int()
                instance_relation_zero_mask = (instance_relations < 0).int()

                idx = idx * instance_relation_zero_mask + instance_relations * instance_relation_nonzero_mask
                idx = idx.int()

        # ===================== Deprecated Methods ===========================1
        '''
        # Ver 2.3: failed version of receiving a instance_relation,
        # however, point downsample could not be handled
        # the instance feed in here is of the same size as the idxes
        # if the num within the same instance group as less than k
        # then the instance_relation will contain -1, we will replace these -1s
        # with the original idx acquired by knn
        if instance_relation is not None:
            print('replacing the instance_relation')
            # import ipdb; ipdb.set_trace()

            instance_relation = instance_relation[:,:,:k]
            instance_relation_nonzero_mask = (instance_relation>=0).int()
            instance_relation_zero_mask = (instance_relation<0).int()

            # idx = idx*instance_relation_zero_mask + instance_relation*instance_relation_nonzero_mask
            idx = instance_relation.int()

        '''
        '''
        Ver 2.2: Hash Table-based
        1st pack the instance into dict(hash table)
        then ref the points within the same scope to replace the knn points
        if ont enough points of the same insatcne, keep the knn idxs
        '''
        '''
        # pack the instacne into dict for further reference
        if instance is not None:
            print('start creating instance dicts')
            instance_dicts = []
            for i_bs, instance_cur_bs in enumerate(instance):
                instance_dict = {}
                for ins_idx, ins in enumerate(instance_cur_bs):
                    if ins.item() in instance_dict.keys():
                        instance_dict[ins.item()].append(ins_idx)
                    else:
                        instance_dict[ins.item()] = [ins_idx]
                for ins_k in instance_dict.keys():
                    instance_dict[ins_k] = torch.tensor(instance_dict[ins_k]).to(instance.device)
                instance_dicts.append(instance_dict)

            l1 = []
            for i_bs in range(instance.shape[0]):
                l0 = []
                for i_point in range(instance.shape[1]):
                    tmp = torch.zeros([k])
                    instance_gathered  = instance_dicts[i_bs][instance[i_bs][i_point].item()][:k]
                    tmp[:len(instance_gathered)] = instance_gathered
                    # idx[i_bs][i_point][:len(instance_gathered)] = instance_gathered
                    l0.append(tmp)
                tmp1 = torch.stack(l0)
                l1.append(tmp1)
            new_idx = torch.stack(l1)
        '''
        '''
        Ver: 2.1: Naive Version of for-loop replacement 
        # Too slow version, needs improving
        # 1st use knn then use mask the value belongs not to the same instance
        instance_masks = []
        for i_batch, single_batch_instance in enumerate(instance):
            # single_batch_instance: [npoint]
            masks_cur_batch = []
            for i_point, gathered_points in enumerate(idx[i_batch]):
                # gathered_points: [k]
                points_with_same_instance = torch.where(single_batch_instance == single_batch_instance[i_point])[0]
                # ck if the grouped idxes are within the same idxes
                cur_mask = torch.tensor([g.item() in points_with_same_instance for g in gathered_points])
                masks_cur_batch.append(cur_mask)
            masks_cur_batch = torch.stack(masks_cur_batch)
            instance_masks.append(masks_cur_batch)
        instance_masks = torch.stack(instance_masks)
        '''

        # ==========================================================================================

        grouped_input_p = grouping_operation_cuda(
            input_p.transpose(1, 2).contiguous(), idx)  # [bs, xyz, npoint, k]

        if self.pre_ln:
            input_x = self.ln_top(input_x.transpose(1, 2)).transpose(1, 2)

        input_x = self.linear_top(input_x)

        # TODO: apply the layer-norm
        # however the original is [bs, dim, npoint]
        if self.pre_ln:
            input_x = self.ln_attn(input_x.transpose(1, 2)).transpose(1, 2)

        # grouped_input_x = index_points(input_x.permute([0,2,1]), idx.long()).permute([0,3,1,2])
        # grouped_input_x = grouping_operation_cuda(input_x.contiguous(), idx)  # [bs, xyz, npoint, K]
        phi = self.phi(input_x)
        phi = phi[:, :, :, None].repeat(1, 1, 1, k)
        psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx)
        alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(),
                                        idx)  # [bs, xyz, npoint, k]

        relative_xyz = input_p.permute([0, 2, 1])[:, :, :,
                                                  None] - grouped_input_p
        pos_encoding = self.delta(relative_xyz)  # [bs, dims, npoint, k]

        if self.use_vector_attn:
            # the attn_map: [vector_dim];
            # the alpha:    [out_dim]
            attn_map = F.softmax(self.gamma(phi - psi + pos_encoding),
                                 dim=-1)  # [B, Dim, N, k]
            # if instance is not None: # apply mask
            # attn_map = attn_map*(masks.unsqueeze(1))
            y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                1) * (alpha + pos_encoding)
            y = y.sum(dim=-1)
        else:
            phi = phi.reshape(B, h, self.out_dim // h, npoint, k)
            psi = psi.reshape(B, h, self.out_dim // h, npoint, k)
            attn_map = F.softmax(
                (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding,
                dim=-1)
            y = attn_map * (alpha + pos_encoding)
            y = y.sum(dim=-1)

        if self.pre_ln:
            y = self.ln_down(y.transpose(1, 2)).transpose(1, 2)

        y = self.linear_down(y)

        return y + res, attn_map.detach().cpu().data
예제 #9
0
    def forward(self, input_p, input_x):
        '''
        input_p:  B, 3, npoint
        input_x: B, in_dim, npoint
        '''
        B, in_dim, npoint = list(input_x.size())  # npoint: the input point-num
        n_sample = self.n_sample  # the knn-sample num cur block
        k = min(n_sample, npoint)  # denoting the num_point cur layer
        if not self.use_vector_attn:
            h = self.nhead  # only used in non-vextor attn

        input_p = input_p.permute([0, 2, 1])  # [B, npoint, 3]
        self.register_buffer('in_xyz_map', input_p)

        if self.fps_rate is not None:
            npoint = npoint // self.fps_rate
            fps_idx = farthest_point_sample_cuda(input_p, npoint)
            torch.cuda.empty_cache()
            input_p_fps = index_points_cuda(input_p, fps_idx)  # [B. M, 3]
            if self.SKIP_ALL:
                input_p_reduced = input_p_fps.transpose(1, 2)
                input_x_reduced = index_points_cuda(
                    self.tmp_linear(input_x).transpose(1, 2),
                    fps_idx).transpose(1, 2)
                return input_p_reduced, input_x_reduced
        else:
            input_p_fps = input_p
            input_x_fps = input_x

        res = input_x  # [B, dim, M]

        if self.USE_KNN:
            self.knn = KNN(k=k, transpose_mode=True)
            _, idx = self.knn(input_p.contiguous(), input_p_fps.contiguous())
            idx = idx.int()  # [bs, npoint, k]
        else:
            idx = query_ball_point_cuda(
                self.radius, k, input_p.contiguous(),
                input_p_fps.contiguous())  # [bs, npoint, k]

        grouped_input_p = grouping_operation_cuda(
            input_p.transpose(1, 2).contiguous(), idx)  # [bs, xyz, npoint, k]
        grouped_input_x = grouping_operation_cuda(
            input_x.contiguous(), idx)  # [bs, hidden_dim, npoint, k]

        self.register_buffer('neighbor_map', idx)

        # TODO: define proper r for em
        # query_idx = query_ball_point_cuda(radius, k, coord, coord) # [bs, npoint, k]
        # self.knn = KNN(k=k, transpose_mode=True)
        # _, knn_idx = self.knn(input_p.contiguous(), input_p)
        # import ipdb; ipdb.set_trace()

        if self.fps_rate is not None:
            if self.SKIP_ATTN:
                pass  # only apply linear-top for ds blocks
            else:
                input_x = self.linear_top(input_x)
        else:
            if self.SKIP_ATTN:
                pass  # only apply linear-top for ds blocks
            else:
                input_x = self.linear_top(input_x)
        # input_x = self.linear_top(input_x)

        if self.SKIP_ATTN:
            # import ipdb; ipdb.set_trace()
            # out_dim should be the same with in_dim, since here contains no TD
            if self.POS_ENCODING:
                relative_xyz = input_p_fps.permute(
                    [0, 2, 1])[:, :, :, None] - grouped_input_p
                pos_encoding = self.delta(
                    relative_xyz)  # [bs, dims, npoint, k]
                if self.CAT_POS:
                    alpha = self.alpha(
                        torch.cat([grouped_input_x, relative_xyz], dim=1))
                else:  # use sum
                    alpha = self.alpha(grouped_input_x + pos_encoding)
            else:
                alpha = self.alpha(grouped_input_x)
                # alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(), idx)

            y = alpha.max(dim=-1)[0]
            # y = alpha.sum(dim=-1)
            y = self.linear_down(y)

            if self.fps_rate is not None:
                input_p_reduced = input_p_fps.transpose(1, 2)
                # WRONG!: noneed for applying fps_idx here
                # input_x_reduced = index_points_cuda(y.transpose(1,2), fps_idx).transpose(1,2)  # [B, dim, M]
                input_x_reduced = y
                return input_p_reduced, input_x_reduced
            else:
                input_p_reduced = input_p_fps.transpose(1, 2)
                input_x_reduced = y + res
                return input_p_reduced, input_x_reduced

        # when downsampling the TRBlock
        # should use downsampled qkv here, so use input_x_fps
        # as for normal block, input_x and input_x_fps are the same
        if self.fps_rate is not None:
            input_x_fps = index_points_cuda(input_x.transpose(
                1, 2), fps_idx).transpose(
                    1, 2)  # it is only used for tr-like downsample block
            phi = self.phi(input_x_fps)
        else:
            phi = self.phi(input_x)
        phi = phi[:, :, :, None].repeat(1, 1, 1, k)
        psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx)
        self.skip_knn = True
        alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(),
                                        idx)  # [bs, xyz, npoint, k]

        if self.POS_ENCODING:
            relative_xyz = input_p_fps.permute(
                [0, 2, 1])[:, :, :, None] - grouped_input_p
            pos_encoding = self.delta(relative_xyz)  # [bs, dims, npoint, k]

        if self.use_vector_attn:
            # the attn_map: [vector_dim];
            # the alpha:    [out_dim]
            if self.POS_ENCODING:
                # if V_POS and QK_POS is both false, then apply all pos_encoding
                assert (
                    self.V_POS_ONLY and self.QK_POS_ONLY
                ) is False  # only one of the V_ONLY and QK_ONLY should be applied
                if self.V_POS_ONLY:
                    attn_map = F.softmax(self.gamma(phi - psi), dim=-1)
                else:
                    attn_map = F.softmax(self.gamma(phi - psi + pos_encoding),
                                         dim=-1)
                if self.QK_POS_ONLY:
                    y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                        1) * (alpha)
                else:
                    y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                        1) * (alpha + pos_encoding)
            else:
                attn_map = F.softmax(self.gamma(phi - psi), dim=-1)
                y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                    1) * (alpha)
            if self.MAX_POOL:
                y = y.max(dim=-1)[0]
            else:
                y = y.sum(dim=-1)
        else:
            assert self.POS_ENCODING == True
            phi = phi.reshape(B, h, self.out_dim // h, npoint, k)
            psi = psi.reshape(B, h, self.out_dim // h, npoint, k)
            attn_map = F.softmax(
                (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding,
                dim=-1)
            y = attn_map * (alpha + pos_encoding)
            y = y.sum(dim=-1)

        self.register_buffer('attn_map', attn_map.mean(dim=1))

        y = self.linear_down(y)

        if self.fps_rate is not None:
            input_p_reduced = input_p_fps.transpose(1, 2)
            # input_x_reduced = index_points_cuda(y.transpose(1,2), fps_idx).transpose(1,2)  # [B, dim, M]
            input_x_reduced = y
            return input_p_reduced, input_x_reduced
        else:
            input_p_reduced = input_p_fps.transpose(1, 2)
            input_x_reduced = y + res
            return input_p_reduced, input_x_reduced
예제 #10
0
    def forward(self, input_p, input_x):
        '''
        input_p:  B, 3, npoint
        input_x: B, in_dim, npoint
        '''

        B, in_dim, npoint = list(input_x.size())
        n_sample = self.n_sample
        k = min(n_sample, npoint)
        h = self.nhead

        res = input_x

        input_p = input_p.permute([0, 2, 1])

        # DEBUG: error here is that in the last block only 4-points;
        # however the knn still gives 16 idxs
        if npoint < self.n_sample:
            self.knn = KNN(k=npoint, transpose_mode=True)

        # DEBUG ONLY: using the input_x: feature space knn!
        # _, idx = self.knn(input_x.transpose(1,2), input_x.transpose(1,2))

        _, idx = self.knn(input_p.contiguous(), input_p)
        idx = idx.int()

        grouped_input_p = grouping_operation_cuda(
            input_p.transpose(1, 2).contiguous(), idx)  # [bs, xyz, npoint, k]

        if self.pre_ln:
            input_x = self.ln_top(input_x.transpose(1, 2)).transpose(1, 2)

        input_x = self.linear_top(input_x)

        # TODO: apply the layer-norm
        # however the original is [bs, dim, npoint]
        if self.pre_ln:
            input_x = self.ln_attn(input_x.transpose(1, 2)).transpose(1, 2)

        # grouped_input_x = index_points(input_x.permute([0,2,1]), idx.long()).permute([0,3,1,2])
        # grouped_input_x = grouping_operation_cuda(input_x.contiguous(), idx)  # [bs, xyz, npoint, K]
        phi = self.phi(input_x)
        phi = phi[:, :, :, None].repeat(1, 1, 1, k)
        psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx)
        alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(),
                                        idx)  # [bs, xyz, npoint, k]

        relative_xyz = input_p.permute([0, 2, 1])[:, :, :,
                                                  None] - grouped_input_p
        pos_encoding = self.delta(relative_xyz)  # [bs, dims, npoint, k]

        if self.use_vector_attn:
            # the attn_map: [vector_dim];
            # the alpha:    [out_dim]
            attn_map = F.softmax(self.gamma(phi - psi + pos_encoding), dim=-1)
            y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1,
                                1) * (alpha + pos_encoding)
            y = y.sum(dim=-1)
        else:
            phi = phi.reshape(B, h, self.out_dim // h, npoint, k)
            psi = psi.reshape(B, h, self.out_dim // h, npoint, k)
            attn_map = F.softmax(
                (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding,
                dim=-1)
            y = attn_map * (alpha + pos_encoding)
            y = y.sum(dim=-1)

        if self.pre_ln:
            y = self.ln_down(y.transpose(1, 2)).transpose(1, 2)

        y = self.linear_down(y)

        return y + res, attn_map.detach().cpu().data