def forward(self, x): point_cloud1 = x[:, 0:3, :] point_cloud1 = point_cloud1.transpose(1, 2).contiguous() x0 = F.relu(self.conv1(x)) # 24 x1 = F.relu(self.dense_conv1(x0)) # 24 + 24 * 3 = 96 x1 = torch.cat((x1, x0), 1) # 120 x1d, _, _, point_cloud2 = edge_preserve_sampling( x1, point_cloud1, self.hierarchy[0], self.k) # 240 x2 = F.relu(self.conv2(x1d)) # 48 x2 = F.relu(self.dense_conv2(x2)) # 48 + 24 * 3 = 120 x2 = torch.cat((x2, x1d), 1) # 120 + 240 = 360 x2d, _, _, point_cloud3 = edge_preserve_sampling( x2, point_cloud2, self.hierarchy[1], self.k) # 720 x3 = F.relu(self.conv3(x2d)) x3 = F.relu(self.dense_conv3(x3)) x3 = torch.cat((x3, x2d), 1) x3d, _, _, point_cloud4 = edge_preserve_sampling( x3, point_cloud3, self.hierarchy[2], self.k) x4 = F.relu(self.conv4(x3d)) x4 = F.relu(self.dense_conv4(x4)) x4 = torch.cat((x4, x3d), 1) global_feat = self.gf_conv(x4) global_feat, _ = torch.max(global_feat, -1) global_feat = F.relu(self.fc1(global_feat)) global_feat = F.relu(self.fc2(global_feat)).unsqueeze(2).repeat( 1, 1, self.hierarchy[2]) x4 = torch.cat((global_feat, x4), 1) x4 = F.relu(self.conv5(x4)) idx, weight = three_nn_upsampling(point_cloud3, point_cloud4) x4 = pn2.three_interpolate(x4, idx, weight) x3 = torch.cat((x3, x4), 1) x3 = F.relu(self.conv6(x3)) idx, weight = three_nn_upsampling(point_cloud2, point_cloud3) x3 = pn2.three_interpolate(x3, idx, weight) x2 = torch.cat((x2, x3), 1) x2 = F.relu(self.conv7(x2)) idx, weight = three_nn_upsampling(point_cloud1, point_cloud2) x2 = pn2.three_interpolate(x2, idx, weight) x1 = torch.cat((x1, x2), 1) x1 = self.conv8(x1) return x1
def interpolate_func(inputs): idx = torch.from_numpy(np.array([[[0, 1, 2], [1, 2, 3]]])).int().cuda() weight = torch.from_numpy(np.array([[[1, 1, 1], [2, 2, 2]]])).float().cuda() interpolated_feats = pointnet2_utils.three_interpolate( inputs, idx, weight) return interpolated_feats
def forward(self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor) -> torch.Tensor: """ :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features :param known: (B, m, 3) tensor of the xyz positions of the known features :param unknow_feats: (B, C1, n) tensor of the features to be propigated to :param known_feats: (B, C2, m) tensor of features to be propigated :return: new_features: (B, mlp[-1], n) tensor of the features of the unknown features """ if known is not None: dist, idx = pointnet2_utils.three_nn(unknown, known) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = pointnet2_utils.three_interpolate( known_feats, idx, weight) else: interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) if unknow_feats is not None: new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) else: new_features = interpolated_feats new_features = new_features.unsqueeze(-1) new_features = self.mlp(new_features) return new_features.squeeze(-1)
def forward( self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor ) -> torch.Tensor: r""" Parameters ---------- unknown : torch.Tensor (B, n, 3) tensor of the xyz positions of the unknown features known : torch.Tensor (B, m, 3) tensor of the xyz positions of the known features unknow_feats : torch.Tensor (B, C1, n) tensor of the features to be propigated to known_feats : torch.Tensor (B, C2, m) tensor of features to be propigated Returns ------- new_features : torch.Tensor (B, mlp[-1], n) tensor of the features of the unknown features """ dist, idx = pointnet2_utils.three_nn(unknown, known) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = pointnet2_utils.three_interpolate( known_feats, idx, weight ) if unknow_feats is not None: new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) #(B, C2 + C1, n) else: new_features = interpolated_feats new_features = new_features.unsqueeze(-1) new_features = self.mlp(new_features) return new_features.squeeze(-1)
def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, query_features: torch.Tensor, support_features: torch.Tensor) -> torch.Tensor: """ :param query_xyz: (B, n, 3) tensor of the xyz positions of the unknown features :param support_xyz: (B, m, 3) tensor of the xyz positions of the known features :param query_features: (B, C1, n) tensor of the features to be propagated to :param support_features: (B, C2, m) tensor of features to be propagated :return: new_features: (B, mlp[-1], n) tensor of the features of the unknown features """ # nearest neighbor interpolation with inverse distance weight (k=3) if support_xyz is not None: dist, idx = pointnet2_utils.three_nn( query_xyz, support_xyz) # (B,n,3) (B,n,3) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = pointnet2_utils.three_interpolate( support_features, idx, weight) # (B,C2,n) else: interpolated_feats = support_features.expand( *support_features.size()[0:2], query_xyz.size(1)) if query_features is not None: new_features = torch.cat([interpolated_feats, query_features], dim=1) # (B, C2 + C1, n) else: new_features = interpolated_feats new_features = new_features.unsqueeze(-1) # (B, C2 + C1, n, 1) new_features = self.mlp(new_features) # (B, mlp[-1], n,1) return new_features.squeeze(-1)
def forward(self, x_a: ME.SparseTensor, x_b: ME.SparseTensor): """ Input: M < N xyz_1: input points position data, [B, 3, M] xyz_2: input points position data, [B, 3, N] points_1: input points data, [B, C, M] points_2: input points data, [B, C, N] interpolate xyz_2's coordinates feature with knn neighbor's features weighted by inverse distance TODO: For POINT_TR_LIKE, add support for no x_b is fed, simply upsample the x_a Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ if self.POINT_TR_LIKE: dim = x_b.F.shape[1] assert dim == self.out_dim x_ac, mask_a, idx_a = separate_batch(x_a.C) B = x_ac.shape[0] N_a = x_ac.shape[1] x_af = torch.zeros(B * N_a, dim).cuda() idx_a = idx_a.reshape(-1, 1).repeat(1, dim) x_af.scatter_(dim=0, index=idx_a, src=self.linear_a(x_a.F)) x_af = x_af.reshape([B, N_a, dim]) x_bc, mask_b, idx_b = separate_batch(x_b.C) B = x_bc.shape[0] N_b = x_bc.shape[1] x_bf = torch.zeros(B * N_b, dim).cuda() idx_b = idx_b.reshape(-1, 1).repeat(1, dim) x_bf.scatter_(dim=0, index=idx_b, src=self.linear_b(x_b.F)) x_bf = x_bf.reshape([B, N_b, dim]) dists, idx = three_nn(x_bc.float(), x_ac.float()) mask = (dists.sum(dim=-1) > 0).unsqueeze(-1).repeat(1, 1, 3) dist_recip = 1.0 / (dists + 1e-1) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm weight = weight * mask # mask the zeros part interpolated_points = three_interpolate( x_af.transpose(1, 2).contiguous(), idx, weight).transpose(1, 2) # [B, N_b, dim] out = interpolated_points + x_bf out = torch.gather( out.reshape(B * N_b, dim), dim=0, index=idx_b) # should be the same size with x_a.F x = ME.SparseTensor(features=out, coordinate_map_key=x_b.coordinate_map_key, coordinate_manager=x_b.coordinate_manager) else: if self.SUM_FEATURE: x_a = self.conv_a(x_a) x_b = self.conv_b(x_b) x = x_a + x_b else: x_a = self.conv(x_a) x_a = self.bn(x_a) x_a = self.relu(x_a) x = me.cat(x_a, x_b) x = self.out_conv(x) x = self.out_bn(x) x = self.out_relu(x) return x
def _edge_unpooling(self, features, src_pts, tgt_pts): features = features.squeeze(2) idx, weight = three_nn_upsampling(tgt_pts, src_pts) features = pn2.three_interpolate(features, idx, weight) features = features.unsqueeze(2) return features
def forward(self, xyzs: torch.Tensor, original_xyzs: torch.Tensor, features: torch.Tensor, original_features: torch.Tensor = None) -> torch.Tensor: r""" Parameters ---------- xyzs : torch.Tensor (B, L', N', 3) tensor of the xyz positions of the convolved features original_xyzs : torch.Tensor (B, L, N, 3) tensor of the xyz positions of the original points features : torch.Tensor (B, L', C', N') tensor of the features to be propigated to original_features : torch.Tensor (B, L, C, N) tensor of original point features for skip connection Returns ------- new_features : torch.Tensor (B, L, C", N) tensor of the features of the unknown features """ L1 = original_xyzs.size(1) N1 = original_xyzs.size(2) L2 = xyzs.size(1) N2 = xyzs.size(2) if self.temporal_kernel_size > 1 and self.temporal_stride > 1: assert ((L2 - 1) * self.temporal_stride + sum(self.temporal_padding) + self.temporal_kernel_size == L1), "PSTConvTranspose: Temporal parameter error!" xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1) xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs] features = torch.split(tensor=features, split_size_or_sections=1, dim=1) features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features] new_xyzs = original_xyzs original_xyzs = torch.split(tensor=original_xyzs, split_size_or_sections=1, dim=1) original_xyzs = [torch.squeeze(input=original_xyz, dim=1).contiguous() for original_xyz in original_xyzs] if original_features is not None: original_features = torch.split(tensor=original_features, split_size_or_sections=1, dim=1) original_features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in original_features] # temporal transposed convolution temporal_trans_features = [] for feature in features: feature = self.temporal_conv(feature) feature = torch.split(tensor=feature, split_size_or_sections=self.mid_planes, dim=1) temporal_trans_features.append(feature) # temporal interpolation temporal_interpolated_xyzs = [] temporal_interpolated_features = [] middles = [] deltas = [] for t2 in range(1, L2+1): middle = t2 + (t2-1)*(self.temporal_stride-1) + self.temporal_radius + self.temporal_padding[0] middles.append(middle) delta = range(middle - self.temporal_radius, middle + self.temporal_radius + self.temporal_padding[1] + 1) deltas.append(delta) for t1 in range(1, L1+1): seed_xyzs = [] seed_features = [] for t2 in range(L2): delta = deltas[t2] if t1 in delta: seed_xyzs.append(xyzs[t2]) seed_feature = temporal_trans_features[t2][t1-middles[t2]+self.temporal_radius] if self.batch_norm: seed_feature = self.batch_norm(seed_feature) if self.activation: seed_feature = self.activation(seed_feature) seed_features.append(seed_feature) seed_xyzs = torch.cat(seed_xyzs, dim=1) seed_features = torch.cat(seed_features, dim=2) temporal_interpolated_xyzs.append(seed_xyzs) temporal_interpolated_features.append(seed_features) # spatial interpolation new_features = [] for t1 in range(L1): neighbor_xyz = temporal_interpolated_xyzs[t1] # [B, N', 3] anchor_xyz = original_xyzs[t1] # [B, N, 3] dist, idx = pointnet2_utils.three_nn(anchor_xyz, neighbor_xyz) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = pointnet2_utils.three_interpolate(temporal_interpolated_features[t1], idx, weight) if original_features is not None: new_feature = torch.cat([interpolated_feats, original_features[t1]], dim=1) else: new_feature = interpolated_feats new_feature = self.spatial_conv(new_feature) new_features.append(new_feature) new_features = torch.stack(tensors=new_features, dim=1) return new_xyzs, new_features