def forward(self, x, sn, node, node_knn_I, is_train=False, epoch=None): ''' :param x: Bx3xN Tensor :param sn: Bx3xN Tensor :param node: Bx3xM FloatTensor :param node_knn_I: BxMxk_som LongTensor :param is_train: determine whether to add noise in KNNModule :return: ''' # optimize the som, access the Tensor's tensor, the optimize function should not modify the tensor # self.som_builder.optimize(x.data) # self.som_builder.node.resize_(node.size()).copy_(node) # modify the x according to the nodes, minus the center mask, mask_row_max, min_idx = som.query_topk( node, x.data, node.size()[2], k=self.opt.k) # BxkNxnode_num, Bxnode_num, BxkN mask_row_sum = torch.sum(mask, dim=1) # Bxnode_num mask = mask.unsqueeze(1) # Bx1xkNxnode_num # if necessary, stack the x x_list, sn_list = [], [] for i in range(self.opt.k): x_list.append(x) sn_list.append(sn) x_stack = torch.cat(tuple(x_list), dim=2) sn_stack = torch.cat(tuple(sn_list), dim=2) # re-compute center, instead of using som.node x_stack_data_unsqueeze = x_stack.data.unsqueeze(3) # BxCxkNx1 x_stack_data_masked = x_stack_data_unsqueeze * mask.float( ) # BxCxkNxnode_num cluster_mean = torch.sum(x_stack_data_masked, dim=2) / ( mask_row_sum.unsqueeze(1).float() + 1e-5) # BxCxnode_num som_node_cluster_mean = cluster_mean # assign each point with a center node_expanded = som_node_cluster_mean.unsqueeze( 2) # BxCx1xnode_num, som.node is BxCxnode_num centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN x_decentered = (x_stack - centers).detach() # Bx3xkN x_augmented = torch.cat((x_decentered, sn_stack), dim=1) # Bx6xkN # go through the first PointNet if self.opt.surface_normal == True: first_pn_out = self.first_pointnet(x_augmented, epoch) else: first_pn_out = self.first_pointnet(x_decentered, epoch) # gather_index = self.masked_max.compute(first_pn_out, min_idx, mask).detach() M = node.size()[2] with torch.cuda.device(first_pn_out.get_device()): gather_index = index_max.forward_cuda(first_pn_out.detach(), min_idx.int(), M).detach().long() first_pn_out_masked_max = first_pn_out.gather( dim=2, index=gather_index * mask_row_max.unsqueeze(1).long()) # BxCxM if self.opt.som_k >= 2: # second pointnet, knn search on SOM nodes: ---------------------------------- knn_center_1, knn_feature_1 = self.knnlayer( som_node_cluster_mean, first_pn_out_masked_max, node_knn_I, self.opt.som_k, self.opt.som_k_type, epoch) # final pointnet -------------------------------------------------------------- final_pn_out = self.final_pointnet( torch.cat((knn_center_1, knn_feature_1), dim=1), epoch) # Bx1024xM else: # final pointnet -------------------------------------------------------------- final_pn_out = self.final_pointnet( torch.cat((som_node_cluster_mean, first_pn_out_masked_max), dim=1), epoch) # Bx1024xM feature, _ = torch.max(final_pn_out, dim=2, keepdim=False) return feature
def forward(self, x, sn, node, node_knn_I, is_train=False, epoch=None): ''' :param x: Bx3xN Tensor :param sn: Bx3xN Tensor :param node: Bx3xM FloatTensor :param node_knn_I: BxMxk_som LongTensor :param is_train: determine whether to add noise in KNNModule :return: ''' device = x.device # optimize the som, access the Tensor's tensor, the optimize function should not modify the tensor # self.som_builder.optimize(x.data) # self.som_builder.node.resize_(node.size()).copy_(node) # modify the x according to the nodes, minus the center mask, mask_row_max, min_idx = som.query_topk( node, x.data, node.size()[2], k=self.opt.k) # BxkNxnode_num, Bxnode_num, BxkN mask_row_sum = torch.sum(mask, dim=1) # Bxnode_num mask = mask.unsqueeze(1) # Bx1xkNxnode_num # if necessary, stack the x x_list, sn_list = [], [] for i in range(self.opt.k): x_list.append(x) sn_list.append(sn) x_stack = torch.cat(tuple(x_list), dim=2) # Bx3xkN sn_stack = torch.cat(tuple(sn_list), dim=2) # Bx3xkN # re-compute center, instead of using som.node x_stack_data_unsqueeze = x_stack.data.unsqueeze(3) # BxCxkNx1 x_stack_data_masked = x_stack_data_unsqueeze * mask.float( ) # BxCxkNxnode_num cluster_mean = torch.sum(x_stack_data_masked, dim=2) / ( mask_row_sum.unsqueeze(1).float() + 1e-5) # BxCxnode_num som_node_cluster_mean = cluster_mean # ====== rotate the pc, sn & som_node into R number of rotated versions ====== B, R, N, kN, M = x_stack.size()[0], \ self.opt.rot_equivariant_no, \ x.size()[2], x_stack.size()[2], \ node.size()[2] rotation_matrix = self.rotation_matrix_template.to(device).expand( B, R, 3, 3).detach() # 1xRx3x3 -> BxRx3x3 x_stack_rot = torch.matmul( rotation_matrix, x_stack.unsqueeze(1).expand(B, R, 3, kN)) # BxRx3x3 * BxRx3xkN -> BxRx3xkN sn_stack_rot = torch.matmul(rotation_matrix, sn_stack.unsqueeze(1).expand( B, R, 3, kN)) # BxRx3xkN som_node_rot = torch.matmul(rotation_matrix, som_node_cluster_mean.unsqueeze(1).expand( B, R, 3, M)) # BxRx3xM node_knn_I_rot = node_knn_I.unsqueeze(1).expand( B, R, M, self.opt.som_k).contiguous() # BxRxMxsom_k mask_rot = mask.unsqueeze(1).expand(B, R, 1, kN, M).contiguous() min_idx_rot = min_idx.unsqueeze(1).expand(B, R, kN).contiguous() mask_row_max_rot = mask_row_max.unsqueeze(1).expand(B, R, M).contiguous() # ====== rotate the pc, sn & som_node into R number of rotated versions ====== # assign each point with a center # single rotation ------ begin ------ # node_expanded = som_node_cluster_mean.unsqueeze(2) # Bx3x1xM, som.node is Bx3xM # centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN # # x_decentered = (x_stack - centers).detach() # Bx3xkN # x_augmented = torch.cat((x_decentered, sn_stack), dim=1) # Bx6xkN # single rotation ------ end ------ # multiple rotations ------ begin ------ node_rot_expanded = som_node_rot.unsqueeze( 3) # BxRx3x1xM, som_node_rot is BxRx3xM # mask: Bx1xkNxM -> BxRx1xkNxM, self.centers_rot: BxRx3xkN centers_rot = torch.sum(mask_rot.float() * node_rot_expanded, dim=4).detach() # BxRx3xkN x_decentered_rot = (x_stack_rot - centers_rot).detach() # BxRx3xkN x_augmented_rot = torch.cat((x_decentered_rot, sn_stack_rot), dim=2) # BxRx6xkN # multiple rotations ------ end ------ # go through the first PointNet if self.opt.surface_normal == True: first_pn_out_rot = self.first_pointnet( x_augmented_rot.contiguous().view(B * R, 6, kN).contiguous(), epoch) else: first_pn_out_rot = self.first_pointnet( x_decentered_rot.contiguous().view(B * R, 6, kN).contiguous(), epoch) C = first_pn_out_rot.size()[1] # permute and reshape the min_idx, mask_rot, mask_row_max_rot min_idx_rot = min_idx_rot.contiguous().view( B * R, kN).contiguous() # BxRxkN-> kNxBxR->kN*BR->BR*kN mask_rot = mask_rot.contiguous().view(B * R, 1, kN, M).contiguous( ) # BxRx1xkNxM -> 1xkNxMxBxR -> 1xkNxMxBR -> BRx1xkNxM mask_row_max_rot = mask_row_max_rot.contiguous().view( B * R, M).contiguous().unsqueeze(1).long() # first_gather_index_rot = self.masked_max.compute(first_pn_out_rot, # min_idx_rot, # mask_rot).detach() with torch.cuda.device(first_pn_out_rot.get_device()): first_gather_index_rot = index_max.forward_cuda( first_pn_out_rot.detach(), min_idx_rot.int(), M).detach().long() first_pn_out_masked_max_rot = first_pn_out_rot.gather( dim=2, index=first_gather_index_rot * mask_row_max_rot) # BRxCxM # scatter the masked_max back to the kN points scattered_first_masked_max = torch.gather( first_pn_out_masked_max_rot, dim=2, index=min_idx_rot.unsqueeze(1).expand(B * R, first_pn_out_rot.size()[1], kN)) # BRxCxkN first_pn_out_fusion = torch.cat( (first_pn_out_rot, scattered_first_masked_max), dim=1) # BRx2CxkN second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch) # second_gather_index_rot = self.masked_max.compute(second_pn_out, # min_idx_rot, # mask_rot).detach() # BRxCxM with torch.cuda.device(second_pn_out.get_device()): second_gather_index_rot = index_max.forward_cuda( second_pn_out.detach(), min_idx_rot.int(), M).detach().long() second_pn_out_masked_max_rot = second_pn_out.gather( dim=2, index=second_gather_index_rot * mask_row_max_rot) # BxCxM if self.opt.rot_equivariant_pooling_mode == 'per-hierarchy': # second_pn_out_masked_max_rot: BRxCxM second_pn_out_masked_max_rot = second_pn_out_masked_max_rot.contiguous( ).view(B, R, C, M).contiguous() # BxRxCxM second_pn_out_masked_max_rot, _ = torch.max( second_pn_out_masked_max_rot, dim=1, keepdim=True) # BxRxCxM -> Bx1xCxM second_pn_out_masked_max_rot = second_pn_out_masked_max_rot.expand( B, R, C, M).contiguous() # Bx1xCxM -> BxRxCxM second_pn_out_masked_max_rot = second_pn_out_masked_max_rot.contiguous( ).view( B * R, C, M, ).contiguous() # BRxCxM if self.opt.som_k >= 2: # second pointnet, knn search on SOM nodes: ---------------------------------- knn_center_1_rot, knn_feature_1_rot = self.knnlayer( som_node_rot.contiguous().view(B * R, 3, M).contiguous(), second_pn_out_masked_max_rot, node_knn_I_rot.contiguous().view(B * R, M, self.opt.som_k).contiguous(), self.opt.som_k, self.opt.som_k_type, epoch) C2 = knn_feature_1_rot.size()[1] # final pointnet -------------------------------------------------------------- if self.opt.rot_equivariant_pooling_mode == 'per-hierarchy': knn_feature_1_rot = knn_feature_1_rot.contiguous().view( B, R, C2, M).contiguous() # B*RxC2xM -> BxRxC2xM knn_feature_1_rot, _ = torch.max(knn_feature_1_rot, dim=1, keepdim=True) # Bx1xC2xM knn_feature_1_rot = knn_feature_1_rot.expand( B, R, C2, M).contiguous() # Bx1xC2xM -> BxRxC2xM knn_feature_1_rot = knn_feature_1_rot.contiguous().view( B * R, C2, M).contiguous() final_pn_out_rot = self.final_pointnet( torch.cat((knn_center_1_rot, knn_feature_1_rot), dim=1), epoch) # Bx1024xM else: # final pointnet -------------------------------------------------------------- final_pn_out_rot = self.final_pointnet( torch.cat((som_node_rot.contiguous().view( B * R, 3, M).contiguous(), second_pn_out_masked_max_rot), dim=1), epoch) # Bx1024xM # final_pn_out_rot: BRx1024xM final_pn_out_rot = final_pn_out_rot.contiguous().view( B, R, self.opt.feature_num, M).contiguous() feature_rot, _ = torch.max(final_pn_out_rot, dim=3, keepdim=False) # BxRxC feature, _ = torch.max(feature_rot, dim=1, keepdim=False) # # debug using vanilla pointnet # pn_out = self.pn(x) # BxCxN # feature, _ = torch.max(pn_out, dim=2, keepdim=False) return feature
def forward(self, x, sn, node, is_train=False, epoch=None): ''' :param x: Bx3xN Tensor :param sn: Bx3xN Tensor :param node: Bx3xM FloatTensor :param is_train: determine whether to add noise in KNNModule :return: ''' # modify the x according to the nodes, minus the center mask, mask_row_max, min_idx = som.query_topk( node, x, node.size()[2], k=self.opt.k) # BxkNxnode_num, Bxnode_num, BxkN mask_row_sum = torch.sum(mask, dim=1) # Bxnode_num mask = mask.unsqueeze(1) # Bx1xkNxnode_num # if necessary, stack the x x_stack = x.repeat(1, 1, self.opt.k) sn_stack = sn.repeat(1, 1, self.opt.k) x_stack_data_unsqueeze = x_stack.unsqueeze(3) # BxCxkNx1 x_stack_data_masked = x_stack_data_unsqueeze * mask.float( ) # BxCxkNxnode_num cluster_mean = torch.sum(x_stack_data_masked, dim=2) / ( mask_row_sum.unsqueeze(1).float() + 1e-5).detach() # BxCxnode_num # cluster_mean = node som_node_cluster_mean = cluster_mean B, N, kN, M = x.size()[0], x.size()[2], x_stack.size( )[2], som_node_cluster_mean.size()[2] # assign each point with a center node_expanded = som_node_cluster_mean.unsqueeze( 2) # BxCx1xnode_num, som.node is BxCxnode_num centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN x_decentered = (x_stack - centers).detach() # Bx3xkN x_augmented = torch.cat((x_decentered, sn_stack), dim=1) # Bx6xkN # go through the first PointNet if self.opt.surface_normal_len >= 1: first_pn_out = self.first_pointnet(x_augmented, epoch) else: first_pn_out = self.first_pointnet(x_decentered, epoch) # first_gather_index = self.masked_max.compute(first_pn_out, min_idx, mask).detach() # BxCxM with torch.cuda.device(first_pn_out.get_device()): first_gather_index = index_max.forward_cuda_shared_mem( first_pn_out.detach(), min_idx.int(), M).detach().long() first_pn_out_masked_max = first_pn_out.gather( dim=2, index=first_gather_index) * mask_row_max.unsqueeze( 1).float() # BxCxM # scatter the masked_max back to the kN points scattered_first_masked_max = torch.gather( first_pn_out_masked_max, dim=2, index=min_idx.unsqueeze(1).expand(B, first_pn_out.size()[1], kN)) # BxCxkN first_pn_out_fusion = torch.cat( (first_pn_out, scattered_first_masked_max), dim=1) # Bx2CxkN second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch) # second_gather_index = self.masked_max.compute(second_pn_out, min_idx, mask).detach() # BxCxM with torch.cuda.device(second_pn_out.get_device()): second_gather_index = index_max.forward_cuda_shared_mem( second_pn_out, min_idx.int(), M).detach().long() second_pn_out_masked_max = second_pn_out.gather( dim=2, index=second_gather_index) * mask_row_max.unsqueeze( 1).float() # BxCxM # knn module, knn search on nodes: ---------------------------------- knn_feature_1 = self.knnlayer_1(query=som_node_cluster_mean, database=som_node_cluster_mean, x=second_pn_out_masked_max, K=self.opt.node_knn_k_1, epoch=epoch) node_feature_aggregated = torch.cat( (second_pn_out_masked_max, knn_feature_1), dim=1) # Bx(C1+C2)xM # go through another network to calculate the per-node keypoint & uncertainty sigma y = self.mlp1(node_feature_aggregated) point_descriptor = self.mlp2(y) keypoint_sigma = self.mlp3(point_descriptor) # Bx(3+1)xkN # keypoint = keypoint + node_coordinate keypoints = keypoint_sigma[:, 0:3, :] + som_node_cluster_mean # Bx3xM # make sure sigma>=0 by square # sigmas = torch.pow(keypoint_sigma[:, 3, :], exponent=2) + self.opt.loss_sigma_lower_bound # BxM sigmas = self.softplus( keypoint_sigma[:, 3, :]) + self.opt.loss_sigma_lower_bound # BxM descriptors = None # debug # print(keypoints) # print(sigmas) return som_node_cluster_mean, keypoints, sigmas, descriptors