def forward(self, pc_down, pc_up, feat_down, feat_up): ''' pc_down : B x 3 x N_small pc_up : B x 3 x N_large feat_down : B x C1 x N_small feat_up : B x C2 x N_large return : B x mlp[-1] x N_large ''' idx, dist = md_utils._knn_indices(feat=pc_down, k=3, centroid=pc_up, dist=True) # B x N_large x k dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm grouped_feat = md_utils._indices_group(feat_down, idx) # B x C1 x N_large x k weight = weight.unsqueeze(1) # B x 1 x N_large x k interpolated_feats = grouped_feat * weight interpolated_feats = torch.max(interpolated_feats, dim=-1)[0] # B x C1 x N_large interpolated_feats = torch.cat([interpolated_feats, feat_up], dim=1) # B x C1+C2 x N_large interpolated_feats = interpolated_feats.unsqueeze(-1) interpolated_feats = self.mlp(interpolated_feats) return interpolated_feats.squeeze(-1) # B x out_channel x N_large
def forward(self, pc): ''' pc_withnor : B x N x 6 or pc_withoutnor : B x N x3 ''' assert pc.size( )[2] == self.inchannel, 'illegal input pc size:{}'.format(pc.size()) B, N, _ = pc.size() pc = pc.permute(0, 2, 1) pc_xyz = pc[:, 0:3, :] idx = md_utils._knn_indices(pc_xyz, k=self.K_knn) grouped_xyz = md_utils._indices_group(pc_xyz, idx) #B x 3 x N x k grouped_pc = pc_xyz.unsqueeze(-1).expand(B, 3, N, self.K_knn) grouped_pc = grouped_xyz - grouped_pc feat_1 = self.spiderconv1(pc, idx, grouped_pc) # B x 64 x N feat_2 = self.spiderconv2(feat_1, idx, grouped_pc) feat_3 = self.spiderconv3(feat_2, idx, grouped_pc) feat_4 = self.spiderconv4(feat_3, idx, grouped_pc) cat_feat = torch.cat([feat_1, feat_2, feat_3, feat_4], dim=1) # B x 480 x N cat_feat = torch.topk(cat_feat, 2, dim=2)[0] # B x 480 x 2 cat_feat = cat_feat.view(B, -1) # B x 960 return cat_feat
def forward(self, pts, fts): assert pts.size()[1] == 3, 'illegal pointcloud size:{}'.format(pts.size()) B, _, N = pts.size() if self.P == -1: # P==-1 has two situation 1),input layer 2),sample number consistent within two layer qrs = pts self.P = N else: sample_indices = self.sample(pts.permute(0, 2, 1).contiguous()) sample_indices = sample_indices.unsqueeze(1).expand(B, 3, self.P) qrs = torch.gather(pts, 2, sample_indices) # B x 3 x P indices_dilated = md_utils._knn_indices(pts, k=self.K*self.D, centroid=qrs) #B x P x K*D indices = indices_dilated[:, :, ::self.D] # B x P x K nn_pts = md_utils._indices_group(pts, indices) # B x 3 x P x K nn_pts_center = qrs.unsqueeze(-1).expand_as(nn_pts) nn_pts_local = nn_pts - nn_pts_center # B x 3 x P x K nn_fts_from_pts = self.mlp_delta(nn_pts_local) # B x C_delta x P x K if fts is None: # in the first layer nn_fts_input = nn_fts_from_pts else: nn_fts_input = md_utils._indices_group(fts, indices) nn_fts_input = torch.cat([nn_fts_from_pts, nn_fts_input], dim=1) # B x C_delta+C_in x P x K X = self.X_transform0(nn_pts_local) # B x K*K x P x 1 X = X.view(B, self.K, self.P, self.K) X = self.X_transform1(X) X = X.view(B, self.K, self.P, self.K) X = self.X_transform2(X) X = X.view(B*self.P, self.K, self.K) fts_X = torch.bmm(nn_fts_input.permute(0,2,1,3).contiguous().view(B*self.P, -1, self.K), X) fts_X = fts_X.view(B, self.P, -1, self.K).permute(0, 2, 1, 3) # B x C_delta+C_in x P x K fts_conv = self.conv(fts_X).squeeze(-1) # B x C_out x P if self.with_global: fts_global = self.conv_global(qrs.unsqueeze(-1)).squeeze(-1) # B x C_out//4 x P return qrs, torch.cat([fts_global, fts_conv], dim=1) else : return qrs, fts_conv # B x C_out x P
def forward(self, pc, feat): ''' input --------------- pc : B x 3 x N feat : B x C x N output ---------------- pc_sample : B x 3 x npoint feat_sample : B x outchannel x npoint ''' B, _, N = pc.size() idx = self.fps(pc.permute(0, 2, 1).contiguous()) # B x npoint idx = idx.unsqueeze(1).expand(B, 3, self.npoint) pc_sample = torch.gather(pc, 2, idx) # B x 3 x npoint cat_feat = [] for i in range(len(self.mlp_layers)): indices, _ = self.query_ball_point[i](pc.contiguous(), pc_sample.contiguous()) grouped_pc = md_utils._indices_group( pc, indices) # B x 3 x npoint x nsample grouped_pc = grouped_pc - pc_sample.unsqueeze(-1).expand_as( grouped_pc) out_feat = grouped_pc.contiguous() if feat is not None: # feat will be None in the first layer grouped_feat = md_utils._indices_group( feat, indices) # B x C x npoint x nsample out_feat = torch.cat([grouped_pc, grouped_feat], dim=1) # B x C+3 x npoint x nsample out_feat = self.mlp_layers[i](out_feat) out_feat = torch.max(out_feat, -1)[0] # B x C_out x npoint cat_feat.append(out_feat) cat_feat = torch.cat(cat_feat, dim=1) # B x sum(mlp[-1]) x npoint return pc_sample, cat_feat
def forward(self, feat, idx, group_pc): ''' feat : B x in_channel x N idx(knn_indices) : B x N x k group_pc : B x 3 x N x k return: feat : B x out_channel x N ''' B, in_channel, N = feat.size() _, _, k = idx.size() assert k == self.K_knn, 'illegal k' group_feat = md_utils._indices_group(feat, idx) # B x inchannel x N x k X = group_pc[:, 0, :, :].unsqueeze(1) Y = group_pc[:, 1, :, :].unsqueeze(1) Z = group_pc[:, 2, :, :].unsqueeze(1) XX, YY, ZZ = X**2, Y**2, Z**2 XXX, YYY, ZZZ = XX * X, YY * Y, ZZ * Z XY, XZ, YZ = X * Y, X * Z, Y * Z XXY, XXZ, YYZ, YYX, ZZX, ZZY, XYZ = X * XY, X * XZ, Y * YZ, Y * XY, Z * XZ, Z * YZ, XY * Z group_XYZ = torch.cat([ X, Y, Z, XX, YY, ZZ, XXX, YYY, ZZZ,\ XY, XZ, YZ, XXY, XXZ, YYZ, YYX, ZZX, ZZY, XYZ ], dim=1) # B x 20 x N x k taylor = self.conv1(group_XYZ) # B x taylor_channel x N x k group_feat = group_feat.unsqueeze(2) #B x inchannel x 1 x N x k taylor = taylor.unsqueeze(1) # B x 1 x taylor_channel x N x k group_feat = torch.mul(group_feat, taylor).view( B, self.in_channel * self.taylor_channel, N, k) group_feat = self.conv2(group_feat) # B x out_channel x N x 1 group_feat = group_feat.squeeze(-1) return group_feat
group_feat = torch.mul(group_feat, taylor).view( B, self.in_channel * self.taylor_channel, N, k) group_feat = self.conv2(group_feat) # B x out_channel x N x 1 group_feat = group_feat.squeeze(-1) return group_feat if __name__ == '__main__': in_channel = 3 out_channel = 6 taylor_channel = 9 k = 3 batch_size = 3 num_points = 10 model = _BaseSpiderConv(in_channel, out_channel, taylor_channel, batch_size, num_points, k) pc = torch.randn(batch_size, 3, num_points) feat = torch.randn(batch_size, in_channel, num_points) idx = md_utils._knn_indices(pc, k) group_pc = md_utils._indices_group(pc, idx) pc = pc.unsqueeze(-1).expand(batch_size, 3, num_points, k) group_pc = group_pc - pc output = model(feat, idx, group_pc) print(output.size())