def sample_and_group_cuda(npoint, k, xyz, points, cat_xyz_feature=True, fps_only=False): """ Input: npoint: k: xyz: input points position data, [B, N, 3] points: input points data, [B, C, N] Return: new_xyz: sampled points position data, [B, 3, npoint] new_points: sampled points data, [B, C+C_xyz, npoint, k] grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k] """ k = min(npoint, k) knn = KNN(k=k, transpose_mode=True) B, N, C_xyz = xyz.shape if npoint < N: # fps_idx = torch.arange(npoint).repeat(xyz.shape[0], 1).int().cuda() # DEBUG ONLY fps_idx = farthest_point_sample_cuda(xyz, npoint) # [B, npoint] torch.cuda.empty_cache() new_xyz = index_points_cuda(xyz, fps_idx) #[B, npoint, 3] new_points = index_points_cuda(points.transpose(1, 2), fps_idx) else: new_xyz = xyz if fps_only: return new_xyz.transpose(1, 2), new_points.transpose(1, 2), fps_idx torch.cuda.empty_cache() _, idx = knn(xyz.contiguous(), new_xyz) # B, npoint, k idx = idx.int() torch.cuda.empty_cache() grouped_xyz = grouping_operation_cuda( xyz.transpose(1, 2).contiguous(), idx).permute(0, 2, 3, 1) # [B, npoint, k, C_xyz] torch.cuda.empty_cache() grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1, C_xyz) # [B, npoint, k, 3] grouped_xyz_norm = grouped_xyz_norm.permute( 0, 3, 1, 2).contiguous() # [B, 3, npoint, k] torch.cuda.empty_cache() grouped_points = grouping_operation_cuda(points.contiguous(), idx) #B, C, npoint, k if cat_xyz_feature: new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=1) # [B, C+C_xyz, npoint, k] else: new_points = grouped_points # [B, C+C_xyz, npoint, k] return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points, idx
def sample_and_group_cuda(npoint, k, xyz, points, cat_xyz_feature=True): """ Input: npoint: k: xyz: input points position data, [B, N, 3] points: input points data, [B, C, N] Return: new_xyz: sampled points position data, [B, 3, npoint] new_points: sampled points data, [B, C+C_xyz, npoint, k] grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k] """ k = min(npoint, k) knn = KNN(k=k, transpose_mode=True) B, N, C_xyz = xyz.shape if npoint < N: fps_idx = farthest_point_sample_cuda(xyz, npoint) # [B, npoint] torch.cuda.empty_cache() new_xyz = index_points_cuda(xyz, fps_idx) #[B, npoint, 3] else: new_xyz = xyz torch.cuda.empty_cache() _, idx = knn(xyz.contiguous(), new_xyz) # B, npoint, k idx = idx.int() torch.cuda.empty_cache() grouped_xyz = grouping_operation_cuda( xyz.transpose(1, 2).contiguous(), idx).permute(0, 2, 3, 1) # [B, npoint, k, C_xyz] torch.cuda.empty_cache() try: # DEBUG: when using the mixed-trans, some last voxels may have less points grouped_xyz_norm = grouped_xyz - new_xyz.view(-1, min( npoint, N), 1, C_xyz) # [B, npoint, k, 3] except: import ipdb ipdb.set_trace() grouped_xyz_norm = grouped_xyz_norm.permute( 0, 3, 1, 2).contiguous() # [B, 3, npoint, k] torch.cuda.empty_cache() grouped_points = grouping_operation_cuda(points.contiguous(), idx) #B, C, npoint, k if cat_xyz_feature: new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=1) # [B, C+C_xyz, npoint, k] else: new_points = grouped_points # [B, C+C_xyz, npoint, k] return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
def stem_knn(xyz, points, k): knn = KNN(k=k, transpose_mode=True) xyz = xyz.permute([0, 2, 1]) _, idx = knn(xyz.contiguous(), xyz) # xyz: [bs, npoints, coord] idx: [bs, npoint, k] idx = idx.int() # take in [B, 3, N] grouped_xyz = grouping_operation_cuda( xyz.transpose(1, 2).contiguous(), idx) # [bs, xyz, n_point, k] grouped_points = grouping_operation_cuda(points.contiguous(), idx) #B, C, npoint, k) return grouped_xyz, grouped_points
def forward(self, xyz_1, xyz_2, points_1, points_2): """ Input: M < N xyz_1: input points position data, [B, 3, M] xyz_2: input points position data, [B, 3, N] points_1: input points data, [B, C, M] points_2: input points data, [B, C, N] interpolate xyz_2's coordinates feature with knn neighbor's features weighted by inverse distance Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ B, input_dim, M = list(points_1.size()) B, output_dim, N = list(points_2.size()) points_1 = self.linear_1(points_1) points_2 = self.linear_2(points_2) dists = square_distance(xyz_2.transpose(1, 2), xyz_1.transpose(1, 2)) # [B, N, M] dists, idx = dists.sort(dim=-1) dists, idx = dists[:, :, :self.k], idx[:, :, :self.k] dist_recip = 1.0 / (dists + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_points = torch.sum( \ grouping_operation_cuda(points_1, idx.int())*weight.view(B, 1, N, 3) ,dim=-1) return xyz_2, (interpolated_points + points_2)
def sample_and_group_cuda(npoint, k, xyz, points): """ Input: npoint: k: xyz: input points position data, [B, N, 3] points: input points data, [B, C, N] Return: new_xyz: sampled points position data, [B, 3, npoint] new_points: sampled points data, [B, C+C_xyz, npoint, k] grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k] """ k = min(npoint, k) knn = KNN(k=k, transpose_mode=True) B, N, C_xyz = xyz.shape if npoint < N: fps_idx = farthest_point_sample_cuda(xyz, npoint) # [B, npoint] torch.cuda.empty_cache() new_xyz = index_points_cuda(xyz, fps_idx) #[B, npoint, 3] else: new_xyz = xyz torch.cuda.empty_cache() _, idx = knn(xyz.contiguous(), new_xyz) # B, npoint, k idx = idx.int() torch.cuda.empty_cache() grouped_xyz = grouping_operation_cuda( xyz.transpose(1, 2).contiguous(), idx).permute(0, 2, 3, 1) # [B, npoint, k, C_xyz] #print(grouped_xyz.size()) torch.cuda.empty_cache() grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1, C_xyz) # [B, npoint, k, 3] grouped_xyz_norm = grouped_xyz_norm.permute( 0, 3, 1, 2).contiguous() # [B, 3, npoint, k] torch.cuda.empty_cache() grouped_points = grouping_operation_cuda(points.contiguous(), idx) #B, C, npoint, k new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=1) # [B, C+C_xyz, npoint, k] return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
def get_neighbor(self, ref: ME.SparseTensor, query: ME.SparseTensor): B_nq, _ = query.C.shape coord = query.C # (N, 4) batch_info = coord[:, 0] coord, mask, idx_ = separate_batch(coord) # (b, n, 3) b, n, _ = coord.shape if self.use_knn: _, idx = self.knn(coord.contiguous(), coord) grouped_coord = grouping_operation_cuda( coord.float().transpose(1, 2).contiguous(), idx.int()) result_padded = grouped_coord.permute([0, 2, 3, 1]) else: query_and_group_cuda = QueryAndGroup(radius=self.r, nsample=self.k, use_xyz=False) coord = coord.float() idxs = query_and_group_cuda( xyz=coord, new_xyz=coord, features=coord.transpose(1, 2).contiguous(), ) # idx: [bs, xyz, npoint, nsample] idxs = idxs.permute([0, 2, 3, 1]) # idx: [bs, npoint, nsample, xyz] result_padded = idxs # unpad result (b, n, k, 3) -> (B_nq, k, 4) by applying mask result = torch.zeros([B_nq, self.k, 4], dtype=torch.int32, device=query.device) result[:, :, 1:] = torch.gather(result_padded.reshape(-1, self.k, 3), 0, idx_.reshape(-1, 1, 1).repeat(1, self.k, 3)) result[:, :, 0] = batch_info.unsqueeze(-1).repeat(1, self.k) return result, mask, idx_
def sample_and_group_cuda(npoint, k, xyz, points, instance=None, instance_relation=None): """ Input: npoint: seems 1/4 of N k: xyz: input points position data, [B, N, 3] points: input points data, [B, C, N] instance: input_instance, [B,N] Return: new_xyz: sampled points position data, [B, 3, npoint] new_points: sampled points data, [B, C+C_xyz, npoint, k] grouped_xyz_norm: sampled relative points position data, [B, 3, npoint, k] new_instance, [B, npoint] """ k = min(npoint, k) knn = KNN(k=k, transpose_mode=True) B, N, C_xyz = xyz.shape if npoint < N: fps_idx = farthest_point_sample_cuda(xyz, npoint) # [B, npoint] torch.cuda.empty_cache() new_xyz = index_points_cuda(xyz, fps_idx) #[B, npoint, 3] else: new_xyz = xyz # unsqueeze to [B,N,1] then apply indexing if instance is not None: new_instance = index_points_cuda( instance.unsqueeze(-1).float(), fps_idx).squeeze(-1) else: pass torch.cuda.empty_cache() _, idx = knn(xyz.contiguous(), new_xyz) # B, npoint, k idx = idx.int() torch.cuda.empty_cache() grouped_xyz = grouping_operation_cuda( xyz.transpose(1, 2).contiguous(), idx).permute(0, 2, 3, 1) # [B, npoint, k, C_xyz] #print(grouped_xyz.size()) torch.cuda.empty_cache() grouped_xyz_norm = grouped_xyz - new_xyz.view(B, npoint, 1, C_xyz) # [B, npoint, k, 3] grouped_xyz_norm = grouped_xyz_norm.permute( 0, 3, 1, 2).contiguous() # [B, 3, npoint, k] torch.cuda.empty_cache() grouped_points = grouping_operation_cuda(points.contiguous(), idx) #B, C, npoint, k new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=1) # [B, C+C_xyz, npoint, k] if instance is not None: return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points, new_instance else: return new_xyz.transpose(1, 2), grouped_xyz_norm, new_points
def forward(self, input_p, input_x, instance=None, instance_relation=None): ''' input_p: B, 3, npoint input_x: B, in_dim, npoint ''' ''' how to use the instance information: 1. use it as guidance of the attention, mask the knns points with different instance label 2. directly random choose points of same instance label as attention receptive field 3. attend to the instance center ''' INSTANCE_SCHEME = 3 B, in_dim, npoint = list(input_x.size()) n_sample = self.n_sample k = min(n_sample, npoint) h = self.nhead res = input_x input_p = input_p.permute([0, 2, 1]) ori_input_p = input_p if instance is not None and INSTANCE_SCHEME == 1: # knn more points for sampling knn_sample_more_ratio = 2 enlarged_k = k * knn_sample_more_ratio self.knn = KNN(k=enlarged_k, transpose_mode=True) else: self.knn = KNN(k=k, transpose_mode=True) # DEBUG: error here is that in the last block only 4-points; # however the knn still gives 16 idxs # so when n-point is smaller than the k(n_smaple) # if npoint < self.n_sample: # self.knn = KNN(k=npoint, transpose_mode=True) # else: # self.knn = KNN(k=n_sample, transpose_mode=True) # pass # regular case # DEBUG ONLY: using the input_x: feature space knn! # _, idx = self.knn(input_x.transpose(1,2), input_x.transpose(1,2)) if instance is not None: if INSTANCE_SCHEME == 3: ''' Ver3.0: use cur instance center as knn center, calc the instance center, and weighting the cur-idx and the instance center idx ERROR: - all points of the same instance will have the same idxes? and all are cloest N points to centroid - if use weiighted center and coord, However, need to do N-pointx KNN, will be slow... ''' ori_input_p = input_p.clone() # where = torch.where(instance[0] == 1) # instance_xyz = input_p[:,where,:].mean(dim=1) # get [bs, 3] centroid for cur instance for i_bs in range(instance.shape[0]): for v in torch.unique(instance[i_bs]): tmp_idx = torch.where(instance[i_bs] == v)[0] ins_center = input_p[:, tmp_idx, :].mean( dim=1) # the centroids for each intsance # average cur point and the instance center alpha = 0.999 input_p[:, tmp_idx, :] = alpha * input_p[:, tmp_idx, :] + ( 1 - alpha) * ins_center.unsqueeze( 1) # [bs, n_cur_ins, 3] + [bs, 1, 3] _, idx = self.knn(ori_input_p.contiguous(), ori_input_p) _, idx2 = self.knn(ori_input_p.contiguous(), input_p) print((idx == idx2).int().sum() / idx.nelement()) else: _, idx = self.knn(input_p.contiguous(), input_p) else: _, idx = self.knn(input_p.contiguous(), input_p) idx = idx.int() if INSTANCE_SCHEME == 1: ''' Ver1.0(Naive Version): mask the knn(instance label as auxiliary filter) older version of the instance mask directly ck if knn grouped point within the same pointset then mask if not in ''' if instance is not None: # print('start processing the instance mask') masks = [] for i_bs, idx_cur_bs in enumerate(idx): # [4096, 16] cur_bs_idx # [4096]: instance_label[i_bs] mask = instance[i_bs][idx_cur_bs.long()] # [4096, 2*k] mask = mask - mask[:, 0].unsqueeze( -1 ) # get the 1st column(the 1st element in k-elements is itself) mask = (mask == 0).int() # acuiqre the 0-1 mask masks.append(mask) masks = torch.stack(masks) print("mask ratio {:.4f}".format( masks.sum() / masks.nelement())) # >0.5 means ok ''' generate bigger knn-idx and mask, then choose the 1st n_sample(16) elements random sample other points from the latter, and use mask to fill into the 0 ones get the 1st k idxes that is not 0 in mask since the mask values are all 0-1, use argsort will return a vector however, we want smaller idxes in the front so we give 0 elments a large value to make it appears at last if use descend=True, biggest idx with 1 will come first ''' inv_masks = (masks == 0).int() tmp_inds = torch.arange(masks.shape[2]).repeat( masks.shape[0], masks.shape[1], 1).to(idx.device) # generate the [1,2,...,enlarged_k] inds tmp_inds = tmp_inds * masks tmp_inds = tmp_inds + ( masks.shape[2] + 1 ) * inv_masks # fill the places of 0 with value bigger than the maximum value tmp_inds = torch.argsort( tmp_inds )[:, :, : k] # after argsort, the former elements should be non-zero value with smaller idx idx = torch.gather(idx, -1, tmp_inds) idx = idx.int() # TODO: if nk still does not contain enough elements, the argsort will contain the closet knn result while not instance elif INSTANCE_SCHEME == 2: ''' # Ver2.0: directly use the points of the same instance label as neighbors # random sample k points in the same instance ''' if instance is not None: instance_relations = [] for i_bs in range(instance.shape[0]): instance_inds = [ torch.where(instance[i_bs] == v)[0] for v in torch.unique(instance[i_bs]) ] # torch.where returns a tuple, so use [0] to getthe tensor instance_relation = torch.full([instance[0].shape[0], k], -1).to(instance.device) for i, ins_id in enumerate(instance_inds): # TODO; stupid pytorch has no func like random.choice if len(ins_id ) <= 5: # for small outlier points, skip em continue try: perms = torch.multinomial( ins_id.repeat(len(ins_id), 1).float(), num_samples=min(k, len(ins_id)), replacement=False) except RuntimeError: import ipdb ipdb.set_trace() choices = ins_id[perms] instance_relation[ instance_inds[i], :choices.shape[1]] = choices instance_relation[:, 0] = torch.arange( instance_relation.shape[0]) instance_relations.append(instance_relation) instance_relations = torch.stack(instance_relations) # print('replacing the instance_relation') instance_relation_nonzero_mask = (instance_relations >= 0).int() instance_relation_zero_mask = (instance_relations < 0).int() idx = idx * instance_relation_zero_mask + instance_relations * instance_relation_nonzero_mask idx = idx.int() # ===================== Deprecated Methods ===========================1 ''' # Ver 2.3: failed version of receiving a instance_relation, # however, point downsample could not be handled # the instance feed in here is of the same size as the idxes # if the num within the same instance group as less than k # then the instance_relation will contain -1, we will replace these -1s # with the original idx acquired by knn if instance_relation is not None: print('replacing the instance_relation') # import ipdb; ipdb.set_trace() instance_relation = instance_relation[:,:,:k] instance_relation_nonzero_mask = (instance_relation>=0).int() instance_relation_zero_mask = (instance_relation<0).int() # idx = idx*instance_relation_zero_mask + instance_relation*instance_relation_nonzero_mask idx = instance_relation.int() ''' ''' Ver 2.2: Hash Table-based 1st pack the instance into dict(hash table) then ref the points within the same scope to replace the knn points if ont enough points of the same insatcne, keep the knn idxs ''' ''' # pack the instacne into dict for further reference if instance is not None: print('start creating instance dicts') instance_dicts = [] for i_bs, instance_cur_bs in enumerate(instance): instance_dict = {} for ins_idx, ins in enumerate(instance_cur_bs): if ins.item() in instance_dict.keys(): instance_dict[ins.item()].append(ins_idx) else: instance_dict[ins.item()] = [ins_idx] for ins_k in instance_dict.keys(): instance_dict[ins_k] = torch.tensor(instance_dict[ins_k]).to(instance.device) instance_dicts.append(instance_dict) l1 = [] for i_bs in range(instance.shape[0]): l0 = [] for i_point in range(instance.shape[1]): tmp = torch.zeros([k]) instance_gathered = instance_dicts[i_bs][instance[i_bs][i_point].item()][:k] tmp[:len(instance_gathered)] = instance_gathered # idx[i_bs][i_point][:len(instance_gathered)] = instance_gathered l0.append(tmp) tmp1 = torch.stack(l0) l1.append(tmp1) new_idx = torch.stack(l1) ''' ''' Ver: 2.1: Naive Version of for-loop replacement # Too slow version, needs improving # 1st use knn then use mask the value belongs not to the same instance instance_masks = [] for i_batch, single_batch_instance in enumerate(instance): # single_batch_instance: [npoint] masks_cur_batch = [] for i_point, gathered_points in enumerate(idx[i_batch]): # gathered_points: [k] points_with_same_instance = torch.where(single_batch_instance == single_batch_instance[i_point])[0] # ck if the grouped idxes are within the same idxes cur_mask = torch.tensor([g.item() in points_with_same_instance for g in gathered_points]) masks_cur_batch.append(cur_mask) masks_cur_batch = torch.stack(masks_cur_batch) instance_masks.append(masks_cur_batch) instance_masks = torch.stack(instance_masks) ''' # ========================================================================================== grouped_input_p = grouping_operation_cuda( input_p.transpose(1, 2).contiguous(), idx) # [bs, xyz, npoint, k] if self.pre_ln: input_x = self.ln_top(input_x.transpose(1, 2)).transpose(1, 2) input_x = self.linear_top(input_x) # TODO: apply the layer-norm # however the original is [bs, dim, npoint] if self.pre_ln: input_x = self.ln_attn(input_x.transpose(1, 2)).transpose(1, 2) # grouped_input_x = index_points(input_x.permute([0,2,1]), idx.long()).permute([0,3,1,2]) # grouped_input_x = grouping_operation_cuda(input_x.contiguous(), idx) # [bs, xyz, npoint, K] phi = self.phi(input_x) phi = phi[:, :, :, None].repeat(1, 1, 1, k) psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx) alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(), idx) # [bs, xyz, npoint, k] relative_xyz = input_p.permute([0, 2, 1])[:, :, :, None] - grouped_input_p pos_encoding = self.delta(relative_xyz) # [bs, dims, npoint, k] if self.use_vector_attn: # the attn_map: [vector_dim]; # the alpha: [out_dim] attn_map = F.softmax(self.gamma(phi - psi + pos_encoding), dim=-1) # [B, Dim, N, k] # if instance is not None: # apply mask # attn_map = attn_map*(masks.unsqueeze(1)) y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1, 1) * (alpha + pos_encoding) y = y.sum(dim=-1) else: phi = phi.reshape(B, h, self.out_dim // h, npoint, k) psi = psi.reshape(B, h, self.out_dim // h, npoint, k) attn_map = F.softmax( (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding, dim=-1) y = attn_map * (alpha + pos_encoding) y = y.sum(dim=-1) if self.pre_ln: y = self.ln_down(y.transpose(1, 2)).transpose(1, 2) y = self.linear_down(y) return y + res, attn_map.detach().cpu().data
def forward(self, input_p, input_x): ''' input_p: B, 3, npoint input_x: B, in_dim, npoint ''' B, in_dim, npoint = list(input_x.size()) # npoint: the input point-num n_sample = self.n_sample # the knn-sample num cur block k = min(n_sample, npoint) # denoting the num_point cur layer if not self.use_vector_attn: h = self.nhead # only used in non-vextor attn input_p = input_p.permute([0, 2, 1]) # [B, npoint, 3] self.register_buffer('in_xyz_map', input_p) if self.fps_rate is not None: npoint = npoint // self.fps_rate fps_idx = farthest_point_sample_cuda(input_p, npoint) torch.cuda.empty_cache() input_p_fps = index_points_cuda(input_p, fps_idx) # [B. M, 3] if self.SKIP_ALL: input_p_reduced = input_p_fps.transpose(1, 2) input_x_reduced = index_points_cuda( self.tmp_linear(input_x).transpose(1, 2), fps_idx).transpose(1, 2) return input_p_reduced, input_x_reduced else: input_p_fps = input_p input_x_fps = input_x res = input_x # [B, dim, M] if self.USE_KNN: self.knn = KNN(k=k, transpose_mode=True) _, idx = self.knn(input_p.contiguous(), input_p_fps.contiguous()) idx = idx.int() # [bs, npoint, k] else: idx = query_ball_point_cuda( self.radius, k, input_p.contiguous(), input_p_fps.contiguous()) # [bs, npoint, k] grouped_input_p = grouping_operation_cuda( input_p.transpose(1, 2).contiguous(), idx) # [bs, xyz, npoint, k] grouped_input_x = grouping_operation_cuda( input_x.contiguous(), idx) # [bs, hidden_dim, npoint, k] self.register_buffer('neighbor_map', idx) # TODO: define proper r for em # query_idx = query_ball_point_cuda(radius, k, coord, coord) # [bs, npoint, k] # self.knn = KNN(k=k, transpose_mode=True) # _, knn_idx = self.knn(input_p.contiguous(), input_p) # import ipdb; ipdb.set_trace() if self.fps_rate is not None: if self.SKIP_ATTN: pass # only apply linear-top for ds blocks else: input_x = self.linear_top(input_x) else: if self.SKIP_ATTN: pass # only apply linear-top for ds blocks else: input_x = self.linear_top(input_x) # input_x = self.linear_top(input_x) if self.SKIP_ATTN: # import ipdb; ipdb.set_trace() # out_dim should be the same with in_dim, since here contains no TD if self.POS_ENCODING: relative_xyz = input_p_fps.permute( [0, 2, 1])[:, :, :, None] - grouped_input_p pos_encoding = self.delta( relative_xyz) # [bs, dims, npoint, k] if self.CAT_POS: alpha = self.alpha( torch.cat([grouped_input_x, relative_xyz], dim=1)) else: # use sum alpha = self.alpha(grouped_input_x + pos_encoding) else: alpha = self.alpha(grouped_input_x) # alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(), idx) y = alpha.max(dim=-1)[0] # y = alpha.sum(dim=-1) y = self.linear_down(y) if self.fps_rate is not None: input_p_reduced = input_p_fps.transpose(1, 2) # WRONG!: noneed for applying fps_idx here # input_x_reduced = index_points_cuda(y.transpose(1,2), fps_idx).transpose(1,2) # [B, dim, M] input_x_reduced = y return input_p_reduced, input_x_reduced else: input_p_reduced = input_p_fps.transpose(1, 2) input_x_reduced = y + res return input_p_reduced, input_x_reduced # when downsampling the TRBlock # should use downsampled qkv here, so use input_x_fps # as for normal block, input_x and input_x_fps are the same if self.fps_rate is not None: input_x_fps = index_points_cuda(input_x.transpose( 1, 2), fps_idx).transpose( 1, 2) # it is only used for tr-like downsample block phi = self.phi(input_x_fps) else: phi = self.phi(input_x) phi = phi[:, :, :, None].repeat(1, 1, 1, k) psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx) self.skip_knn = True alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(), idx) # [bs, xyz, npoint, k] if self.POS_ENCODING: relative_xyz = input_p_fps.permute( [0, 2, 1])[:, :, :, None] - grouped_input_p pos_encoding = self.delta(relative_xyz) # [bs, dims, npoint, k] if self.use_vector_attn: # the attn_map: [vector_dim]; # the alpha: [out_dim] if self.POS_ENCODING: # if V_POS and QK_POS is both false, then apply all pos_encoding assert ( self.V_POS_ONLY and self.QK_POS_ONLY ) is False # only one of the V_ONLY and QK_ONLY should be applied if self.V_POS_ONLY: attn_map = F.softmax(self.gamma(phi - psi), dim=-1) else: attn_map = F.softmax(self.gamma(phi - psi + pos_encoding), dim=-1) if self.QK_POS_ONLY: y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1, 1) * (alpha) else: y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1, 1) * (alpha + pos_encoding) else: attn_map = F.softmax(self.gamma(phi - psi), dim=-1) y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1, 1) * (alpha) if self.MAX_POOL: y = y.max(dim=-1)[0] else: y = y.sum(dim=-1) else: assert self.POS_ENCODING == True phi = phi.reshape(B, h, self.out_dim // h, npoint, k) psi = psi.reshape(B, h, self.out_dim // h, npoint, k) attn_map = F.softmax( (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding, dim=-1) y = attn_map * (alpha + pos_encoding) y = y.sum(dim=-1) self.register_buffer('attn_map', attn_map.mean(dim=1)) y = self.linear_down(y) if self.fps_rate is not None: input_p_reduced = input_p_fps.transpose(1, 2) # input_x_reduced = index_points_cuda(y.transpose(1,2), fps_idx).transpose(1,2) # [B, dim, M] input_x_reduced = y return input_p_reduced, input_x_reduced else: input_p_reduced = input_p_fps.transpose(1, 2) input_x_reduced = y + res return input_p_reduced, input_x_reduced
def forward(self, input_p, input_x): ''' input_p: B, 3, npoint input_x: B, in_dim, npoint ''' B, in_dim, npoint = list(input_x.size()) n_sample = self.n_sample k = min(n_sample, npoint) h = self.nhead res = input_x input_p = input_p.permute([0, 2, 1]) # DEBUG: error here is that in the last block only 4-points; # however the knn still gives 16 idxs if npoint < self.n_sample: self.knn = KNN(k=npoint, transpose_mode=True) # DEBUG ONLY: using the input_x: feature space knn! # _, idx = self.knn(input_x.transpose(1,2), input_x.transpose(1,2)) _, idx = self.knn(input_p.contiguous(), input_p) idx = idx.int() grouped_input_p = grouping_operation_cuda( input_p.transpose(1, 2).contiguous(), idx) # [bs, xyz, npoint, k] if self.pre_ln: input_x = self.ln_top(input_x.transpose(1, 2)).transpose(1, 2) input_x = self.linear_top(input_x) # TODO: apply the layer-norm # however the original is [bs, dim, npoint] if self.pre_ln: input_x = self.ln_attn(input_x.transpose(1, 2)).transpose(1, 2) # grouped_input_x = index_points(input_x.permute([0,2,1]), idx.long()).permute([0,3,1,2]) # grouped_input_x = grouping_operation_cuda(input_x.contiguous(), idx) # [bs, xyz, npoint, K] phi = self.phi(input_x) phi = phi[:, :, :, None].repeat(1, 1, 1, k) psi = grouping_operation_cuda(self.psi(input_x).contiguous(), idx) alpha = grouping_operation_cuda(self.alpha(input_x).contiguous(), idx) # [bs, xyz, npoint, k] relative_xyz = input_p.permute([0, 2, 1])[:, :, :, None] - grouped_input_p pos_encoding = self.delta(relative_xyz) # [bs, dims, npoint, k] if self.use_vector_attn: # the attn_map: [vector_dim]; # the alpha: [out_dim] attn_map = F.softmax(self.gamma(phi - psi + pos_encoding), dim=-1) y = attn_map.repeat(1, self.out_dim // self.vector_dim, 1, 1) * (alpha + pos_encoding) y = y.sum(dim=-1) else: phi = phi.reshape(B, h, self.out_dim // h, npoint, k) psi = psi.reshape(B, h, self.out_dim // h, npoint, k) attn_map = F.softmax( (phi * psi).reshape(B, self.out_dim, npoint, k) + pos_encoding, dim=-1) y = attn_map * (alpha + pos_encoding) y = y.sum(dim=-1) if self.pre_ln: y = self.ln_down(y.transpose(1, 2)).transpose(1, 2) y = self.linear_down(y) return y + res, attn_map.detach().cpu().data