def forward(self, x: Tuple[UFloatTensor, # (N, x, dims) UFloatTensor] # (N, x, C_in) ) -> Tuple[UFloatTensor, # (N, P, dims) UFloatTensor]: # (N, P, C_out) """ Given a point cloud, and its corresponding features, return a new set of fartest-points-sampled representative points with features projected from the point cloud. :param x: (pts, fts) where - pts: Regional point cloud such that fts[:,p_idx,:] is the feature associated with pts[:,p_idx,:]. - fts: Regional features such that pts[:,p_idx,:] is the feature associated with fts[:,p_idx,:]. :return: Randomly subsampled points and their features. """ pts, fts = x if 0 < self.P < pts.size()[1]: # Select random set of indices of subsampled points. fps_idx = farthest_point_sample(pts, self.P) # [N, P] rep_pts = index_points(pts, fps_idx) # [N, P, dim] else: # All input points are representative points. rep_pts = pts rep_pts_fts = self.pointcnn((rep_pts, pts, fts)) # [N, P, C_in] return rep_pts, rep_pts_fts
def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ xyz = xyz.permute(0, 2, 1) if points is not None: points = points.permute(0, 2, 1) B, N, C = xyz.shape S = self.npoint new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) new_points_list = [] for i, radius in enumerate(self.radius_list): K = self.nsample_list[i] group_idx = query_ball_point(radius, K, xyz, new_xyz) grouped_xyz = index_points(xyz, group_idx) grouped_xyz -= new_xyz.view(B, S, 1, C) if points is not None: grouped_points = index_points(points, group_idx) grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) else: grouped_points = grouped_xyz grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] for j in range(len(self.conv_blocks[i])): conv = self.conv_blocks[i][j] bn = self.bn_blocks[i][j] grouped_points = F.relu(bn(conv(grouped_points))) new_points = torch.max(grouped_points, 2)[0] # [B, D', S] new_points_list.append(new_points) new_xyz = new_xyz.permute(0, 2, 1) new_points_concat = torch.cat(new_points_list, dim=1) return new_xyz, new_points_concat
def forward(self, xyz1, xyz2, points1, points2): """ Input: xyz1: input points position data, [B, C, N] xyz2: sampled input points position data, [B, C, S] points1: input points data, [B, D, N] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ xyz1 = xyz1.permute(0, 2, 1) # (B, N, C) xyz2 = xyz2.permute(0, 2, 1) # (B, S, C) points2 = points2.permute(0, 2, 1) # (B, S, D) B, N, C = xyz1.shape _, S, _ = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, N, 1) else: dists = square_distance(xyz1, xyz2) # (B, N, S) dists, idx = dists.sort(dim=-1) # (B, N, S), # (B, N, S) dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3], (B, N, 3) dist_recip = 1.0 / (dists + 1e-8) # (B, N, 3) norm = torch.sum(dist_recip, dim=2, keepdim=True) # (B, N, 1) weight = dist_recip / norm # (B, N, 3) interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) # (B, N, D) if points1 is not None: points1 = points1.permute(0, 2, 1) # (B, N, C) new_points = torch.cat([points1, interpolated_points], dim=-1) # (B, N, C+D) else: new_points = interpolated_points # (B, N, D) new_points = new_points.permute(0, 2, 1) # (B, C+D, N) for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) return new_points # (B, mlp[-1], N)
def forward(self, xyz1, xyz2, points1, points2): """ Interpolate point with densenet Input: xyz1: input points position data, [B, N, C] xyz2: sampled input points position data, [B, S, C] points1: input points data, [B, N, D] points2: input points data, [B, S, D'] Return: new_points_res: upsampled points data, [B, N, D''] """ B, N, C = xyz1.shape _, S, _ = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, N, 1) else: assert (S >= 3) dists = square_distance(xyz1, xyz2) # (B, N, S) dists, idx = dists.sort(dim=-1) # (B, N, S) dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3], [B, N, 3] dists[dists < 1e-10] = 1e-10 weight = 1.0 / dists # [B, N, 3] weight = weight / torch.sum(weight, dim=-1).view(B, N, 1) # [B, N, 3] interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) # (B, N, D') if points1 is not None: # points1 = points1.permute(0, 2, 1) new_points = torch.cat([points1, interpolated_points], dim=-1) # [B, N, D+D'] else: new_points = interpolated_points new_points = new_points.permute(0, 2, 1) # [B, D+D', N] new_points_dense = self.pnfpdensesnet(new_points) # [B, D'', N] new_points_dense = new_points_dense.permute(0, 2, 1) # [B, N, D''] return new_points_dense