def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ device = xyz.device B, C, N = xyz.shape xyz_t = xyz.permute(0, 2, 1).contiguous() # [B, N, C] fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) # [B, npoint] new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, 3, npoint] new_xyz_t = new_xyz.permute(0, 2, 1).contiguous() _, idx = pointutils.knn(self.nsample, new_xyz_t, xyz_t) # [B, npoint, nsample] neighbors = pointutils.grouping_operation( xyz, idx) # [B, 3, npoint, nsample] centers = new_xyz.view(B, -1, self.npoint, 1).repeat( 1, 1, 1, self.nsample) # [B, 3, npoint, nsample] pos_diff = centers - neighbors # [B, 3, npoint, nsample] distances = torch.norm(pos_diff, p=2, dim=1, keepdim=True) # [B, 1, npoint, nsample] h_xi_xj = torch.cat([distances, pos_diff, centers, neighbors], dim=1) # [B, 1+3+3+3, npoint, nsample] x = pointutils.grouping_operation(points, idx) # [B, D, npoint, nsample] x = torch.cat([neighbors, x], dim=1) # [B, D+3, npoint, nsample] h_xi_xj = self.mapping_func2( F.relu(self.bn_mapping( self.mapping_func1(h_xi_xj)))) # [B, c_in, npoint, nsample] if self.first_layer: x = F.relu(self.bn_xyz_raising( self.xyz_raising(x))) # [B, c_in, npoint, nsample] x = F.relu(self.bn_rsconv(torch.mul(h_xi_xj, x))) # (B, c_in, npoint, nsample) for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] x = F.relu(bn(conv(x))) # [B, c_out, npoint, nsample] x = torch.max(x, -1)[0] # [B, c_out, npoint] # x = F.relu(self.bn_channel_raising(self.cr_mapping(x))) # [B, c_out, npoint] return new_xyz, x
def forward(self, pos1, pos2, feature1, feature2): """ Input: xyz1: input points position data, [B, C, N] xyz2: sampled input points position data, [B, C, S] points1: input points data, [B, D, N] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ pos1_t = pos1.permute(0, 2, 1).contiguous() pos2_t = pos2.permute(0, 2, 1).contiguous() B, C, N = pos1.shape # dists = square_distance(pos1, pos2) # dists, idx = dists.sort(dim=-1) # dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] dists, idx = pointutils.three_nn(pos1_t, pos2_t) # [B, N, K=3] dists[dists < 1e-10] = 1e-10 weight = 1.0 / dists weight = weight / torch.sum(weight, -1, keepdim=True) interpolated_feat = torch.sum( pointutils.grouping_operation(feature2, idx) * weight.view(B, 1, N, 3), dim=-1) # [B, C, N, S=3] -> [B, C, N] if feature1 is not None: feat_new = torch.cat([interpolated_feat, feature1], 1) else: feat_new = interpolated_feat for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] feat_new = F.relu(bn(conv(feat_new))) return feat_new
def forward(self, pos1, pos2, feature1, feature2): """ Feature propagation from xyz2 (less points) to xyz1 (more points) Inputs: xyz1: (batch_size, 3, npoint1) xyz2: (batch_size, 3, npoint2) feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points) feat2: (batch_size, channel1, npoint2) features for xyz2 points Output: feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3) TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating. """ pos1_t = pos1.permute(0, 2, 1).contiguous() pos2_t = pos2.permute(0, 2, 1).contiguous() B, C, N = pos1.shape if self.knn: _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) # [B, N1, S] else: idx = pointutils.ball_query(self.radius, self.nsample, pos2_t, pos1_t) pos2_grouped = pointutils.grouping_operation(pos2, idx) pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N1, S] feat2_grouped = pointutils.grouping_operation(feature2, idx) feat_new = torch.cat([feat2_grouped, pos_diff], dim=1) # [B, C1+3, N1, S] for conv in self.mlp1_convs: feat_new = conv(feat_new) # max pooling feat_new = feat_new.max(-1)[0] # [B, mlp1[-1], N1] # concatenate feature in early layer if feature1 is not None: feat_new = torch.cat([feat_new, feature1], dim=1) # [B, mlp1[-1]+feat1_channel, N1] for conv in self.mlp2_convs: feat_new = conv(feat_new) return feat_new
def forward(self, pos1, pos2, feature1, feature2): """ Input: xyz1: (batch_size, 3, npoint) xyz2: (batch_size, 3, npoint) feat1: (batch_size, channel, npoint) feat2: (batch_size, channel, npoint) Output: xyz1: (batch_size, 3, npoint) feat1_new: (batch_size, mlp[-1], npoint) """ pos1_t = pos1.permute(0, 2, 1).contiguous() pos2_t = pos2.permute(0, 2, 1).contiguous() B, N, C = pos1_t.shape if self.knn: _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) # [B, N, S] else: idx = pointutils.ball_query(self.radius, self.nsample, pos2_t, pos1_t) pos2_grouped = pointutils.grouping_operation(pos2, idx) # [B, 3, N, S] pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N, S] feat2_grouped = pointutils.grouping_operation(feature2, idx) # [B, C, N, S] if self.corr_func == 'concat': feat_diff = torch.cat([ feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample) ], dim=1) # [B, 2*C, N, S] feat1_new = torch.cat([pos_diff, feat_diff], dim=1) # [B, 2*C+3, N, S] for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] feat1_new = F.relu(bn(conv(feat1_new))) feat1_new = torch.max(feat1_new, -1)[0] # [B, mlp[-1], npoint] return pos1, feat1_new
def group(p: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """Group point cloud indices. Args: p: Reference point cloud of shape [batch_size, dim, num_point]. idx: Indices tensor of shape [batch_size, num_query, k]. Returns: A tensor of shape [batch_size, dim, num_query, k]. """ p = p.contiguous() return _PU.grouping_operation(p, idx)
def forward(self,point_cloud): dist,idx=self.KNN(point_cloud,point_cloud) ''' idx is batch_size,k,n_points point_cloud is batch_size,n_dims,n_points point_cloud_neightbors is batch_size,n_dims,k,n_points ''' idx=idx[:,1:,:] point_cloud_neighbors=grouping_operation(point_cloud,idx.contiguous().int()) point_cloud_central=point_cloud.unsqueeze(2).repeat(1,1,self.k,1) #print(point_cloud_central.shape,point_cloud_neighbors.shape) edge_feature=torch.cat([point_cloud_central,point_cloud_neighbors-point_cloud_central],dim=1) return edge_feature,idx return dist,idx
def get_uniform_loss(self, pcd, percentage=[0.004, 0.006, 0.008, 0.010, 0.012], radius=1.0): B, N, C = pcd.shape[0], pcd.shape[1], pcd.shape[2] npoint = int(N * 0.05) loss = 0 further_point_idx = pn2_utils.furthest_point_sample( pcd.permute(0, 2, 1).contiguous(), npoint) new_xyz = pn2_utils.gather_operation( pcd.permute(0, 2, 1).contiguous(), further_point_idx) # B,C,N for p in percentage: nsample = int(N * p) r = math.sqrt(p * radius) disk_area = math.pi * (radius**2) / N idx = pn2_utils.ball_query(r, nsample, pcd.contiguous(), new_xyz.permute( 0, 2, 1).contiguous()) #b N nsample expect_len = math.sqrt(disk_area) grouped_pcd = pn2_utils.grouping_operation( pcd.permute(0, 2, 1).contiguous(), idx) #B C N nsample grouped_pcd = grouped_pcd.permute(0, 2, 3, 1) #B N nsample C grouped_pcd = torch.cat(torch.unbind(grouped_pcd, dim=1), dim=0) #B*N nsample C dist, _ = self.knn_uniform(grouped_pcd, grouped_pcd) #print(dist.shape) uniform_dist = dist[:, :, 1:] #B*N nsample 1 uniform_dist = torch.abs(uniform_dist + 1e-8) uniform_dist = torch.mean(uniform_dist, dim=1) uniform_dist = (uniform_dist - expect_len)**2 / (expect_len + 1e-8) mean_loss = torch.mean(uniform_dist) mean_loss = mean_loss * math.pow(p * 100, 2) loss += mean_loss return loss / len(percentage)