示例#1
0
    def get_single_event(self, event_idx):

        # ------- building the cluster graph ---------- #

        cluster_cell_ID = self.ev_tree['cluster_cell_ID'].array(
            entry_start=event_idx, entry_stop=event_idx + 1, library='np')[0]
        cluster_cell_E = self.ev_tree['cluster_cell_E'].array(
            entry_start=event_idx, entry_stop=event_idx + 1, library='np')[0]

        n_clusters = len(cluster_cell_ID)

        if (n_clusters == 0):
            #print('Empty cluster event')
            return {
                'gr': [dgl.rand_graph(2, 1)],
                'truth_E': torch.tensor([-1.])
            }

        graph_list = []
        # ---- loop over clusters ---- #
        for ic in range(n_clusters):

            cell_E = np.array(cluster_cell_E[ic])
            cell_idx = np.array(cluster_cell_ID[ic])

            cluster_cell_pos = torch.tensor(
                [self.id_to_position[x] for x in cell_idx])
            cluster_cell_pos = torch.reshape(
                cluster_cell_pos,
                (1, cluster_cell_pos.shape[0], cluster_cell_pos.shape[1]))

            n_part = len(cluster_cell_pos[0])

            if (n_part < 2): continue

            if (n_part < self.n_neighbor):
                graph_frn = FixedRadiusNNGraph(radius=self.R,
                                               n_neighbor=n_part)
            else:
                graph_frn = FixedRadiusNNGraph(radius=self.R,
                                               n_neighbor=self.n_neighbor)

            fps = FarthestPointSampler(n_part)
            centroids = fps(cluster_cell_pos)

            gr_frn = graph_frn(cluster_cell_pos, centroids)

            gr_frn.ndata['x'] = cluster_cell_pos[0]
            gr_frn.ndata['en'] = torch.tensor(cell_E)

            graph_list.append(gr_frn)
        # -------- #

        cluster_energy_truth = self.ev_tree['cluster_ENG_CALIB_TOT'].array(
            entry_start=event_idx, entry_stop=event_idx + 1, library='np')[0]
        # ---------------------------------------------------------------- #
        return {
            'gr': graph_list,
            'truth_E': torch.tensor(cluster_energy_truth)
        }
 def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
              group_all=False):
     super(SAModule, self).__init__()
     self.group_all = group_all
     if not group_all:
         self.fps = FarthestPointSampler(npoints)
         self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
     self.message = RelativePositionMessage(n_neighbor)
     self.conv = PointNetConv(mlp_sizes, batch_size)
     self.batch_size = batch_size
示例#3
0
def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
    fps = FarthestPointSampler(sample_points)
    res = fps(x)
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0
    def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list):
        super(SAMSGModule, self).__init__()
        self.batch_size = batch_size
        self.group_size = len(radius_list)

        self.fps = FarthestPointSampler(npoints)
        self.frnn_graph_list = nn.ModuleList()
        self.message_list = nn.ModuleList()
        self.conv_list = nn.ModuleList()
        for i in range(self.group_size):
            self.frnn_graph_list.append(FixedRadiusNNGraph(radius_list[i],
                                                           n_neighbor_list[i]))
            self.message_list.append(RelativePositionMessage(n_neighbor_list[i]))
            self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))