def forward(self, x, sn, node, is_train=False, epoch=None): """ :param x: Bx3xN Tensor :param sn: Bx3xN Tensor :param node: Bx3xM FloatTensor :param is_train: determine whether to add noise in KNNModule :return: """ # -- Prepare input -- # modify the x according to the nodes, minus the center mask, mask_row_max, min_idx = som.query_topk( node, x, node.size()[2], k=self.opt.k) # BxkNxnode_num, Bxnode_num, BxkN mask_row_sum = torch.sum(mask, dim=1) # Bxnode_num mask = mask.unsqueeze(1) # Bx1xkNxnode_num # if necessary, stack the x x_stack = x.repeat(1, 1, self.opt.k) sn_stack = sn.repeat(1, 1, self.opt.k) x_stack_data_unsqueeze = x_stack.unsqueeze(3) # BxCxkNx1 x_stack_data_masked = x_stack_data_unsqueeze * mask.float( ) # BxCxkNxnode_num cluster_mean = (torch.sum(x_stack_data_masked, dim=2) / (mask_row_sum.unsqueeze(1).float() + 1e-5).detach() ) # BxCxnode_num # cluster_mean = node som_node_cluster_mean = cluster_mean B, N, kN, M = ( x.size()[0], x.size()[2], x_stack.size()[2], som_node_cluster_mean.size()[2], ) # assign each point with a center node_expanded = som_node_cluster_mean.unsqueeze( 2) # BxCx1xnode_num, som.node is BxCxnode_num centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN x_decentered = (x_stack - centers).detach() # Bx3xkN x_augmented = torch.cat((x_decentered, sn_stack), dim=1) # Bx6xkN # -- Nodes branch -- # First PointNet if self.opt.surface_normal_len >= 1: first_pn_out = self.first_pointnet(x_augmented, epoch) else: first_pn_out = self.first_pointnet(x_decentered, epoch) with torch.cuda.device(first_pn_out.get_device()): first_gather_index = (index_max.forward_cuda_shared_mem( first_pn_out.detach(), min_idx.int(), M).detach().long()) first_pn_out_masked_max = ( first_pn_out.gather(dim=2, index=first_gather_index) * mask_row_max.unsqueeze(1).float()) # BxCxM # scatter the masked_max back to the kN points scattered_first_masked_max = torch.gather( first_pn_out_masked_max, dim=2, index=min_idx.unsqueeze(1).expand(B, first_pn_out.size()[1], kN), ) # BxCxkN first_pn_out_fusion = torch.cat( (first_pn_out, scattered_first_masked_max), dim=1) # Bx2CxkN # Second PointNet second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch) with torch.cuda.device(second_pn_out.get_device()): second_gather_index = (index_max.forward_cuda_shared_mem( second_pn_out, min_idx.int(), M).detach().long()) second_pn_out_masked_max = ( second_pn_out.gather(dim=2, index=second_gather_index) * mask_row_max.unsqueeze(1).float()) # BxCxM # knn search on nodes knn_feature_1 = self.knnlayer_1( query=som_node_cluster_mean, database=som_node_cluster_mean, x=second_pn_out_masked_max, K=self.opt.node_knn_k_1, epoch=epoch, ) node_feature_aggregated = torch.cat( (second_pn_out_masked_max, knn_feature_1), dim=1) # Bx(C1+C2)xM # mlp to calculate the per-node keypoint y = self.mlp1(node_feature_aggregated) point_descriptor = self.mlp2(y) keypoint_sigma = self.mlp3(point_descriptor) # Bx(3+1)xkN nodes = keypoint_sigma[:, 0:3, :] + som_node_cluster_mean # Bx3xM # -- Pose and coefficients branch -- x_init_augmented = torch.cat((x_stack, sn_stack), dim=1) coeffs, rot = self.instance_branch(x_init_augmented, epoch) return nodes, coeffs, rot
def forward(self, x, sn, node, is_train=False, epoch=None): ''' :param x: Bx3xN Tensor :param sn: Bx3xN Tensor :param node: Bx3xM FloatTensor :param is_train: determine whether to add noise in KNNModule :return: ''' # modify the x according to the nodes, minus the center mask, mask_row_max, min_idx = som.query_topk( node, x, node.size()[2], k=self.opt.k) # BxkNxnode_num, Bxnode_num, BxkN mask_row_sum = torch.sum(mask, dim=1) # Bxnode_num mask = mask.unsqueeze(1) # Bx1xkNxnode_num # if necessary, stack the x x_stack = x.repeat(1, 1, self.opt.k) sn_stack = sn.repeat(1, 1, self.opt.k) x_stack_data_unsqueeze = x_stack.unsqueeze(3) # BxCxkNx1 x_stack_data_masked = x_stack_data_unsqueeze * mask.float( ) # BxCxkNxnode_num cluster_mean = torch.sum(x_stack_data_masked, dim=2) / ( mask_row_sum.unsqueeze(1).float() + 1e-5).detach() # BxCxnode_num # cluster_mean = node som_node_cluster_mean = cluster_mean B, N, kN, M = x.size()[0], x.size()[2], x_stack.size( )[2], som_node_cluster_mean.size()[2] # assign each point with a center node_expanded = som_node_cluster_mean.unsqueeze( 2) # BxCx1xnode_num, som.node is BxCxnode_num centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN x_decentered = (x_stack - centers).detach() # Bx3xkN x_augmented = torch.cat((x_decentered, sn_stack), dim=1) # Bx6xkN # go through the first PointNet if self.opt.surface_normal_len >= 1: first_pn_out = self.first_pointnet(x_augmented, epoch) else: first_pn_out = self.first_pointnet(x_decentered, epoch) # first_gather_index = self.masked_max.compute(first_pn_out, min_idx, mask).detach() # BxCxM with torch.cuda.device(first_pn_out.get_device()): first_gather_index = index_max.forward_cuda_shared_mem( first_pn_out.detach(), min_idx.int(), M).detach().long() first_pn_out_masked_max = first_pn_out.gather( dim=2, index=first_gather_index) * mask_row_max.unsqueeze( 1).float() # BxCxM # scatter the masked_max back to the kN points scattered_first_masked_max = torch.gather( first_pn_out_masked_max, dim=2, index=min_idx.unsqueeze(1).expand(B, first_pn_out.size()[1], kN)) # BxCxkN first_pn_out_fusion = torch.cat( (first_pn_out, scattered_first_masked_max), dim=1) # Bx2CxkN second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch) # second_gather_index = self.masked_max.compute(second_pn_out, min_idx, mask).detach() # BxCxM with torch.cuda.device(second_pn_out.get_device()): second_gather_index = index_max.forward_cuda_shared_mem( second_pn_out, min_idx.int(), M).detach().long() second_pn_out_masked_max = second_pn_out.gather( dim=2, index=second_gather_index) * mask_row_max.unsqueeze( 1).float() # BxCxM # knn module, knn search on nodes: ---------------------------------- knn_feature_1 = self.knnlayer_1(query=som_node_cluster_mean, database=som_node_cluster_mean, x=second_pn_out_masked_max, K=self.opt.node_knn_k_1, epoch=epoch) node_feature_aggregated = torch.cat( (second_pn_out_masked_max, knn_feature_1), dim=1) # Bx(C1+C2)xM # go through another network to calculate the per-node keypoint & uncertainty sigma y = self.mlp1(node_feature_aggregated) point_descriptor = self.mlp2(y) keypoint_sigma = self.mlp3(point_descriptor) # Bx(3+1)xkN # keypoint = keypoint + node_coordinate keypoints = keypoint_sigma[:, 0:3, :] + som_node_cluster_mean # Bx3xM # make sure sigma>=0 by square # sigmas = torch.pow(keypoint_sigma[:, 3, :], exponent=2) + self.opt.loss_sigma_lower_bound # BxM sigmas = self.softplus( keypoint_sigma[:, 3, :]) + self.opt.loss_sigma_lower_bound # BxM descriptors = None # debug # print(keypoints) # print(sigmas) return som_node_cluster_mean, keypoints, sigmas, descriptors
8).long() end_t = time.time() print('cpu multi thread time: %f' % (end_t - begin_t)) data_cuda = data.cuda() index_cuda = index.cuda() begin_t = time.time() for i in range(100): max_idx_cuda = index_max.forward_cuda(data_cuda, index_cuda, M).long() end_t = time.time() print('cuda cpp time, 100 times: %f' % (end_t - begin_t)) begin_t = time.time() for i in range(100): max_idx_cuda_shared_mem = index_max.forward_cuda_shared_mem( data_cuda, index_cuda, M).long() end_t = time.time() print('cuda cpp shared mem time, 100 times: %f' % (end_t - begin_t)) mask_max = operations.MaskedMax(M) begin_t = time.time() for i in range(100): max_idx_gt = mask_max.compute(data_cuda, index_cuda, None) end_t = time.time() print('cuda operations.py time, 100 times: %f' % (end_t - begin_t)) print(torch.max(max_idx_gt.cpu() - max_idx_single_cpu)) print(torch.min(max_idx_gt.cpu() - max_idx_single_cpu)) print(torch.max(max_idx_gt.cpu() - max_idx_multi_cpu)) print(torch.min(max_idx_gt.cpu() - max_idx_multi_cpu))
def forward(self, pc, intensity, sn, node_a, node_b): ''' :param pc: Bx3xN Tensor :param intensity: Bx1xN Tensor :param sn: Bx3xN Tensor :param node_a: Bx3xMa FloatTensor :param node_b: Bx3xMb FloatTensor :param keypoint_anchor_idx: BxK IntTensor :return: ''' B, N, Ma, Mb = pc.size(0), pc.size(2), node_a.size(2), node_b.size(2) # modify the pc according to the node_a, minus the center pc_B3NMa = pc.unsqueeze(3).expand(B, 3, N, Ma) node_a_B3NMa = node_a.unsqueeze(2).expand(B, 3, N, Ma) diff = torch.norm(pc_B3NMa - node_a_B3NMa, dim=1, p=2, keepdim=False) # BxNxMa _, min_k_idx = torch.topk(diff, k=self.opt.k_interp_point_a, dim=2, largest=False, sorted=True) # BxNxk0 min_idx = min_k_idx[:, :, 0] # BxN mask = torch.eq( min_idx.unsqueeze(2).expand(B, N, Ma), self.node_idx_1NMa.to(device=min_idx.device, dtype=torch.long).expand(B, N, Ma)) # BxNxMa mask_row_max, _ = torch.max( mask, dim=1, keepdim=False ) # BxMa, this indicates whether the node has nearby points mask_row_max_B1Ma_float = mask_row_max.unsqueeze(1).to( dtype=torch.float) mask_B1NMa_float = mask.unsqueeze(1).to(dtype=torch.float) # Bx1xNxMa mask_row_sum = torch.sum(mask_B1NMa_float, dim=2, keepdim=False) # Bx1xMa # calculate the center of the cluster pc_masked = pc.unsqueeze(3) * mask_B1NMa_float # BxCxNxMa cluster_mean = torch.sum( pc_masked, dim=2) / (mask_row_sum + 1e-5).detach() # BxCxMa # assign each point with a center pc_centers = torch.gather(cluster_mean, index=min_idx.unsqueeze(1).expand(B, 3, N), dim=2) # Bx3xN pc_decentered = (pc - pc_centers).detach() # Bx3xN # go through the first PointNet pc_augmented = torch.cat((pc_decentered, intensity, sn), dim=1) # Bx7xN first_pn_out = self.first_pointnet(pc_augmented) with torch.cuda.device(first_pn_out.get_device()): first_gather_index = index_max.forward_cuda_shared_mem( first_pn_out.detach(), min_idx.int(), Ma).detach().long() first_pn_out_masked_max = first_pn_out.gather( dim=2, index=first_gather_index) * mask_row_max_B1Ma_float # BxCxMa # scatter the masked_max back to the N points scattered_first_masked_max = torch.gather( first_pn_out_masked_max, dim=2, index=min_idx.unsqueeze(1).expand(B, first_pn_out.size(1), N)) # BxCxN first_pn_out_fusion = torch.cat( (first_pn_out, scattered_first_masked_max), dim=1) # Bx2CxN second_pn_out = self.second_pointnet(first_pn_out_fusion) with torch.cuda.device(second_pn_out.get_device()): second_gather_index = index_max.forward_cuda_shared_mem( second_pn_out, min_idx.int(), Ma).detach().long() node_a_features = second_pn_out.gather( dim=2, index=second_gather_index) * mask_row_max_B1Ma_float # BxCaxMa # knn module, knn search on nodes: ---------------------------------- node_b_features = self.knnlayer( query=node_b, database=cluster_mean, database_features=node_a_features, # database_features=torch.cat((cluster_mean, second_pn_out_masked_max), dim=1), K=self.opt.k_ab) # BxCbxM # get global feature final_pn_out = self.final_pointnet( torch.cat((node_b, node_b_features), dim=1)) # BxCgxN global_feature, _ = torch.max(final_pn_out, dim=2, keepdim=True) # BxCgx1 return pc_centers,\ cluster_mean,\ min_k_idx, \ first_pn_out, \ second_pn_out, \ node_a_features, \ node_b_features, \ global_feature