Esempio n. 1
0
 def build_graph_batch(self, y, y_p_crops_list, edge_index):
     """
     build the graph nodes and edge
     :param y: features of current input image (N, C, H, W)
     :param y_p_crops_list: list of previous crops, length of N, each with shape (n_crops_i, 64)
     :param edge_index: list with length (N) of tensors, each element of which has a shape of (2, n_edges_i)
     :return:
         graph: torch_geometric.data.Batch object containing the batched graph
         center_inds: list with length (N) of tensors, each element containing the corresponding indices for centers
     """
     data_list = []
     center_inds = []
     offset = 0
     _, C, _, _ = y.shape
     for yi, y_p_crops_i, ei in zip(y, y_p_crops_list, edge_index):
         graph_nodes = torch.cat(
             (y_p_crops_i, yi.reshape(C, -1).T.contiguous()), dim=0)
         data_list.append(gData(x=graph_nodes, edge_index=ei))
         center_inds_i = offset + len(y_p_crops_i) + torch.arange(
             len(yi.reshape(C, -1).T))
         center_inds.append(center_inds_i)
         offset += len(graph_nodes)
     graph = Batch.from_data_list(data_list)
     center_inds = torch.cat(center_inds)
     return graph, center_inds
Esempio n. 2
0
    def __getitem__(self, idx):
        target_cloud = self.dataset.get_velo(idx)
        target_pose = self.dataset.poses[idx]

        source_cloud = self.dataset.get_velo(idx + 1)
        source_pose = self.dataset.poses[idx + 1]

        pose = np.dot(np.linalg.inv(target_pose), source_pose)
        rotation = list(
            Rotation.from_dcm(pose[:3, :3]).as_euler("xyz", degrees=False))
        translation = list(pose[:, -1])
        pose_vect = translation[:-1] + rotation

        s_data = gData(x=torch.from_numpy(source_cloud[:, 2:3]),
                       pos=torch.from_numpy(source_cloud[:, :3]))
        t_data = gData(x=torch.from_numpy(target_cloud[:, 2:3]),
                       pos=torch.from_numpy(target_cloud[:, :3]))

        s_data = s_data if self.transform is None else self.transform(s_data)
        t_data = t_data if self.transform is None else self.transform(t_data)

        return PairData(s_data, t_data,
                        torch.from_numpy(np.array(pose_vect).T))
Esempio n. 3
0
def main():
    batch_size = 16
    num_nodes = 4
    num_in_node_features = 16
    num_out_node_features = 64
    num_in_edge_features = 4
    num_out_edge_features = 8

    # Define batch of example graph
    edge_index = torch.tensor(
        [[0, 1, 2, 0, 3, 2, 3, 0], [1, 0, 0, 2, 2, 3, 0, 3]], dtype=torch.long)

    # Node features
    batch_x = torch.randn((batch_size, num_nodes, num_in_node_features),
                          dtype=torch.float)

    # Edge features -- batch_edge_features has shape: torch.Size([4, 42, 8])
    batch_edge_attr = torch.randn(
        (batch_size, edge_index.size(1), num_in_edge_features),
        dtype=torch.float)

    # Wrap input node and edge features, along with the single edge_index, into a `torch_geometric.data.Batch` instance
    l = []
    for i in range(batch_size):
        l.append(
            gData(x=batch_x[i],
                  edge_index=edge_index,
                  edge_attr=batch_edge_attr[i]))
    batch = gBatch.from_data_list(l)

    # Thus,
    # batch.x          -- shape: torch.Size([28, 16])
    # batch.edge_index -- shape: torch.Size([2, 168])
    # batch.edge_attr  -- shape: torch.Size([168, 8])

    # Define NNConv layer
    nn = tnn.Sequential(
        tnn.Linear(num_in_edge_features, 25), tnn.ReLU(),
        tnn.Linear(25, num_in_node_features * num_out_node_features))
    gconv = gnn.NNConv(in_channels=num_in_node_features,
                       out_channels=num_out_node_features,
                       nn=nn,
                       aggr='mean')

    # Forward pass
    y = gconv(x=batch.x,
              edge_index=batch.edge_index,
              edge_attr=batch.edge_attr)
Esempio n. 4
0
 def build_graph_sample(self, streams, lengths, gt=None):
     lengths = torch.from_numpy(lengths).long()
     batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths)
     batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)])
     slices = batch_slices[1:-1]
     streams = torch.from_numpy(streams)
     l = streams.shape[0]
     graph_sample = gData(x=streams,
                          lengths=lengths,
                          bvec=batch_vec,
                          pos=streams)
     e1 = set(np.arange(0, l - 1)) - set(slices.numpy() - 1)
     e2 = set(np.arange(1, l)) - set(slices.numpy())
     edges = torch.tensor(
         [list(e1) + list(e2), list(e2) + list(e1)], dtype=torch.long)
     graph_sample['edge_index'] = edges
     num_edges = graph_sample.num_edges
     edge_attr = torch.ones(num_edges, 1)
     graph_sample['edge_attr'] = edge_attr
     return graph_sample
Esempio n. 5
0
    def build_graph_sample(self, streams, lengths, gt=None):
        #t0 = time.time()
        #print('time numpy split %f' % (time.time()-t0))
        ### create graph structure
        #sls_lengths = torch.from_numpy(sls_lengths)
        lengths = torch.from_numpy(lengths).long()
        #print('sls lengths:',sls_lengths)
        batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths)
        batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)])
        slices = batch_slices[1:-1]
        streams = torch.from_numpy(streams)
        l = streams.shape[0]
        graph_sample = gData(
            x=streams,
            lengths=lengths,
            #sls_lengths=sls_lengths,
            bvec=batch_vec,
            pos=streams)
        #                     bslices=batch_slices)
        #edges = torch.empty((2, 2*l - 2*n), dtype=torch.long)
        if self.return_edges:
            e1 = set(np.arange(0, l - 1)) - set(slices.numpy() - 1)
            e2 = set(np.arange(1, l)) - set(slices.numpy())
            edges = torch.tensor(
                [list(e1) + list(e2), list(e2) + list(e1)], dtype=torch.long)
            graph_sample['edge_index'] = edges
            num_edges = graph_sample.num_edges
            edge_attr = torch.ones(num_edges, 1)
            graph_sample['edge_attr'] = edge_attr
        if self.distance:
            graph_sample = self.distance(graph_sample)
        #if self.self_loops:
        #graph_sample = self.self_loops(graph_sample)
        if gt is not None:
            graph_sample['y'] = gt

        return graph_sample
    def getitem(self, idx):
        sub = self.subjects[idx]
        sub_dir = os.path.join(self.root_dir, 'sub-%s' % sub)
        T_file = os.path.join(sub_dir, 'sub-%s_var-HCP_full_tract.trk' % (sub))
        label_file = os.path.join(sub_dir, 'sub-%s_var-HCP_labels.pkl' % (sub))
        #T_file = os.path.join(sub_dir, 'All_%s.trk' % (tract_type))
        #label_file = os.path.join(sub_dir, 'All_%s_gt.pkl' % (tract_type))
        T = nib.streamlines.load(T_file, lazy_load=True)
        with open(label_file, 'rb') as f:
            gt = pickle.load(f)
        gt = np.array(gt) if type(gt) == list else gt
        if self.repeat_sampling is not None:
            if len(self.remaining[idx]) == 0:
                self.remaining[idx] = set(np.arange(T.header['nb_streamlines']))
            sample = {'points': np.array(list(self.remaining[idx]))}
            if self.with_gt:
                sample['gt'] = gt[list(self.remaining[idx])]
        else:
            #sample = {'points': np.arange(T.header['nb_streamlines'])}
            #if self.with_gt:
                #sample['gt'] = gt
            sample = {'points': np.arange(T.header['nb_streamlines']), 'gt': gt}

        #t0 = time.time()
        if self.transform:
            sample = self.transform(sample)
        #print('time sampling %f' % (time.time()-t0))
        
        if self.repeat_sampling is not None:
            self.remaining[idx] -= set(sample['points'])
            sample['obj_idxs'] = sample['points'].copy()
            sample['obj_full_size'] = T.header['nb_streamlines']

        #t0 = time.time()
        sample['name'] = T_file.split('/')[-1].rsplit('.', 1)[0]

        n = len(sample['points'])
        #t0 = time.time()
        uniform_size = False
        if uniform_size:
            streams, l_max = load_selected_streamlines_uniform_size(T_file,
                                                    sample['points'].tolist())
            streams.reshape(n, l_max, -1)
            sample['points'] = torch.from_numpy(streams)
        else:
            streams, lengths = load_selected_streamlines(T_file,
                                                    sample['points'].tolist())
        #print('time loading selected streamlines %f' % (time.time()-t0))
        #t0 = time.time()
        #print('time numpy split %f' % (time.time()-t0))
        ### create graph structure
        lengths = torch.from_numpy(lengths)
        batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths)
        batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)])
        slices = batch_slices[1:-1]
        streams = torch.from_numpy(streams)
        l = streams.shape[0]
        graph_sample = gData(x=streams, 
                             lengths=lengths,
                             bvec=batch_vec)
        #                     bslices=batch_slices)
        #edges = torch.empty((2, 2*l - 2*n), dtype=torch.long)
        if self.return_edges:
            e1 = set(np.arange(0,l-1)) - set(slices-1)
            e2 = set(np.arange(1,l)) - set(slices)
            edges = torch.tensor([list(e1)+list(e2),list(e2)+list(e1)],
                            dtype=torch.long)
            graph_sample['edge_index'] = edges
        if self.with_gt:
            graph_sample['y'] = torch.from_numpy(sample['gt'])
        sample['points'] = graph_sample

        #print('time building graph %f' % (time.time()-t0))
        print(len(sample['points']))
        return sample 
Esempio n. 7
0
    def build_graph_sample(self, streams, lengths, gt=None):
        #t0 = time.time()
        #print('time numpy split %f' % (time.time()-t0))
        ### create graph structure
        #sls_lengths = torch.from_numpy(sls_lengths)
        lengths = torch.from_numpy(lengths).long()
        #print('sls lengths:',sls_lengths)
        batch_vec = torch.arange(len(lengths)).repeat_interleave(lengths)
        batch_slices = torch.cat([torch.tensor([0]), lengths.cumsum(dim=0)])
        slices = batch_slices[1:-1]
        streams = torch.from_numpy(streams)
        l = streams.shape[0]
        graph_sample = gData(
            x=streams,
            lengths=lengths,
            #sls_lengths=sls_lengths,
            bvec=batch_vec,
            pos=streams)
        #                     bslices=batch_slices)
        #edges = torch.empty((2, 2*l - 2*n), dtype=torch.long)
        if self.return_edges:
            e1 = set(np.arange(0, l - 1)) - set(slices.numpy() - 1)
            e2 = set(np.arange(1, l)) - set(slices.numpy())
            edges = torch.tensor(
                [list(e1) + list(e2), list(e2) + list(e1)], dtype=torch.long)
            #print('old edges:', edges)
            #print(edges[0,-1])
            #print(edges[0,:int(edges.shape[1]/2)+2])
            #e1_new = torch.repeat_interleave(edges[0,:int(edges.shape[1]/2)],self.k)
            #e1_new=torch.cat((e1_new,torch.repeat_interleave(edges[0,-1],self.k)))
            #e2_new = torch.tensor([],dtype=torch.long)
            #for i in list(edges[0,:int(edges.shape[1]/2)]):
            #if i == 0:
            #e2_new = torch.cat([e2_new,torch.arange(i+1,self.k+1)],dim=0)
            #if i==1:
            #e2_new = torch.cat([e2_new,torch.cat([torch.tensor([i-1]),torch.arange(i+1,self.k+1)],dim=0)])
            #if i<self.k/2 and i>1:
            #e2_new = torch.cat([e2_new,torch.cat([torch.arange(0,i),torch.arange(i+1,i+(self.k-i)+1)],dim=0)])
            #if i>=self.k/2 and i!=edges[0,int(edges.shape[1]/2)-1]:
            #if i+self.k/2 > edges[0,-1]:
            #e2_new = torch.cat([e2_new,torch.cat([torch.arange(i-(self.k-(edges[0,-1]-i)),i),torch.arange(i+1,edges[0,-1]+1)],dim=0)])
            #else:
            #e = torch.cat([torch.arange(i-self.k/2,i),torch.arange(i+1,i+self.k/2+1)])
            #e = e.long()
            #print(e)
            #e2_new = torch.cat([e2_new,e])
            #e2 = torch.cat([e2,torch.cat([torch.arange(i-self.k/2,i),torch.arange(i+1,i+self.k/2+1)],dim=0)])
            #if i==edges[0,int(edges.shape[1]/2)-1]:
            #e2_new = torch.cat([e2_new,torch.cat([torch.arange(i-1,i-self.k,-1),torch.tensor([i+1])])])
            #e2_new = torch.cat([e2_new,torch.arange(edges[0,-1]-1,(edges[0,-1]-1)-self.k, -1)],dim=0)
            #e1, e2 = e1.cuda(), e2.cuda()
            #edges_new = torch.stack((e1_new,e2_new),0)
            #print('new edges:',edges_new)
            graph_sample['edge_index'] = edges
            num_edges = graph_sample.num_edges
            edge_attr = torch.ones(num_edges, 1)
            graph_sample['edge_attr'] = edge_attr
        if self.distance:
            graph_sample = self.distance(graph_sample)
        #if self.self_loops:
        #graph_sample = self.self_loops(graph_sample)
        if gt is not None:
            graph_sample['y'] = gt

        return graph_sample