def forward(self, data): """ inputs data: data[0] - dbscan data data[1] - primary data """ # need to form graph, then pass through GNN clusts = form_clusters_new(data[0]) # remove track-like particles #types = get_cluster_label(data[0], clusts) #selection = types > 1 # 0 or 1 are track-like #clusts = clusts[selection] # remove compton clusters # if no cluster fits this condition, return selection = filter_compton(clusts) # non-compton looking clusters if not len(selection): e = torch.tensor([], requires_grad=True) if data[0].is_cuda: e.cuda() return e clusts = clusts[selection] # process group data # data_grp = process_group_data(data[1], data[0]) # data_grp = data[1] # form primary/secondary bipartite graph primaries = assign_primaries(data[1], clusts, data[0]) batch = get_cluster_batch(data[0], clusts) edge_index = primary_bipartite_incidence(batch, primaries, cuda=True) # obtain vertex features x = cluster_vtx_features(data[0], clusts, cuda=True) # x = cluster_vtx_features_old(data[0], clusts, cuda=True) #print("max input: ", torch.max(x.view(-1))) #print("min input: ", torch.min(x.view(-1))) # obtain edge features e = cluster_edge_features(data[0], clusts, edge_index, cuda=True) # go through layers x = self.attn1(x, edge_index) #print("max x: ", torch.max(x.view(-1))) #print("min x: ", torch.min(x.view(-1))) x = self.attn2(x, edge_index) #print("max x: ", torch.max(x.view(-1))) #print("min x: ", torch.min(x.view(-1))) x = self.attn3(x, edge_index) #print("max x: ", torch.max(x.view(-1))) #print("min x: ", torch.min(x.view(-1))) xbatch = torch.tensor(batch).cuda() x, e, u = self.edge_predictor(x, edge_index, e, u=None, batch=xbatch) print("max edge weight: ", torch.max(e.view(-1))) print("min edge weight: ", torch.min(e.view(-1))) return e
def forward(self, data): """ Input: data[0]: (Nx5) Cluster tensor with row (x, y, z, batch_id, cluster_id) Output: dictionary, with 'node_pred': torch.tensor with node prediction weights """ # Get device cluster_label = data[0] device = cluster_label.device # Find index of points that belong to the same EM clusters clusts = form_clusters_new(cluster_label) # If requested, remove clusters below a certain size threshold if self.remove_compton: selection = np.where(filter_compton(clusts, self.compton_thresh))[0] if not len(selection): return self.default_return(device) clusts = clusts[selection] # Get the cluster ids of each processed cluster clust_ids = get_cluster_label(cluster_label, clusts) # Get the batch ids of each cluster batch_ids = get_cluster_batch(cluster_label, clusts) # Form a complete graph (should add options for other structures, TODO) edge_index = complete_graph(batch_ids, device=device) if not edge_index.shape[0]: return self.default_return(device) # Obtain vertex features x = cluster_vtx_features(cluster_label, clusts, device=device) # Obtain edge features e = cluster_edge_features(cluster_label, clusts, edge_index, device=device) # Convert the the batch IDs to a torch tensor to pass to Torch xbatch = torch.tensor(batch_ids).to(device) # Pass through the model, get output out = self.node_predictor(x, edge_index, e, xbatch) return { **out, 'clust_ids': [torch.tensor(clust_ids)], 'batch_ids': [torch.tensor(batch_ids)], 'edge_index': [edge_index] }
def forward(self, data): """ inputs data: data[0] - dbscan data output: dictionary, with 'edge_pred': torch.tensor with edge prediction weights """ # get device device = data[0].device # need to form graph, then pass through GNN clusts = form_clusters_new(data[0]) # remove compton clusters # if no cluster fits this condition, return if self.remove_compton: selection = filter_compton(clusts, self.compton_thresh) # non-compton looking clusters if not len(selection): e = torch.tensor([], requires_grad=True) e.to(device) return {'edge_pred':[e]} clusts = clusts[selection] # form graph batch = get_cluster_batch(data[0], clusts) edge_index = complete_graph(batch, device=device) if not edge_index.shape[0]: e = torch.tensor([], requires_grad=True) e.to(device) return {'edge_pred':[e]} # obtain vertex features x = cluster_vtx_features(data[0], clusts, device=device) # obtain edge features e = cluster_edge_features(data[0], clusts, edge_index, device=device) # get x batch xbatch = torch.tensor(batch).to(device) # get output out = self.edge_predictor(x, edge_index, e, xbatch) return out
def forward(self, data): """ input data: data[0] - dbscan data data[1] - primary data output data: dictionary with following keys: edges : list of edge_index tensors used for edge prediction edge_pred : list of torch tensors with edge prediction weights matched : numpy array of group for each cluster (identified by primary index) n_iter : number of iterations taken each list is of length k, where k is the number of times the iterative network is applied """ # need to form graph, then pass through GNN clusts = form_clusters_new(data[0]) # remove compton clusters # if no cluster fits this condition, return if self.remove_compton: selection = filter_compton( clusts, self.compton_thresh) # non-compton looking clusters if not len(selection): e = torch.tensor([], requires_grad=True) if data[0].is_cuda: e = e.cuda() return e clusts = clusts[selection] #others = np.array([(i not in primaries) for i in range(n)]) batch = get_cluster_batch(data[0], clusts) # get x batch xbatch = torch.tensor(batch).cuda() primaries = assign_primaries(data[1], clusts, data[0], max_dist=self.pmd) # keep track of who is matched. -1 is not matched matched = np.repeat(-1, len(clusts)) matched[primaries] = primaries # print(matched) edges = [] edge_pred = [] counter = 0 found_match = True while (-1 in matched) and (counter < self.maxiter) and found_match: # continue until either: # 1. everything is matched # 2. we have exceeded the max number of iterations # 3. we didn't find any matches #print('iter ', counter) counter = counter + 1 # get matched indices assigned = np.where(matched > -1)[0] # print(assigned) others = np.where(matched == -1)[0] edge_index = primary_bipartite_incidence(batch, assigned, cuda=True) # check if there are any edges to predict # also batch norm will fail on only 1 edge, so break if this is the case if edge_index.shape[1] < 2: counter -= 1 break # obtain vertex features x = cluster_vtx_features(data[0], clusts, cuda=True) # obtain edge features e = cluster_edge_features(data[0], clusts, edge_index, cuda=True) # print(x.shape) # print(torch.max(edge_index)) # print(torch.min(edge_index)) out = self.edge_predictor(x, edge_index, e, xbatch) # predictions for this edge set. edge_pred.append(out[0][0]) edges.append(edge_index) #print(out[0][0].shape) matched, found_match = self.assign_clusters( edge_index, out[0][0][:, 1] - out[0][0][:, 0], others, matched, self.thresh) # print(edges) # print(edge_pred) #print('num iterations: ', counter) matched = torch.tensor(matched) counter = torch.tensor([counter]) if data[0].is_cuda: matched = matched.cuda() counter = counter.cuda() return { 'edges': [edges], 'edge_pred': [edge_pred], 'matched': [matched], 'counter': [counter] }