Esempio n. 1
0
    def forward(self, x, sn, node, node_knn_I, is_train=False, epoch=None):
        '''

        :param x: Bx3xN Tensor
        :param sn: Bx3xN Tensor
        :param node: Bx3xM FloatTensor
        :param node_knn_I: BxMxk_som LongTensor
        :param is_train: determine whether to add noise in KNNModule
        :return:
        '''

        # optimize the som, access the Tensor's tensor, the optimize function should not modify the tensor
        # self.som_builder.optimize(x.data)
        # self.som_builder.node.resize_(node.size()).copy_(node)

        # modify the x according to the nodes, minus the center
        mask, mask_row_max, min_idx = som.query_topk(
            node, x.data, node.size()[2],
            k=self.opt.k)  # BxkNxnode_num, Bxnode_num, BxkN
        mask_row_sum = torch.sum(mask, dim=1)  # Bxnode_num
        mask = mask.unsqueeze(1)  # Bx1xkNxnode_num

        # if necessary, stack the x
        x_list, sn_list = [], []
        for i in range(self.opt.k):
            x_list.append(x)
            sn_list.append(sn)
        x_stack = torch.cat(tuple(x_list), dim=2)
        sn_stack = torch.cat(tuple(sn_list), dim=2)

        # re-compute center, instead of using som.node
        x_stack_data_unsqueeze = x_stack.data.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float(
        )  # BxCxkNxnode_num
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (
            mask_row_sum.unsqueeze(1).float() + 1e-5)  # BxCxnode_num
        som_node_cluster_mean = cluster_mean

        # assign each point with a center
        node_expanded = som_node_cluster_mean.unsqueeze(
            2)  # BxCx1xnode_num, som.node is BxCxnode_num
        centers = torch.sum(mask.float() * node_expanded,
                            dim=3).detach()  # BxCxkN

        x_decentered = (x_stack - centers).detach()  # Bx3xkN
        x_augmented = torch.cat((x_decentered, sn_stack), dim=1)  # Bx6xkN

        # go through the first PointNet
        if self.opt.surface_normal == True:
            first_pn_out = self.first_pointnet(x_augmented, epoch)
        else:
            first_pn_out = self.first_pointnet(x_decentered, epoch)

        # gather_index = self.masked_max.compute(first_pn_out, min_idx, mask).detach()
        M = node.size()[2]
        with torch.cuda.device(first_pn_out.get_device()):
            gather_index = index_max.forward_cuda(first_pn_out.detach(),
                                                  min_idx.int(),
                                                  M).detach().long()
        first_pn_out_masked_max = first_pn_out.gather(
            dim=2,
            index=gather_index * mask_row_max.unsqueeze(1).long())  # BxCxM

        if self.opt.som_k >= 2:
            # second pointnet, knn search on SOM nodes: ----------------------------------
            knn_center_1, knn_feature_1 = self.knnlayer(
                som_node_cluster_mean, first_pn_out_masked_max, node_knn_I,
                self.opt.som_k, self.opt.som_k_type, epoch)

            # final pointnet --------------------------------------------------------------
            final_pn_out = self.final_pointnet(
                torch.cat((knn_center_1, knn_feature_1), dim=1),
                epoch)  # Bx1024xM
        else:
            # final pointnet --------------------------------------------------------------
            final_pn_out = self.final_pointnet(
                torch.cat((som_node_cluster_mean, first_pn_out_masked_max),
                          dim=1), epoch)  # Bx1024xM

        feature, _ = torch.max(final_pn_out, dim=2, keepdim=False)

        return feature
Esempio n. 2
0
    def forward(self, x, sn, node, node_knn_I, is_train=False, epoch=None):
        '''

        :param x: Bx3xN Tensor
        :param sn: Bx3xN Tensor
        :param node: Bx3xM FloatTensor
        :param node_knn_I: BxMxk_som LongTensor
        :param is_train: determine whether to add noise in KNNModule
        :return:
        '''
        device = x.device

        # optimize the som, access the Tensor's tensor, the optimize function should not modify the tensor
        # self.som_builder.optimize(x.data)
        # self.som_builder.node.resize_(node.size()).copy_(node)

        # modify the x according to the nodes, minus the center
        mask, mask_row_max, min_idx = som.query_topk(
            node, x.data, node.size()[2],
            k=self.opt.k)  # BxkNxnode_num, Bxnode_num, BxkN
        mask_row_sum = torch.sum(mask, dim=1)  # Bxnode_num
        mask = mask.unsqueeze(1)  # Bx1xkNxnode_num

        # if necessary, stack the x
        x_list, sn_list = [], []
        for i in range(self.opt.k):
            x_list.append(x)
            sn_list.append(sn)
        x_stack = torch.cat(tuple(x_list), dim=2)  # Bx3xkN
        sn_stack = torch.cat(tuple(sn_list), dim=2)  # Bx3xkN

        # re-compute center, instead of using som.node
        x_stack_data_unsqueeze = x_stack.data.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float(
        )  # BxCxkNxnode_num
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (
            mask_row_sum.unsqueeze(1).float() + 1e-5)  # BxCxnode_num
        som_node_cluster_mean = cluster_mean

        # ====== rotate the pc, sn & som_node into R number of rotated versions ======
        B, R, N, kN, M = x_stack.size()[0], \
                         self.opt.rot_equivariant_no, \
                         x.size()[2], x_stack.size()[2], \
                         node.size()[2]
        rotation_matrix = self.rotation_matrix_template.to(device).expand(
            B, R, 3, 3).detach()  # 1xRx3x3 -> BxRx3x3

        x_stack_rot = torch.matmul(
            rotation_matrix,
            x_stack.unsqueeze(1).expand(B, R, 3,
                                        kN))  # BxRx3x3 * BxRx3xkN -> BxRx3xkN
        sn_stack_rot = torch.matmul(rotation_matrix,
                                    sn_stack.unsqueeze(1).expand(
                                        B, R, 3, kN))  # BxRx3xkN
        som_node_rot = torch.matmul(rotation_matrix,
                                    som_node_cluster_mean.unsqueeze(1).expand(
                                        B, R, 3, M))  # BxRx3xM

        node_knn_I_rot = node_knn_I.unsqueeze(1).expand(
            B, R, M, self.opt.som_k).contiguous()  # BxRxMxsom_k
        mask_rot = mask.unsqueeze(1).expand(B, R, 1, kN, M).contiguous()
        min_idx_rot = min_idx.unsqueeze(1).expand(B, R, kN).contiguous()
        mask_row_max_rot = mask_row_max.unsqueeze(1).expand(B, R,
                                                            M).contiguous()

        # ====== rotate the pc, sn & som_node into R number of rotated versions ======

        # assign each point with a center
        # single rotation ------ begin ------
        # node_expanded = som_node_cluster_mean.unsqueeze(2)  # Bx3x1xM, som.node is Bx3xM
        # centers = torch.sum(mask.float() * node_expanded, dim=3).detach()  # BxCxkN
        #
        # x_decentered = (x_stack - centers).detach()  # Bx3xkN
        # x_augmented = torch.cat((x_decentered, sn_stack), dim=1)  # Bx6xkN
        # single rotation ------ end ------

        # multiple rotations ------ begin ------
        node_rot_expanded = som_node_rot.unsqueeze(
            3)  # BxRx3x1xM, som_node_rot is BxRx3xM
        # mask: Bx1xkNxM -> BxRx1xkNxM, self.centers_rot: BxRx3xkN
        centers_rot = torch.sum(mask_rot.float() * node_rot_expanded,
                                dim=4).detach()  # BxRx3xkN

        x_decentered_rot = (x_stack_rot - centers_rot).detach()  # BxRx3xkN
        x_augmented_rot = torch.cat((x_decentered_rot, sn_stack_rot),
                                    dim=2)  # BxRx6xkN
        # multiple rotations ------ end ------

        # go through the first PointNet
        if self.opt.surface_normal == True:
            first_pn_out_rot = self.first_pointnet(
                x_augmented_rot.contiguous().view(B * R, 6, kN).contiguous(),
                epoch)
        else:
            first_pn_out_rot = self.first_pointnet(
                x_decentered_rot.contiguous().view(B * R, 6, kN).contiguous(),
                epoch)
        C = first_pn_out_rot.size()[1]

        # permute and reshape the min_idx, mask_rot, mask_row_max_rot
        min_idx_rot = min_idx_rot.contiguous().view(
            B * R, kN).contiguous()  # BxRxkN-> kNxBxR->kN*BR->BR*kN
        mask_rot = mask_rot.contiguous().view(B * R, 1, kN, M).contiguous(
        )  # BxRx1xkNxM -> 1xkNxMxBxR -> 1xkNxMxBR -> BRx1xkNxM
        mask_row_max_rot = mask_row_max_rot.contiguous().view(
            B * R, M).contiguous().unsqueeze(1).long()

        # first_gather_index_rot = self.masked_max.compute(first_pn_out_rot,
        #                                                  min_idx_rot,
        #                                                  mask_rot).detach()
        with torch.cuda.device(first_pn_out_rot.get_device()):
            first_gather_index_rot = index_max.forward_cuda(
                first_pn_out_rot.detach(), min_idx_rot.int(),
                M).detach().long()
        first_pn_out_masked_max_rot = first_pn_out_rot.gather(
            dim=2, index=first_gather_index_rot * mask_row_max_rot)  # BRxCxM

        # scatter the masked_max back to the kN points
        scattered_first_masked_max = torch.gather(
            first_pn_out_masked_max_rot,
            dim=2,
            index=min_idx_rot.unsqueeze(1).expand(B * R,
                                                  first_pn_out_rot.size()[1],
                                                  kN))  # BRxCxkN
        first_pn_out_fusion = torch.cat(
            (first_pn_out_rot, scattered_first_masked_max), dim=1)  # BRx2CxkN
        second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch)

        # second_gather_index_rot = self.masked_max.compute(second_pn_out,
        #                                                   min_idx_rot,
        #                                                   mask_rot).detach()  # BRxCxM
        with torch.cuda.device(second_pn_out.get_device()):
            second_gather_index_rot = index_max.forward_cuda(
                second_pn_out.detach(), min_idx_rot.int(), M).detach().long()
        second_pn_out_masked_max_rot = second_pn_out.gather(
            dim=2, index=second_gather_index_rot * mask_row_max_rot)  # BxCxM

        if self.opt.rot_equivariant_pooling_mode == 'per-hierarchy':
            # second_pn_out_masked_max_rot: BRxCxM
            second_pn_out_masked_max_rot = second_pn_out_masked_max_rot.contiguous(
            ).view(B, R, C, M).contiguous()  # BxRxCxM
            second_pn_out_masked_max_rot, _ = torch.max(
                second_pn_out_masked_max_rot, dim=1,
                keepdim=True)  # BxRxCxM -> Bx1xCxM
            second_pn_out_masked_max_rot = second_pn_out_masked_max_rot.expand(
                B, R, C, M).contiguous()  # Bx1xCxM -> BxRxCxM
            second_pn_out_masked_max_rot = second_pn_out_masked_max_rot.contiguous(
            ).view(
                B * R,
                C,
                M,
            ).contiguous()  # BRxCxM
        if self.opt.som_k >= 2:
            # second pointnet, knn search on SOM nodes: ----------------------------------
            knn_center_1_rot, knn_feature_1_rot = self.knnlayer(
                som_node_rot.contiguous().view(B * R, 3, M).contiguous(),
                second_pn_out_masked_max_rot,
                node_knn_I_rot.contiguous().view(B * R, M,
                                                 self.opt.som_k).contiguous(),
                self.opt.som_k, self.opt.som_k_type, epoch)
            C2 = knn_feature_1_rot.size()[1]

            # final pointnet --------------------------------------------------------------
            if self.opt.rot_equivariant_pooling_mode == 'per-hierarchy':
                knn_feature_1_rot = knn_feature_1_rot.contiguous().view(
                    B, R, C2, M).contiguous()  # B*RxC2xM -> BxRxC2xM
                knn_feature_1_rot, _ = torch.max(knn_feature_1_rot,
                                                 dim=1,
                                                 keepdim=True)  # Bx1xC2xM
                knn_feature_1_rot = knn_feature_1_rot.expand(
                    B, R, C2, M).contiguous()  # Bx1xC2xM -> BxRxC2xM
                knn_feature_1_rot = knn_feature_1_rot.contiguous().view(
                    B * R, C2, M).contiguous()
            final_pn_out_rot = self.final_pointnet(
                torch.cat((knn_center_1_rot, knn_feature_1_rot), dim=1),
                epoch)  # Bx1024xM
        else:
            # final pointnet --------------------------------------------------------------
            final_pn_out_rot = self.final_pointnet(
                torch.cat((som_node_rot.contiguous().view(
                    B * R, 3, M).contiguous(), second_pn_out_masked_max_rot),
                          dim=1), epoch)  # Bx1024xM

        # final_pn_out_rot:  BRx1024xM
        final_pn_out_rot = final_pn_out_rot.contiguous().view(
            B, R, self.opt.feature_num, M).contiguous()

        feature_rot, _ = torch.max(final_pn_out_rot, dim=3,
                                   keepdim=False)  # BxRxC
        feature, _ = torch.max(feature_rot, dim=1, keepdim=False)

        # # debug using vanilla pointnet
        # pn_out = self.pn(x)  # BxCxN
        # feature, _ = torch.max(pn_out, dim=2, keepdim=False)

        return feature
Esempio n. 3
0
    def forward(self, x, sn, node, is_train=False, epoch=None):
        '''

        :param x: Bx3xN Tensor
        :param sn: Bx3xN Tensor
        :param node: Bx3xM FloatTensor
        :param is_train: determine whether to add noise in KNNModule
        :return:
        '''
        # modify the x according to the nodes, minus the center
        mask, mask_row_max, min_idx = som.query_topk(
            node, x, node.size()[2],
            k=self.opt.k)  # BxkNxnode_num, Bxnode_num, BxkN
        mask_row_sum = torch.sum(mask, dim=1)  # Bxnode_num
        mask = mask.unsqueeze(1)  # Bx1xkNxnode_num

        # if necessary, stack the x
        x_stack = x.repeat(1, 1, self.opt.k)
        sn_stack = sn.repeat(1, 1, self.opt.k)

        x_stack_data_unsqueeze = x_stack.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float(
        )  # BxCxkNxnode_num
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (
            mask_row_sum.unsqueeze(1).float() + 1e-5).detach()  # BxCxnode_num
        # cluster_mean = node
        som_node_cluster_mean = cluster_mean

        B, N, kN, M = x.size()[0], x.size()[2], x_stack.size(
        )[2], som_node_cluster_mean.size()[2]

        # assign each point with a center
        node_expanded = som_node_cluster_mean.unsqueeze(
            2)  # BxCx1xnode_num, som.node is BxCxnode_num
        centers = torch.sum(mask.float() * node_expanded,
                            dim=3).detach()  # BxCxkN

        x_decentered = (x_stack - centers).detach()  # Bx3xkN
        x_augmented = torch.cat((x_decentered, sn_stack), dim=1)  # Bx6xkN

        # go through the first PointNet
        if self.opt.surface_normal_len >= 1:
            first_pn_out = self.first_pointnet(x_augmented, epoch)
        else:
            first_pn_out = self.first_pointnet(x_decentered, epoch)

        # first_gather_index = self.masked_max.compute(first_pn_out, min_idx, mask).detach()  # BxCxM
        with torch.cuda.device(first_pn_out.get_device()):
            first_gather_index = index_max.forward_cuda_shared_mem(
                first_pn_out.detach(), min_idx.int(), M).detach().long()
        first_pn_out_masked_max = first_pn_out.gather(
            dim=2, index=first_gather_index) * mask_row_max.unsqueeze(
                1).float()  # BxCxM

        # scatter the masked_max back to the kN points
        scattered_first_masked_max = torch.gather(
            first_pn_out_masked_max,
            dim=2,
            index=min_idx.unsqueeze(1).expand(B,
                                              first_pn_out.size()[1],
                                              kN))  # BxCxkN
        first_pn_out_fusion = torch.cat(
            (first_pn_out, scattered_first_masked_max), dim=1)  # Bx2CxkN
        second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch)

        # second_gather_index = self.masked_max.compute(second_pn_out, min_idx, mask).detach()  # BxCxM
        with torch.cuda.device(second_pn_out.get_device()):
            second_gather_index = index_max.forward_cuda_shared_mem(
                second_pn_out, min_idx.int(), M).detach().long()
        second_pn_out_masked_max = second_pn_out.gather(
            dim=2, index=second_gather_index) * mask_row_max.unsqueeze(
                1).float()  # BxCxM

        # knn module, knn search on nodes: ----------------------------------
        knn_feature_1 = self.knnlayer_1(query=som_node_cluster_mean,
                                        database=som_node_cluster_mean,
                                        x=second_pn_out_masked_max,
                                        K=self.opt.node_knn_k_1,
                                        epoch=epoch)

        node_feature_aggregated = torch.cat(
            (second_pn_out_masked_max, knn_feature_1), dim=1)  # Bx(C1+C2)xM

        # go through another network to calculate the per-node keypoint & uncertainty sigma
        y = self.mlp1(node_feature_aggregated)
        point_descriptor = self.mlp2(y)
        keypoint_sigma = self.mlp3(point_descriptor)  # Bx(3+1)xkN

        # keypoint = keypoint + node_coordinate
        keypoints = keypoint_sigma[:, 0:3, :] + som_node_cluster_mean  # Bx3xM
        # make sure sigma>=0 by square
        # sigmas = torch.pow(keypoint_sigma[:, 3, :], exponent=2) + self.opt.loss_sigma_lower_bound  # BxM
        sigmas = self.softplus(
            keypoint_sigma[:, 3, :]) + self.opt.loss_sigma_lower_bound  # BxM

        descriptors = None

        # debug
        # print(keypoints)
        # print(sigmas)

        return som_node_cluster_mean, keypoints, sigmas, descriptors