def forward(self, center_coords, coords, features): knn_indexes, _ = k_nearest_neighbors(center_coords, coords, self.k * self.dilation) knn_indexes = knn_indexes[:, :, ::self.dilation] knn_coords = index2points(coords, knn_indexes) knn_local_coords = localize(center_coords, knn_coords) # (B,C,N,k) feature_d = self.mlp_d(knn_local_coords) if self.point_feature_size == 0: feature_a = feature_d else: knn_features = index2points(features, knn_indexes) feature_a = torch.cat([feature_d, knn_features], dim=1) # [B, C+add_C, N, k] if self.use_x_transformation: trans = self.x_trans(feature_a) fx = self.transform(feature_a, trans) else: fx = feature_a fx = self.conv1(fx) fx = torch.squeeze(fx, dim=3).contiguous() if self.with_global: fts_global = self.linear1(center_coords) res = torch.can([fts_global, fx], dim=1) else: res = fx return res
def forward(self, x, coords): # Get p and x of fps indices fps_indices = furthest_point_sampling(coords, self.num_samples) fps_coords = index2points(coords, fps_indices) # p # Get knn indices knn_indices, _ = k_nearest_neighbors(fps_coords, coords, self.k) knn_x = index2points(x, knn_indices) # MLP knn_mlp_x = self.mlp(knn_x) # Use local max pooling. y, _ = torch.max(knn_mlp_x, dim=-1) return y, fps_coords
def subsampling(self, coords, num_samples): B, C, N = coords.shape sampled_point_indices = random_sampling(N, num_samples) sampled_point_indices = torch.tensor(sampled_point_indices, device=coords.device).view( 1, num_samples).contiguous() sampled_point_indices = sampled_point_indices.repeat(B, 1).contiguous() center_coords = index2points(coords, sampled_point_indices) return center_coords
def get_graph_feature(x, k=20, memory_saving=False): B, C, N = x.shape k_idx, _ = k_nearest_neighbors(x, x, k) feature = index2points(x, k_idx) x = x.view(B, C, N, 1).repeat(1, 1, 1, k) # x = torch.unsqueeze(x, dim=-1) x = torch.cat((feature - x, x), dim=1) return x
def forward(self, x, coords): # Get p and x of fps indices # fps_indices = furthest_point_sampling(coords, self.num_samples) # fps_coords = index2points(coords, fps_indices) # p # Get knn indices knn_indices, _ = k_nearest_neighbors(coords, coords, self.k) # Get local coords knn_coords = index2points(coords, knn_indices) local_coods = localize(coords, knn_coords) * -1 # get knn_features and use MLP knn_x = index2points(x, knn_indices) knn_mlp_x = self.mlp(knn_x) # Use local max pooling. y = pooled features y, _ = torch.max(knn_mlp_x, dim=-1) return y, knn_mlp_x, local_coods
def forward(self, x, coords): # Get knn indices knn_indices, _ = k_nearest_neighbors(coords, coords, self.k) knn_x = index2points(x, knn_indices) # MLP knn_mlp_x = self.mlp(knn_x) # Use local max pooling. y, _ = torch.max(knn_mlp_x, dim=-1) return y, coords
def group_layer(coords, center_coords, num_samples, radius, points=None): """ Group layer in PointNet++ Parameters ---------- coords : torch.tensor [B, 3, N] xyz tensor center_coords : torch.tensor [B, 3, N'] xyz tensor of ball query centers num_samples : int maximum number of samples for ball query radius : float radius of ball query points : torch.tensor [B, C, N] Concatenate points to return value. Return ------ new_points : torch.tensor [B, 3, N', num_samples] or [B, 3+C, N', num_samples] If points is not None, new_points shape is [B, 3+C, N', num_samples]. """ # Get sampled coords idx by ball query. idx = ball_query(center_coords, coords, radius, num_samples) idx = idx.type(torch.long) # Convert idx to coords grouped_coords = index2points(coords, idx) center_coords = torch.unsqueeze(center_coords, 3) grouped_coords_norm = grouped_coords - center_coords if points is not None: grouped_points = index2points(points, idx) new_points = torch.cat([grouped_coords_norm, grouped_points], dim=1) # note: PointNetSetAbstractionMsg is different order of concatenation. # https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/2d08fa40635cc5eafd14d19d18e3dc646171910d/models/pointnet_util.py#L253 else: new_points = grouped_coords_norm return new_points
def forward(self, features, coords): """ Parameters ---------- features: torch.tensor (B, C, N) coords: torch.tensor (B, 3, N) """ # Get knn indexes. knn_indices, _ = k_nearest_neighbors(coords, coords, self.k) # get (B, N, k) # Get delta. outputs_delta = self.pe_delta(coords, knn_indices) # get (B, C, N, k) # Get pointwise feature. outputs_phi, outputs_psi, outputs_alpha = torch.chunk( self.input_linear(features), chunks=3, dim=1) # to (B, C, N) x 3 # Get weights. outputs_psi = index2points(outputs_psi, knn_indices) # to (B, C, N, k) inputs_gamma = localize( outputs_phi, outputs_psi) * -1 + outputs_delta # get (B, C, N, k) outputs_gamma = self.mlp_gamma(inputs_gamma) outputs_rho = self.normalization_rho(outputs_gamma) # \alpha(x_j) + \delta outputs_alpha = index2points(outputs_alpha, knn_indices) # to (B, C, N, k) outputs_alpha_delta = outputs_alpha + outputs_delta # outputs_alpha_delta = outputs_alpha # compute value with hadamard product and aggregation outputs_hp = outputs_rho * outputs_alpha_delta outputs_aggregation = torch.sum(outputs_hp, dim=-1) # get (B, C, N) return outputs_aggregation
def forward(self, coords, knn_indices): """ Parameters ---------- coords: torch.tensor (B, C, N) """ # Get spaces between points. knn_coords = index2points(coords, knn_indices) coords_space = localize(coords, knn_coords) * -1 # Use theta. outputs = self.mlp_theta(coords_space) return outputs
def forward(self, f_sem, f_ins): adapted_f_sem = self.adaptation(f_sem) # for E_INS f_sins = f_ins + adapted_f_sem e_ins = self.ins_emb_fc(f_sins) # for P_SEM nn_idx, _ = py_k_nearest_neighbors(e_ins, e_ins, self.k, memory_saving=True) k_f_sem = index2points(f_sem, nn_idx) f_isem = torch.max(k_f_sem, dim=3, keepdim=True)[0] f_isem = torch.squeeze(f_isem, dim=3) p_sem = self.sem_pred_fc(f_isem) return p_sem, e_ins
def forward(self, xyz1, xyz2, points1, points2): """ Parameters ---------- xyz1: xyz of center poitns xyz2: xyz of all points points1: features of center points points2: features of all points Note ---- xyz1 > xyz2 """ B, C, N = xyz1.shape _, _, S = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, 1, N) else: # xyz1 = xyz1.permute(0,2,1).contiguous() # xyz2 = xyz2.permute(0,2,1).contiguous() # dists, idxs = three_nn(xyz1, xyz2) idxs, dists = k_nearest_neighbors(xyz1, xyz2, 3) dist_recip = 1.0 / (dists + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm a = index2points(points2, idxs) w = weight.view(B, 1, N, 3) aw = a * w interpolated_points = torch.sum(aw, dim=3) if points1 is not None: new_points = torch.cat([points1, interpolated_points], dim=1) else: new_points = interpolated_points new_points = self.mlp(new_points) return new_points
def sampling_layer(coords, num_samples): """ Sampling layer in PointNet++ Parameters ---------- coords : torch.tensor [B, 3, N] xyz tensor num_samples : int number of samples for furthest point sample Return ------ sampled_coords : torch.tensor [B, 3, num_samples] sampled xyz using furthest point sample """ fps_idx = furthest_point_sampling(coords, num_samples) # fps_idx = batch_fps(coords, num_samples) fps_idx = fps_idx.type(torch.long) sampled_coords = index2points(coords, fps_idx) return sampled_coords
def forward(self, xyz1, xyz2, points1, points2): """ Parameters ---------- xyz1: xyz of all poitns xyz2: xyz of center points points1: features of all points points2: features of center points """ B, C, N = xyz1.shape _, _, S = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, 1, N) else: idxs, dists = py_k_nearest_neighbors(xyz1, xyz2, 3, memory_saving=True) dist_recip = 1.0 / (dists + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm a = index2points(points2, idxs) w = weight.view(B, 1, N, 3) aw = a * w interpolated_points = torch.sum(aw, dim=3) if points1 is not None: new_points = torch.cat([points1, interpolated_points], dim=1) else: new_points = interpolated_points new_points = self.mlp(new_points) return new_points
# gt_outs = other.index2points(point_clouds, center_idxs) # t = timecheck(t, "index2points:") # acc = outs == gt_outs # print(False in (acc)) # # print(acc) # # print(torch.sum(acc)) # # print(outs.shape) # # print(outs) # exit() k = 1 for data in loader: point_clouds, sem_label, ins_label = data point_clouds = point_clouds[:, :, :3].transpose(1, 2).to(device) center_idxs = furthest_point_sampling(point_clouds, 1024) center_points = other.index2points(point_clouds, center_idxs) print(point_clouds.shape, center_points.shape) knn_idxs, _ = k_nearest_neighbors(center_points, point_clouds, k) t = timecheck() outs = other.gather(point_clouds, knn_idxs) t = timecheck(t, "gather:") gt_outs = other.index2points(point_clouds, knn_idxs) t = timecheck(t, "index2points:") acc = outs == gt_outs print(False in (acc)) # print(acc) # print(torch.sum(acc)) # print(outs.shape) # print(outs) exit()