def build_graph_batch(self, y, y_p_crops_list, edge_index): """ build the graph nodes and edge :param y: features of current input image (N, C, H, W) :param y_p_crops_list: list of previous crops, length of N, each with shape (n_crops_i, 64) :param edge_index: list with length (N) of tensors, each element of which has a shape of (2, n_edges_i) :return: graph: torch_geometric.data.Batch object containing the batched graph center_inds: list with length (N) of tensors, each element containing the corresponding indices for centers """ data_list = [] center_inds = [] offset = 0 _, C, _, _ = y.shape for yi, y_p_crops_i, ei in zip(y, y_p_crops_list, edge_index): graph_nodes = torch.cat( (y_p_crops_i, yi.reshape(C, -1).T.contiguous()), dim=0) data_list.append(gData(x=graph_nodes, edge_index=ei)) center_inds_i = offset + len(y_p_crops_i) + torch.arange( len(yi.reshape(C, -1).T)) center_inds.append(center_inds_i) offset += len(graph_nodes) graph = Batch.from_data_list(data_list) center_inds = torch.cat(center_inds) return graph, center_inds
def __getitem__(self, idx): target_cloud = self.dataset.get_velo(idx) target_pose = self.dataset.poses[idx] source_cloud = self.dataset.get_velo(idx + 1) source_pose = self.dataset.poses[idx + 1] pose = np.dot(np.linalg.inv(target_pose), source_pose) rotation = list( Rotation.from_dcm(pose[:3, :3]).as_euler("xyz", degrees=False)) translation = list(pose[:, -1]) pose_vect = translation[:-1] + rotation s_data = gData(x=torch.from_numpy(source_cloud[:, 2:3]), pos=torch.from_numpy(source_cloud[:, :3])) t_data = gData(x=torch.from_numpy(target_cloud[:, 2:3]), pos=torch.from_numpy(target_cloud[:, :3])) s_data = s_data if self.transform is None else self.transform(s_data) t_data = t_data if self.transform is None else self.transform(t_data) return PairData(s_data, t_data, torch.from_numpy(np.array(pose_vect).T))
def main(): batch_size = 16 num_nodes = 4 num_in_node_features = 16 num_out_node_features = 64 num_in_edge_features = 4 num_out_edge_features = 8 # Define batch of example graph edge_index = torch.tensor( [[0, 1, 2, 0, 3, 2, 3, 0], [1, 0, 0, 2, 2, 3, 0, 3]], dtype=torch.long) # Node features batch_x = torch.randn((batch_size, num_nodes, num_in_node_features), dtype=torch.float) # Edge features -- batch_edge_features has shape: torch.Size([4, 42, 8]) batch_edge_attr = torch.randn( (batch_size, edge_index.size(1), num_in_edge_features), dtype=torch.float) # Wrap input node and edge features, along with the single edge_index, into a `torch_geometric.data.Batch` instance l = [] for i in range(batch_size): l.append( gData(x=batch_x[i], edge_index=edge_index, edge_attr=batch_edge_attr[i])) batch = gBatch.from_data_list(l) # Thus, # batch.x -- shape: torch.Size([28, 16]) # batch.edge_index -- shape: torch.Size([2, 168]) # batch.edge_attr -- shape: torch.Size([168, 8]) # Define NNConv layer nn = tnn.Sequential( tnn.Linear(num_in_edge_features, 25), tnn.ReLU(), tnn.Linear(25, num_in_node_features * num_out_node_features)) gconv = gnn.NNConv(in_channels=num_in_node_features, out_channels=num_out_node_features, nn=nn, aggr='mean') # Forward pass y = gconv(x=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr)
def build_graph_sample(self, streams, lengths, gt=None): lengths = torch.from_numpy(lengths).long() batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths) batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)]) slices = batch_slices[1:-1] streams = torch.from_numpy(streams) l = streams.shape[0] graph_sample = gData(x=streams, lengths=lengths, bvec=batch_vec, pos=streams) e1 = set(np.arange(0, l - 1)) - set(slices.numpy() - 1) e2 = set(np.arange(1, l)) - set(slices.numpy()) edges = torch.tensor( [list(e1) + list(e2), list(e2) + list(e1)], dtype=torch.long) graph_sample['edge_index'] = edges num_edges = graph_sample.num_edges edge_attr = torch.ones(num_edges, 1) graph_sample['edge_attr'] = edge_attr return graph_sample
def build_graph_sample(self, streams, lengths, gt=None): #t0 = time.time() #print('time numpy split %f' % (time.time()-t0)) ### create graph structure #sls_lengths = torch.from_numpy(sls_lengths) lengths = torch.from_numpy(lengths).long() #print('sls lengths:',sls_lengths) batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths) batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)]) slices = batch_slices[1:-1] streams = torch.from_numpy(streams) l = streams.shape[0] graph_sample = gData( x=streams, lengths=lengths, #sls_lengths=sls_lengths, bvec=batch_vec, pos=streams) # bslices=batch_slices) #edges = torch.empty((2, 2*l - 2*n), dtype=torch.long) if self.return_edges: e1 = set(np.arange(0, l - 1)) - set(slices.numpy() - 1) e2 = set(np.arange(1, l)) - set(slices.numpy()) edges = torch.tensor( [list(e1) + list(e2), list(e2) + list(e1)], dtype=torch.long) graph_sample['edge_index'] = edges num_edges = graph_sample.num_edges edge_attr = torch.ones(num_edges, 1) graph_sample['edge_attr'] = edge_attr if self.distance: graph_sample = self.distance(graph_sample) #if self.self_loops: #graph_sample = self.self_loops(graph_sample) if gt is not None: graph_sample['y'] = gt return graph_sample
def getitem(self, idx): sub = self.subjects[idx] sub_dir = os.path.join(self.root_dir, 'sub-%s' % sub) T_file = os.path.join(sub_dir, 'sub-%s_var-HCP_full_tract.trk' % (sub)) label_file = os.path.join(sub_dir, 'sub-%s_var-HCP_labels.pkl' % (sub)) #T_file = os.path.join(sub_dir, 'All_%s.trk' % (tract_type)) #label_file = os.path.join(sub_dir, 'All_%s_gt.pkl' % (tract_type)) T = nib.streamlines.load(T_file, lazy_load=True) with open(label_file, 'rb') as f: gt = pickle.load(f) gt = np.array(gt) if type(gt) == list else gt if self.repeat_sampling is not None: if len(self.remaining[idx]) == 0: self.remaining[idx] = set(np.arange(T.header['nb_streamlines'])) sample = {'points': np.array(list(self.remaining[idx]))} if self.with_gt: sample['gt'] = gt[list(self.remaining[idx])] else: #sample = {'points': np.arange(T.header['nb_streamlines'])} #if self.with_gt: #sample['gt'] = gt sample = {'points': np.arange(T.header['nb_streamlines']), 'gt': gt} #t0 = time.time() if self.transform: sample = self.transform(sample) #print('time sampling %f' % (time.time()-t0)) if self.repeat_sampling is not None: self.remaining[idx] -= set(sample['points']) sample['obj_idxs'] = sample['points'].copy() sample['obj_full_size'] = T.header['nb_streamlines'] #t0 = time.time() sample['name'] = T_file.split('/')[-1].rsplit('.', 1)[0] n = len(sample['points']) #t0 = time.time() uniform_size = False if uniform_size: streams, l_max = load_selected_streamlines_uniform_size(T_file, sample['points'].tolist()) streams.reshape(n, l_max, -1) sample['points'] = torch.from_numpy(streams) else: streams, lengths = load_selected_streamlines(T_file, sample['points'].tolist()) #print('time loading selected streamlines %f' % (time.time()-t0)) #t0 = time.time() #print('time numpy split %f' % (time.time()-t0)) ### create graph structure lengths = torch.from_numpy(lengths) batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths) batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)]) slices = batch_slices[1:-1] streams = torch.from_numpy(streams) l = streams.shape[0] graph_sample = gData(x=streams, lengths=lengths, bvec=batch_vec) # bslices=batch_slices) #edges = torch.empty((2, 2*l - 2*n), dtype=torch.long) if self.return_edges: e1 = set(np.arange(0,l-1)) - set(slices-1) e2 = set(np.arange(1,l)) - set(slices) edges = torch.tensor([list(e1)+list(e2),list(e2)+list(e1)], dtype=torch.long) graph_sample['edge_index'] = edges if self.with_gt: graph_sample['y'] = torch.from_numpy(sample['gt']) sample['points'] = graph_sample #print('time building graph %f' % (time.time()-t0)) print(len(sample['points'])) return sample
def build_graph_sample(self, streams, lengths, gt=None): #t0 = time.time() #print('time numpy split %f' % (time.time()-t0)) ### create graph structure #sls_lengths = torch.from_numpy(sls_lengths) lengths = torch.from_numpy(lengths).long() #print('sls lengths:',sls_lengths) batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths) batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)]) slices = batch_slices[1:-1] streams = torch.from_numpy(streams) l = streams.shape[0] graph_sample = gData( x=streams, lengths=lengths, #sls_lengths=sls_lengths, bvec=batch_vec, pos=streams) # bslices=batch_slices) #edges = torch.empty((2, 2*l - 2*n), dtype=torch.long) if self.return_edges: e1 = set(np.arange(0, l - 1)) - set(slices.numpy() - 1) e2 = set(np.arange(1, l)) - set(slices.numpy()) edges = torch.tensor( [list(e1) + list(e2), list(e2) + list(e1)], dtype=torch.long) #print('old edges:', edges) #print(edges[0,-1]) #print(edges[0,:int(edges.shape[1]/2)+2]) #e1_new = torch.repeat_interleave(edges[0,:int(edges.shape[1]/2)],self.k) #e1_new=torch.cat((e1_new,torch.repeat_interleave(edges[0,-1],self.k))) #e2_new = torch.tensor([],dtype=torch.long) #for i in list(edges[0,:int(edges.shape[1]/2)]): #if i == 0: #e2_new = torch.cat([e2_new,torch.arange(i+1,self.k+1)],dim=0) #if i==1: #e2_new = torch.cat([e2_new,torch.cat([torch.tensor([i-1]),torch.arange(i+1,self.k+1)],dim=0)]) #if i<self.k/2 and i>1: #e2_new = torch.cat([e2_new,torch.cat([torch.arange(0,i),torch.arange(i+1,i+(self.k-i)+1)],dim=0)]) #if i>=self.k/2 and i!=edges[0,int(edges.shape[1]/2)-1]: #if i+self.k/2 > edges[0,-1]: #e2_new = torch.cat([e2_new,torch.cat([torch.arange(i-(self.k-(edges[0,-1]-i)),i),torch.arange(i+1,edges[0,-1]+1)],dim=0)]) #else: #e = torch.cat([torch.arange(i-self.k/2,i),torch.arange(i+1,i+self.k/2+1)]) #e = e.long() #print(e) #e2_new = torch.cat([e2_new,e]) #e2 = torch.cat([e2,torch.cat([torch.arange(i-self.k/2,i),torch.arange(i+1,i+self.k/2+1)],dim=0)]) #if i==edges[0,int(edges.shape[1]/2)-1]: #e2_new = torch.cat([e2_new,torch.cat([torch.arange(i-1,i-self.k,-1),torch.tensor([i+1])])]) #e2_new = torch.cat([e2_new,torch.arange(edges[0,-1]-1,(edges[0,-1]-1)-self.k, -1)],dim=0) #e1, e2 = e1.cuda(), e2.cuda() #edges_new = torch.stack((e1_new,e2_new),0) #print('new edges:',edges_new) graph_sample['edge_index'] = edges num_edges = graph_sample.num_edges edge_attr = torch.ones(num_edges, 1) graph_sample['edge_attr'] = edge_attr if self.distance: graph_sample = self.distance(graph_sample) #if self.self_loops: #graph_sample = self.self_loops(graph_sample) if gt is not None: graph_sample['y'] = gt return graph_sample