def forward(self, x, sn, node, is_train=False, epoch=None):
        """
        :param x: Bx3xN Tensor
        :param sn: Bx3xN Tensor
        :param node: Bx3xM FloatTensor
        :param is_train: determine whether to add noise in KNNModule
        :return:
        """
        # -- Prepare input --
        # modify the x according to the nodes, minus the center
        mask, mask_row_max, min_idx = som.query_topk(
            node, x, node.size()[2],
            k=self.opt.k)  # BxkNxnode_num, Bxnode_num, BxkN
        mask_row_sum = torch.sum(mask, dim=1)  # Bxnode_num
        mask = mask.unsqueeze(1)  # Bx1xkNxnode_num

        # if necessary, stack the x
        x_stack = x.repeat(1, 1, self.opt.k)
        sn_stack = sn.repeat(1, 1, self.opt.k)

        x_stack_data_unsqueeze = x_stack.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float(
        )  # BxCxkNxnode_num
        cluster_mean = (torch.sum(x_stack_data_masked, dim=2) /
                        (mask_row_sum.unsqueeze(1).float() + 1e-5).detach()
                        )  # BxCxnode_num
        # cluster_mean = node
        som_node_cluster_mean = cluster_mean

        B, N, kN, M = (
            x.size()[0],
            x.size()[2],
            x_stack.size()[2],
            som_node_cluster_mean.size()[2],
        )

        # assign each point with a center
        node_expanded = som_node_cluster_mean.unsqueeze(
            2)  # BxCx1xnode_num, som.node is BxCxnode_num
        centers = torch.sum(mask.float() * node_expanded,
                            dim=3).detach()  # BxCxkN

        x_decentered = (x_stack - centers).detach()  # Bx3xkN
        x_augmented = torch.cat((x_decentered, sn_stack), dim=1)  # Bx6xkN

        # -- Nodes branch --
        # First PointNet
        if self.opt.surface_normal_len >= 1:
            first_pn_out = self.first_pointnet(x_augmented, epoch)
        else:
            first_pn_out = self.first_pointnet(x_decentered, epoch)

        with torch.cuda.device(first_pn_out.get_device()):
            first_gather_index = (index_max.forward_cuda_shared_mem(
                first_pn_out.detach(), min_idx.int(), M).detach().long())
        first_pn_out_masked_max = (
            first_pn_out.gather(dim=2, index=first_gather_index) *
            mask_row_max.unsqueeze(1).float())  # BxCxM
        # scatter the masked_max back to the kN points
        scattered_first_masked_max = torch.gather(
            first_pn_out_masked_max,
            dim=2,
            index=min_idx.unsqueeze(1).expand(B,
                                              first_pn_out.size()[1], kN),
        )  # BxCxkN
        first_pn_out_fusion = torch.cat(
            (first_pn_out, scattered_first_masked_max), dim=1)  # Bx2CxkN
        # Second PointNet
        second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch)
        with torch.cuda.device(second_pn_out.get_device()):
            second_gather_index = (index_max.forward_cuda_shared_mem(
                second_pn_out, min_idx.int(), M).detach().long())
        second_pn_out_masked_max = (
            second_pn_out.gather(dim=2, index=second_gather_index) *
            mask_row_max.unsqueeze(1).float())  # BxCxM
        # knn search on nodes
        knn_feature_1 = self.knnlayer_1(
            query=som_node_cluster_mean,
            database=som_node_cluster_mean,
            x=second_pn_out_masked_max,
            K=self.opt.node_knn_k_1,
            epoch=epoch,
        )
        node_feature_aggregated = torch.cat(
            (second_pn_out_masked_max, knn_feature_1), dim=1)  # Bx(C1+C2)xM

        # mlp to calculate the per-node keypoint
        y = self.mlp1(node_feature_aggregated)
        point_descriptor = self.mlp2(y)
        keypoint_sigma = self.mlp3(point_descriptor)  # Bx(3+1)xkN
        nodes = keypoint_sigma[:, 0:3, :] + som_node_cluster_mean  # Bx3xM

        # -- Pose and coefficients branch --
        x_init_augmented = torch.cat((x_stack, sn_stack), dim=1)
        coeffs, rot = self.instance_branch(x_init_augmented, epoch)

        return nodes, coeffs, rot
Example #2
0
    def forward(self, x, sn, node, is_train=False, epoch=None):
        '''

        :param x: Bx3xN Tensor
        :param sn: Bx3xN Tensor
        :param node: Bx3xM FloatTensor
        :param is_train: determine whether to add noise in KNNModule
        :return:
        '''
        # modify the x according to the nodes, minus the center
        mask, mask_row_max, min_idx = som.query_topk(
            node, x, node.size()[2],
            k=self.opt.k)  # BxkNxnode_num, Bxnode_num, BxkN
        mask_row_sum = torch.sum(mask, dim=1)  # Bxnode_num
        mask = mask.unsqueeze(1)  # Bx1xkNxnode_num

        # if necessary, stack the x
        x_stack = x.repeat(1, 1, self.opt.k)
        sn_stack = sn.repeat(1, 1, self.opt.k)

        x_stack_data_unsqueeze = x_stack.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float(
        )  # BxCxkNxnode_num
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (
            mask_row_sum.unsqueeze(1).float() + 1e-5).detach()  # BxCxnode_num
        # cluster_mean = node
        som_node_cluster_mean = cluster_mean

        B, N, kN, M = x.size()[0], x.size()[2], x_stack.size(
        )[2], som_node_cluster_mean.size()[2]

        # assign each point with a center
        node_expanded = som_node_cluster_mean.unsqueeze(
            2)  # BxCx1xnode_num, som.node is BxCxnode_num
        centers = torch.sum(mask.float() * node_expanded,
                            dim=3).detach()  # BxCxkN

        x_decentered = (x_stack - centers).detach()  # Bx3xkN
        x_augmented = torch.cat((x_decentered, sn_stack), dim=1)  # Bx6xkN

        # go through the first PointNet
        if self.opt.surface_normal_len >= 1:
            first_pn_out = self.first_pointnet(x_augmented, epoch)
        else:
            first_pn_out = self.first_pointnet(x_decentered, epoch)

        # first_gather_index = self.masked_max.compute(first_pn_out, min_idx, mask).detach()  # BxCxM
        with torch.cuda.device(first_pn_out.get_device()):
            first_gather_index = index_max.forward_cuda_shared_mem(
                first_pn_out.detach(), min_idx.int(), M).detach().long()
        first_pn_out_masked_max = first_pn_out.gather(
            dim=2, index=first_gather_index) * mask_row_max.unsqueeze(
                1).float()  # BxCxM

        # scatter the masked_max back to the kN points
        scattered_first_masked_max = torch.gather(
            first_pn_out_masked_max,
            dim=2,
            index=min_idx.unsqueeze(1).expand(B,
                                              first_pn_out.size()[1],
                                              kN))  # BxCxkN
        first_pn_out_fusion = torch.cat(
            (first_pn_out, scattered_first_masked_max), dim=1)  # Bx2CxkN
        second_pn_out = self.second_pointnet(first_pn_out_fusion, epoch)

        # second_gather_index = self.masked_max.compute(second_pn_out, min_idx, mask).detach()  # BxCxM
        with torch.cuda.device(second_pn_out.get_device()):
            second_gather_index = index_max.forward_cuda_shared_mem(
                second_pn_out, min_idx.int(), M).detach().long()
        second_pn_out_masked_max = second_pn_out.gather(
            dim=2, index=second_gather_index) * mask_row_max.unsqueeze(
                1).float()  # BxCxM

        # knn module, knn search on nodes: ----------------------------------
        knn_feature_1 = self.knnlayer_1(query=som_node_cluster_mean,
                                        database=som_node_cluster_mean,
                                        x=second_pn_out_masked_max,
                                        K=self.opt.node_knn_k_1,
                                        epoch=epoch)

        node_feature_aggregated = torch.cat(
            (second_pn_out_masked_max, knn_feature_1), dim=1)  # Bx(C1+C2)xM

        # go through another network to calculate the per-node keypoint & uncertainty sigma
        y = self.mlp1(node_feature_aggregated)
        point_descriptor = self.mlp2(y)
        keypoint_sigma = self.mlp3(point_descriptor)  # Bx(3+1)xkN

        # keypoint = keypoint + node_coordinate
        keypoints = keypoint_sigma[:, 0:3, :] + som_node_cluster_mean  # Bx3xM
        # make sure sigma>=0 by square
        # sigmas = torch.pow(keypoint_sigma[:, 3, :], exponent=2) + self.opt.loss_sigma_lower_bound  # BxM
        sigmas = self.softplus(
            keypoint_sigma[:, 3, :]) + self.opt.loss_sigma_lower_bound  # BxM

        descriptors = None

        # debug
        # print(keypoints)
        # print(sigmas)

        return som_node_cluster_mean, keypoints, sigmas, descriptors
Example #3
0
                                                           8).long()
    end_t = time.time()
    print('cpu multi thread time: %f' % (end_t - begin_t))

    data_cuda = data.cuda()
    index_cuda = index.cuda()

    begin_t = time.time()
    for i in range(100):
        max_idx_cuda = index_max.forward_cuda(data_cuda, index_cuda, M).long()
    end_t = time.time()
    print('cuda cpp time, 100 times: %f' % (end_t - begin_t))

    begin_t = time.time()
    for i in range(100):
        max_idx_cuda_shared_mem = index_max.forward_cuda_shared_mem(
            data_cuda, index_cuda, M).long()
    end_t = time.time()
    print('cuda cpp shared mem time, 100 times: %f' % (end_t - begin_t))

    mask_max = operations.MaskedMax(M)
    begin_t = time.time()
    for i in range(100):
        max_idx_gt = mask_max.compute(data_cuda, index_cuda, None)
    end_t = time.time()
    print('cuda operations.py time, 100 times: %f' % (end_t - begin_t))

    print(torch.max(max_idx_gt.cpu() - max_idx_single_cpu))
    print(torch.min(max_idx_gt.cpu() - max_idx_single_cpu))

    print(torch.max(max_idx_gt.cpu() - max_idx_multi_cpu))
    print(torch.min(max_idx_gt.cpu() - max_idx_multi_cpu))
Example #4
0
    def forward(self, pc, intensity, sn, node_a, node_b):
        '''

        :param pc: Bx3xN Tensor
        :param intensity: Bx1xN Tensor
        :param sn: Bx3xN Tensor
        :param node_a: Bx3xMa FloatTensor
        :param node_b: Bx3xMb FloatTensor
        :param keypoint_anchor_idx: BxK IntTensor
        :return:
        '''
        B, N, Ma, Mb = pc.size(0), pc.size(2), node_a.size(2), node_b.size(2)

        # modify the pc according to the node_a, minus the center
        pc_B3NMa = pc.unsqueeze(3).expand(B, 3, N, Ma)
        node_a_B3NMa = node_a.unsqueeze(2).expand(B, 3, N, Ma)
        diff = torch.norm(pc_B3NMa - node_a_B3NMa, dim=1, p=2,
                          keepdim=False)  # BxNxMa
        _, min_k_idx = torch.topk(diff,
                                  k=self.opt.k_interp_point_a,
                                  dim=2,
                                  largest=False,
                                  sorted=True)  # BxNxk0
        min_idx = min_k_idx[:, :, 0]  # BxN
        mask = torch.eq(
            min_idx.unsqueeze(2).expand(B, N, Ma),
            self.node_idx_1NMa.to(device=min_idx.device,
                                  dtype=torch.long).expand(B, N, Ma))  # BxNxMa
        mask_row_max, _ = torch.max(
            mask, dim=1, keepdim=False
        )  # BxMa, this indicates whether the node has nearby points
        mask_row_max_B1Ma_float = mask_row_max.unsqueeze(1).to(
            dtype=torch.float)

        mask_B1NMa_float = mask.unsqueeze(1).to(dtype=torch.float)  # Bx1xNxMa
        mask_row_sum = torch.sum(mask_B1NMa_float, dim=2,
                                 keepdim=False)  # Bx1xMa

        # calculate the center of the cluster
        pc_masked = pc.unsqueeze(3) * mask_B1NMa_float  # BxCxNxMa
        cluster_mean = torch.sum(
            pc_masked, dim=2) / (mask_row_sum + 1e-5).detach()  # BxCxMa

        # assign each point with a center
        pc_centers = torch.gather(cluster_mean,
                                  index=min_idx.unsqueeze(1).expand(B, 3, N),
                                  dim=2)  # Bx3xN
        pc_decentered = (pc - pc_centers).detach()  # Bx3xN

        # go through the first PointNet
        pc_augmented = torch.cat((pc_decentered, intensity, sn),
                                 dim=1)  # Bx7xN
        first_pn_out = self.first_pointnet(pc_augmented)

        with torch.cuda.device(first_pn_out.get_device()):
            first_gather_index = index_max.forward_cuda_shared_mem(
                first_pn_out.detach(), min_idx.int(), Ma).detach().long()
        first_pn_out_masked_max = first_pn_out.gather(
            dim=2,
            index=first_gather_index) * mask_row_max_B1Ma_float  # BxCxMa

        # scatter the masked_max back to the N points
        scattered_first_masked_max = torch.gather(
            first_pn_out_masked_max,
            dim=2,
            index=min_idx.unsqueeze(1).expand(B, first_pn_out.size(1),
                                              N))  # BxCxN
        first_pn_out_fusion = torch.cat(
            (first_pn_out, scattered_first_masked_max), dim=1)  # Bx2CxN
        second_pn_out = self.second_pointnet(first_pn_out_fusion)

        with torch.cuda.device(second_pn_out.get_device()):
            second_gather_index = index_max.forward_cuda_shared_mem(
                second_pn_out, min_idx.int(), Ma).detach().long()
        node_a_features = second_pn_out.gather(
            dim=2,
            index=second_gather_index) * mask_row_max_B1Ma_float  # BxCaxMa

        # knn module, knn search on nodes: ----------------------------------
        node_b_features = self.knnlayer(
            query=node_b,
            database=cluster_mean,
            database_features=node_a_features,
            # database_features=torch.cat((cluster_mean, second_pn_out_masked_max), dim=1),
            K=self.opt.k_ab)  # BxCbxM

        # get global feature
        final_pn_out = self.final_pointnet(
            torch.cat((node_b, node_b_features), dim=1))  # BxCgxN
        global_feature, _ = torch.max(final_pn_out, dim=2,
                                      keepdim=True)  # BxCgx1

        return pc_centers,\
               cluster_mean,\
               min_k_idx, \
               first_pn_out, \
               second_pn_out, \
               node_a_features, \
               node_b_features, \
               global_feature