def get_uniform_loss(pcd, percentages=[0.004, 0.006, 0.008, 0.010, 0.012], radius=1.0): B, N, C = pcd.size() npoint = int(N * 0.05) loss = 0 for p in percentages: nsample = int(N*p) r = math.sqrt(p*radius) disk_area = math.pi * (radius ** 2) * p/nsample new_xyz = pn2.gather_operation(pcd.transpose(1, 2).contiguous(), pn2.furthest_point_sample(pcd, npoint)).transpose(1, 2).contiguous() idx = pn2.ball_query(r, nsample, pcd, new_xyz) expect_len = math.sqrt(disk_area) grouped_pcd = pn2.grouping_operation(pcd.transpose(1,2).contiguous(), idx) grouped_pcd = grouped_pcd.permute(0, 2, 3, 1).contiguous().view(-1, nsample, 3) var, _ = knn_point(2, grouped_pcd, grouped_pcd) uniform_dis = -var[:, :, 1:] uniform_dis = torch.sqrt(torch.abs(uniform_dis+1e-8)) uniform_dis = torch.mean(uniform_dis, dim=-1) uniform_dis = ((uniform_dis - expect_len)**2 / (expect_len + 1e-8)) mean = torch.mean(uniform_dis) mean = mean*math.pow(p*100,2) loss += mean return loss/len(percentages)
def edge_preserve_sampling(feature_input, point_input, num_samples, k=10): batch_size = feature_input.size()[0] feature_size = feature_input.size()[1] num_points = feature_input.size()[2] p_idx = pn2.furthest_point_sample(point_input, num_samples) point_output = pn2.gather_operation(point_input.transpose(1, 2).contiguous(), p_idx).transpose(1, 2).contiguous() # B M 3 pk = int(min(k, num_points)) _, pn_idx = knn_point(pk, point_input, point_output) pn_idx = pn_idx.detach().int() # B M pk # print(pn_idx.size()) # neighbor_feature = pn2.grouping_operation(feature_input, pn_idx) # neighbor_feature = index_points(feature_input.transpose(1,2).contiguous(), pn_idx).permute(0, 3, 1, 2) neighbor_feature = pn2.gather_operation(feature_input, pn_idx.view(batch_size, num_samples * pk)).view(batch_size, feature_size, num_samples, pk) neighbor_feature, _ = torch.max(neighbor_feature, 3) center_feature = pn2.grouping_operation(feature_input, p_idx.unsqueeze(2)).view(batch_size, -1, num_samples) net = torch.cat((center_feature, neighbor_feature), 1) return net, p_idx, pn_idx, point_output
def forward(self, P1: torch.Tensor, P2: torch.Tensor, X1: torch.Tensor, S2: torch.Tensor) -> (torch.Tensor): r""" Parameters ---------- P1: (B, N, 3) P2: (B, N, 3) X1: (B, C, N) S2: (B, C, N) Returns ------- S1: (B, C, N) """ # 1. Sample points idx = pointnet2_utils.ball_query(self.radius, self.nsamples, P2, P1) # (B, npoint, nsample) # 2.1 Group P2 points P2_flipped = P2.transpose(1, 2).contiguous() # (B, 3, npoint) P2_grouped = pointnet2_utils.grouping_operation( P2_flipped, idx) # (B, 3, npoint, nsample) # 2.2 Group P2 states S2_grouped = pointnet2_utils.grouping_operation( S2, idx) # (B, C, npoint, nsample) # 3. Calcaulate displacements P1_flipped = P1.transpose(1, 2).contiguous() # (B, 3, npoint) P1_expanded = torch.unsqueeze(P1_flipped, 3) # (B, 3, npoint, 1) displacement = P2_grouped - P1_expanded # (B, 3, npoint, nsample) # 4. Concatenate X1, S2 and displacement if self.in_channels != 0: X1_expanded = torch.unsqueeze(X1, 3) # (B, C, npoint, 1) X1_repeated = X1_expanded.repeat(1, 1, 1, self.nsamples) correlation = torch.cat(tensors=(S2_grouped, X1_repeated, displacement), dim=1) else: correlation = torch.cat(tensors=(S2_grouped, displacement), dim=1) # 5. Fully-connected layer (the only parameters) S1 = self.fc(correlation) # 6. Pooling S1 = torch.max(input=S1, dim=-1, keepdim=False)[0] return S1
def get_repulsion_loss(self, pred): _, idx = knn_point(self.nn_size, pred, pred, transpose_mode=True) idx = idx[:, :, 1:].to(torch.int32) # remove first one idx = idx.contiguous() # B, N, nn pred = pred.transpose(1, 2).contiguous() # B, 3, N grouped_points = grouping_operation( pred, idx) # (B, 3, N), (B, N, nn) => (B, 3, N, nn) grouped_points = grouped_points - pred.unsqueeze(-1) dist2 = torch.sum(grouped_points**2, dim=1) dist2 = torch.max(dist2, torch.tensor(self.eps).cuda()) dist = torch.sqrt(dist2) weight = torch.exp(-dist2 / self.h**2) uniform_loss = torch.mean((self.radius - dist) * weight) # uniform_loss = torch.mean(self.radius - dist * weight) # punet return uniform_loss
def get_repulsion_loss(pred, nsample=20, radius=0.07): # pred: (batch_size, npoint,3) # idx = pn2.ball_query(radius, nsample, pred, pred) idx = knn(pred.transpose(1, 2).contiguous(), nsample).int() pred_flipped = pred.transpose(1, 2).contiguous() grouped_pred = pn2.grouping_operation(pred_flipped, idx) # (B, C, npoint, nsample) grouped_pred -= pred_flipped.unsqueeze(-1) # get the uniform loss h = 0.03 dist_square = torch.sum(grouped_pred ** 2, dim=1) dist_square, idx = torch.topk(-dist_square, 5) dist_square = -dist_square[:, :, 1:] # remove the first one dist_square = torch.max(torch.FloatTensor([1e-12]).expand_as(dist_square).cuda(), dist_square) dist = torch.sqrt(dist_square) weight = torch.exp(-dist_square / h ** 2) uniform_loss = torch.mean(radius - dist * weight) return uniform_loss
def forward(self, xyzs: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): """ Args: xyzs: torch.Tensor (B, L, N, 3) tensor of sequence of the xyz coordinates features: torch.Tensor (B, L, C, N) tensor of sequence of the features """ device = xyzs.get_device() nframes = xyzs.size(1) # L npoints = xyzs.size(2) # N if self.temporal_kernel_size > 1 and self.temporal_stride > 1: assert ((nframes + sum(self.temporal_padding) - self.temporal_kernel_size) % self.temporal_stride == 0), "PSTConv: Temporal parameter error!" xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1) xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs] if self.in_planes != 0: features = torch.split(tensor=features, split_size_or_sections=1, dim=1) features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features] if self.padding_mode == "zeros": xyz_padding = torch.zeros(xyzs[0].size(), dtype=torch.float32, device=device) for i in range(self.temporal_padding[0]): xyzs = [xyz_padding] + xyzs for i in range(self.temporal_padding[1]): xyzs = xyzs + [xyz_padding] if self.in_planes != 0: feature_padding = torch.zeros(features[0].size(), dtype=torch.float32, device=device) for i in range(self.temporal_padding[0]): features = [feature_padding] + features for i in range(self.temporal_padding[1]): features = features + [feature_padding] else: # "replicate" for i in range(self.temporal_padding[0]): xyzs = [xyzs[0]] + xyzs for i in range(self.temporal_padding[1]): xyzs = xyzs + [xyzs[-1]] if self.in_planes != 0: for i in range(self.temporal_padding[0]): features = [features[0]] + features for i in range(self.temporal_padding[1]): features = features + [features[-1]] new_xyzs = [] new_features = [] for t in range(self.temporal_radius, len(xyzs)-self.temporal_radius, self.temporal_stride): # temporal anchor frames # spatial anchor point subsampling by FPS anchor_idx = pointnet2_utils.furthest_point_sample(xyzs[t], npoints//self.spatial_stride) # (B, N//self.spatial_stride) anchor_xyz_flipped = pointnet2_utils.gather_operation(xyzs[t].transpose(1, 2).contiguous(), anchor_idx) # (B, 3, N//self.spatial_stride) anchor_xyz_expanded = torch.unsqueeze(anchor_xyz_flipped, 3) # (B, 3, N//spatial_stride, 1) anchor_xyz = anchor_xyz_flipped.transpose(1, 2).contiguous() # (B, N//spatial_stride, 3) # spatial convolution spatial_features = [] for i in range(t-self.temporal_radius, t+self.temporal_radius+1): neighbor_xyz = xyzs[i] idx = pointnet2_utils.ball_query(self.r, self.k, neighbor_xyz, anchor_xyz) neighbor_xyz_flipped = neighbor_xyz.transpose(1, 2).contiguous() # (B, 3, N) neighbor_xyz_grouped = pointnet2_utils.grouping_operation(neighbor_xyz_flipped, idx) # (B, 3, N//spatial_stride, k) displacement = neighbor_xyz_grouped - anchor_xyz_expanded # (B, 3, N//spatial_stride, k) displacement = self.spatial_conv_d(displacement) # (B, mid_planes, N//spatial_stride, k) if self.in_planes != 0: neighbor_feature_grouped = pointnet2_utils.grouping_operation(features[i], idx) # (B, in_planes, N//spatial_stride, k) feature = self.spatial_conv_f(neighbor_feature_grouped) # (B, mid_planes, N//spatial_stride, k) if self.spatial_aggregation == "addition": spatial_feature = feature + displacement else: spatial_feature = feature * displacement else: spatial_feature = displacement if self.spatial_pooling == 'max': spatial_feature, _ = torch.max(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) elif self.spatial_pooling == 'sum': spatial_feature = torch.sum(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) else: spatial_feature = torch.mean(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) spatial_features.append(spatial_feature) spatial_features = torch.cat(tensors=spatial_features, dim=1, out=None) # (B, temporal_kernel_size*mid_planes, N//spatial_stride) # batch norm and relu if self.batch_norm: spatial_features = self.batch_norm(spatial_features) spatial_features = self.relu(spatial_features) # temporal convolution spatio_temporal_feature = self.temporal(spatial_features) new_xyzs.append(anchor_xyz) new_features.append(spatio_temporal_feature) new_xyzs = torch.stack(tensors=new_xyzs, dim=1) new_features = torch.stack(tensors=new_features, dim=1) return new_xyzs, new_features