def forward(self, data): """ Input: data[0]: (Nx5) Cluster tensor with row (x, y, z, batch_id, cluster_id) Output: dictionary, with 'node_pred': torch.tensor with node prediction weights """ # Get device cluster_label = data[0] device = cluster_label.device # Find index of points that belong to the same EM clusters clusts = form_clusters_new(cluster_label) # If requested, remove clusters below a certain size threshold if self.remove_compton: selection = np.where(filter_compton(clusts, self.compton_thresh))[0] if not len(selection): return self.default_return(device) clusts = clusts[selection] # Get the cluster ids of each processed cluster clust_ids = get_cluster_label(cluster_label, clusts) # Get the batch ids of each cluster batch_ids = get_cluster_batch(cluster_label, clusts) # Form a complete graph (should add options for other structures, TODO) edge_index = complete_graph(batch_ids, device=device) if not edge_index.shape[0]: return self.default_return(device) # Obtain vertex features x = cluster_vtx_features(cluster_label, clusts, device=device) # Obtain edge features e = cluster_edge_features(cluster_label, clusts, edge_index, device=device) # Convert the the batch IDs to a torch tensor to pass to Torch xbatch = torch.tensor(batch_ids).to(device) # Pass through the model, get output out = self.node_predictor(x, edge_index, e, xbatch) return { **out, 'clust_ids': [torch.tensor(clust_ids)], 'batch_ids': [torch.tensor(batch_ids)], 'edge_index': [edge_index] }
def forward(self, data): """ inputs data: data[0] - dbscan data output: dictionary, with 'edge_pred': torch.tensor with edge prediction weights """ # get device device = data[0].device # need to form graph, then pass through GNN clusts = form_clusters_new(data[0]) # remove compton clusters # if no cluster fits this condition, return if self.remove_compton: selection = filter_compton( clusts, self.compton_thresh) # non-compton looking clusters if not len(selection): e = torch.tensor([], requires_grad=True) e.to(device) return {'edge_pred': [e]} clusts = clusts[selection] # form graph batch = get_cluster_batch(data[0], clusts) edge_index = complete_graph(batch, device=device) if not edge_index.shape[0]: e = torch.tensor([], requires_grad=True) e.to(device) return {'edge_pred': [e]} # obtain vertex directions x = cluster_vtx_dirs(data[0], clusts, device=device) # obtain edge directions e = cluster_edge_dirs(data[0], clusts, edge_index, device=device) # get x batch xbatch = torch.tensor(batch).to(device) # get output outdict = self.edge_predictor(x, edge_index, e, xbatch) return outdict
def forward(self, data): """ inputs data: data[0] - dbscan data """ # need to form graph, then pass through GNN clusts = form_clusters_new(data[0]) # remove compton clusters (should we?) # if no cluster fits this condition, return selection = filter_compton(clusts) # non-compton looking clusters if not len(selection): x = torch.tensor([], requires_grad=True) if data[0].is_cuda: x.cuda() return x clusts = clusts[selection] # form complete graph batch = get_cluster_batch(data[0], clusts) edge_index = complete_graph(batch, cuda=True) if not len(edge_index): x = torch.tensor([], requires_grad=True) if data[0].is_cuda: x.cuda() return x batch = torch.tensor(batch) if data[0].is_cuda: batch = batch.cuda() # obtain vertex features #x = cluster_vtx_features(data[0], clusts, cuda=True) x = cluster_vtx_features_old(data[0], clusts, cuda=True) # go through layers x = self.econv1(x, edge_index) x = self.econv2(x, edge_index) x = self.econv3(x, edge_index) x, e, u = self.predictor(x, edge_index, edge_attr=None, u=None, batch=batch) return F.log_softmax(x, dim=1)
def forward(self, out, data0, data1): """ out: dictionary output from GNN Model keys: 'edge_pred': predicted edge weights from model forward data: data[0] - DBSCAN data data[1] - groups data """ edge_pred = out[0][0] data0 = data0[0] data1 = data1[0] device = data0.device # first decide what true edges should be # need to form graph, then pass through GNN # clusts = form_clusters(data0) clusts = form_clusters_new(data0) # remove compton clusters # if no cluster fits this condition, return if self.remove_compton: selection = filter_compton( clusts, self.compton_thresh) # non-compton looking clusters if not len(selection): total_loss = self.lossfn(edge_pred, edge_pred) return {'accuracy': 1., 'loss': total_loss} clusts = clusts[selection] # process group data # data_grp = process_group_data(data1, data0) data_grp = data1 # form graph batch = get_cluster_batch(data0, clusts) edge_index = complete_graph(batch, device=device) if not edge_index.shape[0]: total_loss = self.lossfn(edge_pred, edge_pred) return {'accuracy': 0., 'loss': total_loss} group = get_cluster_label(data_grp, clusts) # determine true assignments edge_assn = edge_assignment(edge_index, batch, group, device=device, dtype=torch.long) edge_assn = edge_assn.view(-1) # total loss on batch total_loss = self.lossfn(edge_pred, edge_assn) # compute assigned clusters fe = edge_pred[1, :] - edge_pred[0, :] cs = assign_clusters_UF(edge_index, fe, len(clusts), thresh=0.0) ari, ami, sbd, pur, eff = DBSCAN_cluster_metrics2(cs, clusts, group) edge_ct = edge_index.shape[1] return { 'ARI': ari, 'AMI': ami, 'SBD': sbd, 'purity': pur, 'efficiency': eff, 'accuracy': ari, 'loss': total_loss, 'edge_count': edge_ct }