def get_single_event(self, event_idx): # ------- building the cluster graph ---------- # cluster_cell_ID = self.ev_tree['cluster_cell_ID'].array( entry_start=event_idx, entry_stop=event_idx + 1, library='np')[0] cluster_cell_E = self.ev_tree['cluster_cell_E'].array( entry_start=event_idx, entry_stop=event_idx + 1, library='np')[0] n_clusters = len(cluster_cell_ID) if (n_clusters == 0): #print('Empty cluster event') return { 'gr': [dgl.rand_graph(2, 1)], 'truth_E': torch.tensor([-1.]) } graph_list = [] # ---- loop over clusters ---- # for ic in range(n_clusters): cell_E = np.array(cluster_cell_E[ic]) cell_idx = np.array(cluster_cell_ID[ic]) cluster_cell_pos = torch.tensor( [self.id_to_position[x] for x in cell_idx]) cluster_cell_pos = torch.reshape( cluster_cell_pos, (1, cluster_cell_pos.shape[0], cluster_cell_pos.shape[1])) n_part = len(cluster_cell_pos[0]) if (n_part < 2): continue if (n_part < self.n_neighbor): graph_frn = FixedRadiusNNGraph(radius=self.R, n_neighbor=n_part) else: graph_frn = FixedRadiusNNGraph(radius=self.R, n_neighbor=self.n_neighbor) fps = FarthestPointSampler(n_part) centroids = fps(cluster_cell_pos) gr_frn = graph_frn(cluster_cell_pos, centroids) gr_frn.ndata['x'] = cluster_cell_pos[0] gr_frn.ndata['en'] = torch.tensor(cell_E) graph_list.append(gr_frn) # -------- # cluster_energy_truth = self.ev_tree['cluster_ENG_CALIB_TOT'].array( entry_start=event_idx, entry_stop=event_idx + 1, library='np')[0] # ---------------------------------------------------------------- # return { 'gr': graph_list, 'truth_E': torch.tensor(cluster_energy_truth) }
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64, group_all=False): super(SAModule, self).__init__() self.group_all = group_all if not group_all: self.fps = FarthestPointSampler(npoints) self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor) self.message = RelativePositionMessage(n_neighbor) self.conv = PointNetConv(mlp_sizes, batch_size) self.batch_size = batch_size
def test_fps(): N = 1000 batch_size = 5 sample_points = 10 x = th.tensor(np.random.uniform(size=(batch_size, int(N/batch_size), 3))) ctx = F.ctx() if F.gpu_ctx(): x = x.to(ctx) fps = FarthestPointSampler(sample_points) res = fps(x) assert res.shape[0] == batch_size assert res.shape[1] == sample_points assert res.sum() > 0
def __init__(self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list): super(SAMSGModule, self).__init__() self.batch_size = batch_size self.group_size = len(radius_list) self.fps = FarthestPointSampler(npoints) self.frnn_graph_list = nn.ModuleList() self.message_list = nn.ModuleList() self.conv_list = nn.ModuleList() for i in range(self.group_size): self.frnn_graph_list.append(FixedRadiusNNGraph(radius_list[i], n_neighbor_list[i])) self.message_list.append(RelativePositionMessage(n_neighbor_list[i])) self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))