def forward(self, xyz, points): """ Input: xyz: input points position data, [B, C, N] points: input points data, [B, D, N] Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ device = xyz.device B, C, N = xyz.shape xyz_t = xyz.permute(0, 2, 1).contiguous() # [B, N, C] fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) # [B, npoint] new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, 3, npoint] new_xyz_t = new_xyz.permute(0, 2, 1).contiguous() _, idx = pointutils.knn(self.nsample, new_xyz_t, xyz_t) # [B, npoint, nsample] neighbors = pointutils.grouping_operation( xyz, idx) # [B, 3, npoint, nsample] centers = new_xyz.view(B, -1, self.npoint, 1).repeat( 1, 1, 1, self.nsample) # [B, 3, npoint, nsample] pos_diff = centers - neighbors # [B, 3, npoint, nsample] distances = torch.norm(pos_diff, p=2, dim=1, keepdim=True) # [B, 1, npoint, nsample] h_xi_xj = torch.cat([distances, pos_diff, centers, neighbors], dim=1) # [B, 1+3+3+3, npoint, nsample] x = pointutils.grouping_operation(points, idx) # [B, D, npoint, nsample] x = torch.cat([neighbors, x], dim=1) # [B, D+3, npoint, nsample] h_xi_xj = self.mapping_func2( F.relu(self.bn_mapping( self.mapping_func1(h_xi_xj)))) # [B, c_in, npoint, nsample] if self.first_layer: x = F.relu(self.bn_xyz_raising( self.xyz_raising(x))) # [B, c_in, npoint, nsample] x = F.relu(self.bn_rsconv(torch.mul(h_xi_xj, x))) # (B, c_in, npoint, nsample) for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] x = F.relu(bn(conv(x))) # [B, c_out, npoint, nsample] x = torch.max(x, -1)[0] # [B, c_out, npoint] # x = F.relu(self.bn_channel_raising(self.cr_mapping(x))) # [B, c_out, npoint] return new_xyz, x
def forward(self, pos1, pos2, feature1, feature2): """ Feature propagation from xyz2 (less points) to xyz1 (more points) Inputs: xyz1: (batch_size, 3, npoint1) xyz2: (batch_size, 3, npoint2) feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points) feat2: (batch_size, channel1, npoint2) features for xyz2 points Output: feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3) TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating. """ pos1_t = pos1.permute(0, 2, 1).contiguous() pos2_t = pos2.permute(0, 2, 1).contiguous() B, C, N = pos1.shape if self.knn: _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) # [B, N1, S] else: idx = pointutils.ball_query(self.radius, self.nsample, pos2_t, pos1_t) pos2_grouped = pointutils.grouping_operation(pos2, idx) pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N1, S] feat2_grouped = pointutils.grouping_operation(feature2, idx) feat_new = torch.cat([feat2_grouped, pos_diff], dim=1) # [B, C1+3, N1, S] for conv in self.mlp1_convs: feat_new = conv(feat_new) # max pooling feat_new = feat_new.max(-1)[0] # [B, mlp1[-1], N1] # concatenate feature in early layer if feature1 is not None: feat_new = torch.cat([feat_new, feature1], dim=1) # [B, mlp1[-1]+feat1_channel, N1] for conv in self.mlp2_convs: feat_new = conv(feat_new) return feat_new
def forward(self, pos1, pos2, feature1, feature2): """ Input: xyz1: (batch_size, 3, npoint) xyz2: (batch_size, 3, npoint) feat1: (batch_size, channel, npoint) feat2: (batch_size, channel, npoint) Output: xyz1: (batch_size, 3, npoint) feat1_new: (batch_size, mlp[-1], npoint) """ pos1_t = pos1.permute(0, 2, 1).contiguous() pos2_t = pos2.permute(0, 2, 1).contiguous() B, N, C = pos1_t.shape if self.knn: _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) # [B, N, S] else: idx = pointutils.ball_query(self.radius, self.nsample, pos2_t, pos1_t) pos2_grouped = pointutils.grouping_operation(pos2, idx) # [B, 3, N, S] pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N, S] feat2_grouped = pointutils.grouping_operation(feature2, idx) # [B, C, N, S] if self.corr_func == 'concat': feat_diff = torch.cat([ feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample) ], dim=1) # [B, 2*C, N, S] feat1_new = torch.cat([pos_diff, feat_diff], dim=1) # [B, 2*C+3, N, S] for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] feat1_new = F.relu(bn(conv(feat1_new))) feat1_new = torch.max(feat1_new, -1)[0] # [B, mlp[-1], npoint] return pos1, feat1_new