class RandomNodeSampler(torch.utils.data.DataLoader): r"""A data loader that randomly samples nodes within a graph and returns their induced subgraph. .. note:: For an example of using :obj:`RandomNodeSampler`, see `examples/ogbn_proteins_deepgcn.py <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ ogbn_proteins_deepgcn.py>`_. Args: data (torch_geometric.data.Data): The graph data object. num_parts (int): The number of partitions. shuffle (bool, optional): If set to :obj:`True`, the data is reshuffled at every epoch (default: :obj:`False`). **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`. """ def __init__(self, data, num_parts: int, shuffle: bool = False, **kwargs): assert data.edge_index is not None self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=torch.arange( self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None super(RandomNodeSampler, self).__init__(self, batch_size=1, sampler=RandomIndexSampler( self.N, num_parts, shuffle), collate_fn=self.__collate__, **kwargs) def __getitem__(self, idx): return idx def __collate__(self, node_idx): node_idx = node_idx[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) adj, _ = self.adj.saint_subgraph(node_idx) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if isinstance(item, Tensor) and item.size(0) == self.N: data[key] = item[node_idx] elif isinstance(item, Tensor) and item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item return data
def _filter(data, node_idx): """ presumably data_n_id and new_n_id are sorted """ new_data = Data() N = data.num_nodes E = data.edge_index.size(1) adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=data.edge_attr, sparse_sizes=(N, N)) new_adj, edge_idx = adj.saint_subgraph(node_idx) row, col, value = new_adj.coo() for key, item in data: if item.size(0) == N: new_data[key] = item[node_idx] elif item.size(0) == E: new_data[key] = item[edge_idx] else: new_data[key] = item new_data.edge_index = torch.stack([row, col], dim=0) new_data.num_nodes = len(node_idx) new_data.edge_attr = value return new_data
class RandomNodeSampler(torch.utils.data.DataLoader): def __init__(self, data, num_parts: int, shuffle: bool = False, **kwargs): """ Randomly sample nodes from a graph :param data: adjacency matrix :param num_parts: batch number :param shuffle: shuffle the index """ assert data.edge_index is not None self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=torch.arange( self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None super(RandomNodeSampler, self).__init__(self, batch_size=1, sampler=RandomIndexSampler( self.N, num_parts, shuffle), collate_fn=self.__collate__, **kwargs) def __getitem__(self, idx): return idx def __collate__(self, node_idx): node_idx = node_idx[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) adj, _ = self.adj.saint_subgraph(node_idx) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item return data
class GraphSAINTSampler(object): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`torch_geometric.data.GraphSAINTNodeSampler`, :class:`torch_geometric.data.GraphSAINTEdgeSampler` and :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/ blob/master/examples/graph_saint.py>`_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch to load. num_steps (int, optional): The number of iterations. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`50`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) num_workers (int, optional): How many subprocesses to use for data sampling. :obj:`0` means that the data will be sampled in the main process. (default: :obj:`0`) """ def __init__(self, data, batch_size, num_steps=1, sample_coverage=50, save_dir=None, num_workers=0): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=data.edge_attr, sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None self.data.edge_attr = None self.batch_size = batch_size self.num_steps = num_steps self.sample_coverage = sample_coverage self.num_workers = num_workers self.__count__ = 0 if self.num_workers > 0: self.__sample_queue__ = Queue() self.__sample_workers__ = [] for _ in range(self.num_workers): worker = Process(target=self.__put_sample__, args=(self.__sample_queue__, )) worker.daemon = True worker.start() self.__sample_workers__.append(worker) path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: torch.save((self.node_norm, self.edge_norm), path) if self.num_workers > 0: self.__data_queue__ = Queue() self.__data_workers__ = [] for _ in range(self.num_workers): worker = Process(target=self.__put_data__, args=(self.__data_queue__, )) worker.daemon = True worker.start() self.__data_workers__.append(worker) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __sample_nodes__(self, num_examples): raise NotImplementedError def __sample__(self, num_examples): node_samples = self.__sample_nodes__(num_examples) samples = [] for node_idx in node_samples: node_idx = node_idx.unique() adj, edge_idx = self.adj.saint_subgraph(node_idx) samples.append((node_idx, edge_idx, adj)) return samples def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) num_samples = total_sampled_nodes = 0 pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') while total_sampled_nodes < self.N * self.sample_coverage: num_sampled_nodes = 0 if self.num_workers > 0: for _ in range(200): node_idx, edge_idx, _ = self.__sample_queue__.get() node_count[node_idx] += 1 edge_count[edge_idx] += 1 num_sampled_nodes += node_idx.size(0) else: samples = self.__sample__(200) for node_idx, edge_idx, _ in samples: node_count[node_idx] += 1 edge_count[edge_idx] += 1 num_sampled_nodes += node_idx.size(0) pbar.update(num_sampled_nodes) total_sampled_nodes += num_sampled_nodes num_samples += 200 pbar.close() row, col, _ = self.adj.coo() edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm def __get_data_from_sample__(self, sample): node_idx, edge_idx, adj = sample data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, value = adj.coo() data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value for key, item in self.data: if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __put_sample__(self, queue): while True: sample = self.__sample__(1)[0] queue.put(sample) def __put_data__(self, queue): while True: sample = self.__sample_queue__.get() data = self.__get_data_from_sample__(sample) queue.put(data) def __next__(self): if self.__count__ < len(self): self.__count__ += 1 if self.num_workers > 0: data = self.__data_queue__.get() else: sample = self.__sample__(1)[0] data = self.__get_data_from_sample__(sample) return data else: raise StopIteration def __len__(self): return self.num_steps def __iter__(self): self.__count__ = 0 return self
class GraphSAINTSampler(torch.utils.data.DataLoader): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`torch_geometric.data.GraphSAINTNodeSampler`, :class:`torch_geometric.data.GraphSAINTEdgeSampler` and :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/ blob/master/examples/graph_saint.py>`_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch. num_steps (int, optional): The number of iterations per epoch. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`0`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any pre-processing progress. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.num_steps = num_steps self.__batch_size__ = batch_size self.sample_coverage = sample_coverage self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=torch.arange( self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None super(GraphSAINTSampler, self).__init__(self, batch_size=1, collate_fn=self.__collate__, **kwargs) if self.sample_coverage > 0: path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def __sample_nodes__(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self.__sample_nodes__(self.__batch_size__).unique() adj, _ = self.adj.saint_subgraph(node_idx) return node_idx, adj def __collate__(self, data_list): assert len(data_list) == 1 node_idx, adj = data_list[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if isinstance(item, torch.Tensor) and item.size(0) == self.N: data[key] = item[node_idx] elif isinstance(item, torch.Tensor) and item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item if self.sample_coverage > 0: data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) loader = torch.utils.data.DataLoader(self, batch_size=200, collate_fn=lambda x: x, num_workers=self.num_workers) if self.log: # pragma: no cover pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: for data in loader: for node_idx, adj in data: edge_idx = adj.storage.value() node_count[node_idx] += 1 edge_count[edge_idx] += 1 total_sampled_nodes += node_idx.size(0) if self.log: # pragma: no cover pbar.update(node_idx.size(0)) num_samples += self.num_steps if self.log: # pragma: no cover pbar.close() row, _, edge_idx = self.adj.coo() t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row]) edge_norm = (t / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm
class Pruner(object): def __init__(self, edge_index, split_idx, prune_set='train', ratio=0.9): self.N = N = int(edge_index.max() + 1) self.E = edge_index.size(1) self.edge_index = edge_index self.adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.arange(edge_index.size(1)), sparse_sizes=(N, N), is_sorted=False) if prune_set == 'train': self.train_idx = split_idx['train'] else: self.train_idx = torch.cat( [split_idx['train'], split_idx['valid']]) subadj, _ = self.adj.saint_subgraph(self.train_idx) # subadj = self.adj.to_dense()[self.train_idx][:,self.train_idx].view(-1) _, _, e_idx = subadj.coo() self.train_e_idx = e_idx.squeeze().long() self.train_edge_index = self.edge_index[:, self.train_e_idx] self.rest_e_idx = torch.LongTensor( list(set(range(self.E)) - set(self.train_e_idx.tolist()))) self.ratio = ratio self.loss = torch.zeros(N) self.times = 0 def prune(self, naive=False): i = self.times diff_loss = torch.abs(self.loss[self.train_edge_index[0]] - self.loss[self.train_edge_index[1]]) # print(diff_loss.nonzero().size()) # print(int(len(diff_loss)*ratio)) _, mask1 = torch.topk(diff_loss, int(len(diff_loss) * self.ratio), largest=False) newE = self.train_edge_index.size(1) mask2 = torch.randperm(newE)[:int(newE * self.ratio)] torch.save(self.train_edge_index, f'./savept/pre_prune_edges_{self.ratio ** i:.4f}.pt') torch.save( mask1, f'./savept/smart_mask_prune_edges_{self.ratio ** i:.4f}.pt') torch.save( mask2, f'./savept/naive_mask_prune_edges_{self.ratio ** i:.4f}.pt') torch.save(self.loss, f'./savept/loss_{self.ratio ** i:.4f}.pt') self.times += 1 if naive == False: mask = mask1 else: mask = mask2 self.train_e_idx = self.train_e_idx[mask] self.train_edge_index = self.train_edge_index[:, mask] self.edge_index = self.edge_index[:, torch.cat([ self.train_e_idx, self.rest_e_idx ])] self.train_e_idx = torch.arange(self.train_e_idx.size(0)) self.rest_e_idx = torch.arange( self.train_e_idx.size(0), self.train_e_idx.size(0) + self.rest_e_idx.size(0)) self.E = self.edge_index.size(1) print( f'****************trainE : {self.train_e_idx.size(0)} ,restE:{self.rest_e_idx.size(0)}, totalE:{self.E}' ) def update_loss(self, loss): self.loss = loss
class MySAINTSampler(object): r"""A new random-walk sampler for GraphSAINT that samples initial nodes by iterating over node permutations. The benefit is that we can leverage this sampler for subgraph-based inference. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The number of walks to sample per batch. walk_length (int): The length of each random walk. sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`50`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any progress. (default: :obj:`True`) """ def __init__(self, data, batch_size, sample_type='random_walk', walk_length=2, sample_coverage=50, save_dir=None, log=True): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data assert sample_type in ('node', 'random_walk') self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=data.edge_attr, sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None self.data.edge_attr = None self.sample_type = sample_type self.batch_size = batch_size # self.num_steps = num_steps self.walk_length = walk_length self.sample_coverage = sample_coverage self.log = log path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __sample_nodes__(self): """Sampling initial nodes by iterating over the random permutation of nodes""" tmp_map = torch.arange(self.N, dtype=torch.long) all_n_id = torch.randperm(self.N, dtype=torch.long) node_samples = [] for s_id in range(0, self.N, self.batch_size): init_n_id = all_n_id[s_id:s_id+self.batch_size] # [batch_size] if self.sample_type == 'random_walk': n_id = self.adj.random_walk(init_n_id, self.walk_length) # [batch_size, walk_length+1] n_id = n_id.flatten().unique() # [num_nodes_in_subgraph] tmp_map[n_id] = torch.arange(n_id.size(0), dtype=torch.long) res_n_id = tmp_map[init_n_id] elif self.sample_type == 'node': n_id = init_n_id res_n_id = torch.arange(n_id.size(0), dtype=torch.long) else: raise ValueError('Unsupported value type {}'.format(self.sample_type)) node_samples.append((n_id, res_n_id)) return node_samples def __sample__(self, num_epoches): samples = [] for _ in range(num_epoches): node_samples = self.__sample_nodes__() for n_id, res_n_id in node_samples: adj, e_id = self.adj.saint_subgraph(n_id) samples.append((n_id, e_id, adj, res_n_id)) return samples def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) if self.log: pbar = tqdm(total=self.sample_coverage) pbar.set_description('GraphSAINT Normalization') num_samples = len(self) * self.sample_coverage for _ in range(self.sample_coverage): samples = self.__sample__(1) for n_id, e_id, _, _ in samples: node_count[n_id] += 1 edge_count[e_id] += 1 if self.log: pbar.update(1) row, col, _ = self.adj.coo() edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / (node_count * self.N) return node_norm, edge_norm def __get_data_from_sample__(self, sample): n_id, e_id, adj, res_n_id = sample data = self.data.__class__() data.num_nodes = n_id.size(0) row, col, value = adj.coo() data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value for key, item in self.data: if item.size(0) == self.N: data[key] = item[n_id] elif item.size(0) == self.E: data[key] = item[e_id] else: data[key] = item data.node_norm = self.node_norm[n_id] data.edge_norm = self.edge_norm[e_id] data.n_id = n_id data.res_n_id = res_n_id return data def __len__(self): return (self.N + self.batch_size-1) // self.batch_size def __iter__(self): for sample in self.__sample__(1): data = self.__get_data_from_sample__(sample) yield data
class GraphSAINTSampler(torch.utils.data.DataLoader): def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.num_steps = num_steps self.__batch_size__ = batch_size self.sample_coverage = sample_coverage self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None super(GraphSAINTSampler, self).__init__(self, batch_size=1, collate_fn=self.__collate__, **kwargs) if self.sample_coverage > 0: path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def __sample_nodes__(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self.__sample_nodes__(self.__batch_size__).unique() adj, _ = self.adj.saint_subgraph(node_idx) return node_idx, adj def __collate__(self, data_list): assert len(data_list) == 1 node_idx, adj = data_list[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if isinstance(item, torch.Tensor) and item.size(0) == self.N: data[key] = item[node_idx] elif isinstance(item, torch.Tensor) and item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item if self.sample_coverage > 0: data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) loader = torch.utils.data.DataLoader(self, batch_size=200, collate_fn=lambda x: x, num_workers=self.num_workers) if self.log: pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: for data in loader: for node_idx, adj in data: edge_idx = adj.storage.value() node_count[node_idx] += 1 edge_count[edge_idx] += 1 total_sampled_nodes += node_idx.size(0) if self.log: pbar.update(node_idx.size(0)) num_samples += 200 if self.log: pbar.close() row, _, edge_idx = self.adj.coo() t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row]) edge_norm = (t / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm
class RandomNodeSampler(torch.utils.data.DataLoader): r"""A data loader that randomly samples nodes within a graph and returns their induced subgraph. .. note:: For an example of using :obj:`RandomNodeSampler`, see `examples/ogbn_proteins_deepgcn.py <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ ogbn_proteins_deepgcn.py>`_. Args: data (torch_geometric.data.Data): The graph data object. num_parts (int): The number of partitions. shuffle (bool, optional): If set to :obj:`True`, the data is reshuffled at every epoch (default: :obj:`False`). **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`. """ def __init__(self, data, num_parts: int, shuffle: bool = False, split_idx=None, num_edges=None, prune=False, prune_set='train', **kwargs): assert data.edge_index is not None self.N = N = data.num_nodes if num_edges != None: self.E = num_edges else: self.E = data.num_edges self.split_idx = split_idx self.adj = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) # self.data.edge_index = edge_index super(RandomNodeSampler, self).__init__( self, batch_size=1, sampler=RandomIndexSampler(self.N, num_parts, shuffle), collate_fn=self.__collate__, **kwargs) if prune == True: self.edge_index = data.edge_index if prune_set == 'train': self.train_idx = self.split_idx['train'] else: self.train_idx = torch.cat([self.split_idx['train'], self.split_idx['valid']]) subadj, _ = self.adj.saint_subgraph(self.train_idx) # subadj = self.adj.to_dense()[self.train_idx][:,self.train_idx].view(-1) _,_,e_idx = subadj.coo() self.train_e_idx = e_idx.squeeze().long() self.train_edge_index = self.edge_index[:, self.train_e_idx] self.rest_e_idx = torch.LongTensor(list(set(range(self.E)) - set(self.train_e_idx.tolist()))) self.times = 0 def __getitem__(self, idx): return idx def prune(self, loss, ratio=0, method='ada',globe=False, savept=False): if globe == False: if method=='ada': diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]]) print(diff_loss.size(0)) #print(diff_loss.nonzero().size()) # print(int(len(diff_loss)*ratio)) _, mask = torch.topk(diff_loss, int(len(diff_loss)*ratio), largest=False) elif method=='random': newE =self.train_edge_index.size(1) mask = torch.randperm(newE)[:int(newE*ratio)] elif method == 'naive': degrees = scatter(torch.ones(self.train_edge_index.size(1)), self.train_edge_index[0]) thold = (degrees.max()*ratio).long() prunemask = (degrees>thold).nonzero() mymask = torch.ones(self.train_edge_index.size(1)) taway = [] for pru in prunemask: p_eid = (self.train_edge_index[0]==pru.squeeze()).nonzero().squeeze() p = p_eid[torch.randperm(p_eid.size(0))][thold:] taway.append(p) nonmask = torch.cat(taway) mask = torch.ones(self.train_edge_index.size(1), dtype=torch.bool) mask[nonmask] = False self.train_e_idx = self.train_e_idx[mask] self.train_edge_index = self.train_edge_index[:, mask] # self.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])] self.edge_index = self.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])] # print(self.data.edge_attr.size(), self.data.edge_index.size()) print('len',len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.edge_index.size()) self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])] self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])] self.train_e_idx = torch.arange(self.train_e_idx.size(0)) self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0)) self.E = self.edge_index.size(1) self.adj = SparseTensor( row=self.edge_index[0], col=self.edge_index[1], value=torch.arange(self.E, device=self.edge_index.device), sparse_sizes=(self.N, self.N)) # if savept==True: # torch.save(self.train_edge_index, f'./savept/edge_index_protein.pt') # # torch.save(mask1, f'./savept/p_smart_mask_prune_edges_{ratio ** i:.4f}.pt') # # torch.save(mask2, f'./savept/p_naive_mask_prune_edges_{ratio ** i:.4f}.pt') # torch.save(loss, f'./savept/p_loss_protein.pt') else: if method == 'random': mask = torch.randperm(self.E)[:int(self.E*ratio)] elif method == 'naive': degrees = scatter(torch.ones(self.data.edge_index.size(1)), self.data.edge_index[0]) thold = (degrees.max()*ratio).long() prunemask = (degrees>thold).nonzero().squeeze() mymask = torch.ones(self.data.edge_index.size(1)) taway = [] for pru in prunemask: p_eid = (self.data.edge_index[0]==pru).nonzero().squeeze() p = p_eid[torch.randperm(p_eid.size(0))][thold:] taway.append(p) nonmask = torch.cat(taway) mask = torch.ones(self.data.edge_index.size(1), dtype=torch.bool) mask[nonmask] = False self.data.edge_index = self.data.edge_index[:,mask] self.data.edge_attr = self.data.edge_attr[mask] self.E = self.data.edge_index.size(1) self.adj = SparseTensor( row=self.data.edge_index[0], col=self.data.edge_index[1], value=torch.arange(self.E, device=self.data.edge_index.device), sparse_sizes=(self.N, self.N)) # def prune(self, loss, ratio, naive=False,cutoff=False, savept=False, random=False): # #p_loss = loss[self.train_idx] # i = self.times # if naive== False: # diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]]) # # print(diff_loss.nonzero().size()) # # print(int(len(diff_loss)*ratio)) # _, mask = torch.topk(diff_loss, int(len(diff_loss) * ratio), largest=False) # else: # newE =self.train_edge_index.size(1) # mask = torch.randperm(newE)[:int(newE * ratio)] # if savept==True: # torch.save(self.train_edge_index, f'./savept/edge_index_protein.pt') # # torch.save(mask1, f'./savept/p_smart_mask_prune_edges_{ratio ** i:.4f}.pt') # # torch.save(mask2, f'./savept/p_naive_mask_prune_edges_{ratio ** i:.4f}.pt') # torch.save(loss, f'./savept/p_loss_protein.pt') # # self.train_edge_index = self.train_edge_index[:,mask] # # edge_index = torch.cat([self.train_edge_index,self.rest_edge_index], dim=1) # # self.data.edge_index = edge_index # self.train_e_idx = self.train_e_idx[mask] # self.train_edge_index = self.train_edge_index[:, mask] # self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])] # self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])] # # print(self.data.edge_attr.size(), self.data.edge_index.size()) # self.train_e_idx = torch.arange(self.train_e_idx.size(0)) # self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0)) # # print(len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.data.edge_index.size()) # self.E = self.data.num_edges # self.adj = SparseTensor( # row=self.data.edge_index[0], col=self.data.edge_index[1], # value=torch.arange(self.E, device=self.data.edge_index.device), # sparse_sizes=(self.N, self.N)) def __collate__(self, node_idx): node_idx = node_idx[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) data.n_id = node_idx adj, _ = self.adj.saint_subgraph(node_idx) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E and key != 'edge_index': data[key] = item[edge_idx] elif key!= 'edge_index': data[key] = item return data
class GraphSAINTSampler(torch.utils.data.DataLoader): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`torch_geometric.data.GraphSAINTNodeSampler`, :class:`torch_geometric.data.GraphSAINTEdgeSampler` and :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/ blob/master/examples/graph_saint.py>`_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch. num_steps (int, optional): The number of iterations per epoch. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`0`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any pre-processing progress. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, prune = False, prune_set='train', prune_type='adaptive', **kwargs): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.num_steps = num_steps self.__batch_size__ = batch_size self.sample_coverage = sample_coverage self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) super(GraphSAINTSampler, self).__init__(self, batch_size=1, collate_fn=self.__collate__, **kwargs) if prune == True: if prune_set == 'train': self.train_idx = self.data.train_mask.nonzero(as_tuple=False).squeeze() else: self.train_idx = (self.data.train_mask + self.data.valid_mask).nonzero(as_tuple=False).squeeze() subadj, _ = self.adj.saint_subgraph(self.train_idx) # subadj = self.adj.to_dense()[self.train_idx][:,self.train_idx].view(-1) _,_,e_idx = subadj.coo() self.train_e_idx = e_idx.squeeze().long() self.train_edge_index = self.data.edge_index[:, self.train_e_idx] self.rest_e_idx = torch.LongTensor(list(set(range(self.E)) - set(self.train_e_idx.tolist()))) # if self.sample_coverage > 0: # path = osp.join(save_dir or '', self.__filename__) # if save_dir is not None and osp.exists(path): # pragma: no cover # self.node_norm, self.edge_norm = torch.load(path) # else: # self.node_norm, self.edge_norm = self.__compute_norm__() # if save_dir is not None: # pragma: no cover # torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def __sample_nodes__(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self.__sample_nodes__(self.__batch_size__).unique() adj, _ = self.adj.saint_subgraph(node_idx) return node_idx, adj def prune(self, loss, ratio=0, method='ada',glob=False): if glob == False: if method=='ada': diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]]) #print(diff_loss.nonzero().size()) # print(int(len(diff_loss)*ratio)) _, mask = torch.topk(diff_loss, int(len(diff_loss)*ratio), largest=False) elif method=='random': newE =self.train_edge_index.size(1) mask = torch.randperm(newE)[:int(newE*ratio)] elif method == 'naive': degrees = scatter(torch.ones(self.train_edge_index.size(1)), self.train_edge_index[0]) thold = (degrees.max()*ratio).long() prunemask = (degrees>thold).nonzero() # print(prunemask) mymask = torch.ones(self.train_edge_index.size(1)) taway = [] for pru in prunemask: p_eid = (self.train_edge_index[0]==pru).nonzero().squeeze() p = p_eid[torch.randperm(p_eid.size(0))][thold:] taway.append(p) nonmask = torch.cat(taway) mask = torch.ones(self.train_edge_index.size(1), dtype=torch.bool) mask[nonmask] = False # print(mask.size()) # mask = (diff_loss <= threshold) # self.train_edge_index = self.train_edge_index[:,mask] # edge_index = torch.cat([self.train_edge_index,self.rest_edge_index], dim=1) # self.data.edge_index = edge_index self.train_e_idx = self.train_e_idx[mask] self.train_edge_index = self.train_edge_index[:, mask] # print('train', self.train_edge_index.size()) self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])] self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])] # print(self.data.edge_attr.size(), self.data.edge_index.size()) self.train_e_idx = torch.arange(self.train_e_idx.size(0)) self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0)) # print(len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.data.edge_index.size()) self.E = self.data.num_edges self.adj = SparseTensor( row=self.data.edge_index[0], col=self.data.edge_index[1], value=torch.arange(self.E, device=self.data.edge_index.device), sparse_sizes=(self.N, self.N)) # def prune(self, loss, ratio=0, naive=False): # if naive == False: # diff_loss = torch.abs(loss[self.train_edge_index[0]] - loss[self.train_edge_index[1]]) # print(diff_loss.nonzero().size()) # # print(int(len(diff_loss)*ratio)) # # print(diff_loss) # _, mask = torch.topk(diff_loss, int(diff_loss.size(0)*ratio), largest=False) # else: # newE =self.train_edge_index.size(1) # mask = torch.randperm(newE)[:int(newE*ratio)] # # print(mask.size()) # # mask = (diff_loss <= threshold) # # self.train_edge_index = self.train_edge_index[:,mask] # # edge_index = torch.cat([self.train_edge_index,self.rest_edge_index], dim=1) # # self.data.edge_index = edge_index # self.train_e_idx = self.train_e_idx[mask] # self.train_edge_index = self.train_edge_index[:, mask] # # print('train', self.train_edge_index.size()) # self.data.edge_attr = self.data.edge_attr[torch.cat([self.train_e_idx, self.rest_e_idx])] # self.data.edge_index = self.data.edge_index[:,torch.cat([self.train_e_idx, self.rest_e_idx])] # # print(self.data.edge_attr.size(), self.data.edge_index.size()) # self.train_e_idx = torch.arange(self.train_e_idx.size(0)) # self.rest_e_idx = torch.arange(self.train_e_idx.size(0),self.train_e_idx.size(0) + self.rest_e_idx.size(0)) # # print(len(self.train_e_idx),len(self.rest_e_idx), self.train_edge_index.size(),self.data.edge_index.size()) # self.E = self.data.num_edges # self.adj = SparseTensor( # row=self.data.edge_index[0], col=self.data.edge_index[1], # value=torch.arange(self.E, device=self.data.edge_index.device), # sparse_sizes=(self.N, self.N)) def __collate__(self, data_list): assert len(data_list) == 1 node_idx, adj = data_list[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) data.node_idx = node_idx row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E and key != 'edge_index': data[key] = item[edge_idx] elif key!= 'edge_index': data[key] = item # if self.sample_coverage > 0: # data.node_norm = self.node_norm[node_idx] # data.edge_norm = self.edge_norm[edge_idx] return data
class MyNeighborSampler(NeighborSampler): def __init__(self, data, size, num_hops, batch_size=1, shuffle=False, drop_last=False, bipartite=True, add_self_loops=False, flow='source_to_target', use_negative_sampling=False, neg_sample_ratio=None): self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)) if use_negative_sampling: assert neg_sample_ratio is not None assert neg_sample_ratio > 0.0 self.use_negative_sampling = use_negative_sampling self.num_proc = mp.cpu_count() self.neg_sample_ratio = neg_sample_ratio super().__init__(data, size, num_hops, batch_size, shuffle, drop_last, bipartite, add_self_loops, flow) def __produce_subgraph__(self, b_id): r"""Produces a :obj:`Data` object holding the subgraph data for a given mini-batch :obj:`b_id`.""" n_ids = [b_id] e_ids = [] edge_indices = [] for l in range(self.num_hops): e_id = neighbor_sampler(n_ids[-1], self.cumdeg, self.size[l]) n_id = self.edge_index_j.index_select(0, e_id) n_id = n_id.unique(sorted=False) n_ids.append(n_id) e_ids.append(self.e_assoc.index_select(0, e_id)) edge_index = self.data.edge_index.index_select(1, e_ids[-1]) edge_indices.append(edge_index) n_id = torch.unique(torch.cat(n_ids, dim=0), sorted=False) self.tmp[n_id] = torch.arange(n_id.size(0)) e_id = torch.cat(e_ids, dim=0) edge_index = self.tmp[torch.cat(edge_indices, dim=1)] num_nodes = n_id.size(0) idx = edge_index[0] * num_nodes + edge_index[1] idx, inv = idx.unique(sorted=False, return_inverse=True) edge_index = torch.stack([idx / num_nodes, idx % num_nodes], dim=0) e_id = e_id.new_zeros(edge_index.size(1)).scatter_(0, inv, e_id) # n_id: original ID of nodes in the whole sub-graph. # b_id: original ID of nodes in the training graph. # sub_b_id: sampled ID of nodes in the training graph. # Get full-subgraph for negative sampling. # Will be deleted at __call__. if self.use_negative_sampling: adj, _ = self.adj.saint_subgraph(n_id) row, col, edge_idx = adj.coo() full_edge_index = torch.stack([row, col], dim=0) else: full_edge_index = None return Data(edge_index=edge_index, e_id=e_id, n_id=n_id, b_id=b_id, sub_b_id=self.tmp[b_id], full_edge_index=full_edge_index, num_nodes=num_nodes) def __call__(self, subset=None): r"""Returns a generator of :obj:`DataFlow` that iterates over the nodes in :obj:`subset` in a mini-batch fashion. Args: subset (LongTensor or BoolTensor, optional): The initial nodes to propagate messages to. If set to :obj:`None`, will iterate over all nodes in the graph. (default: :obj:`None`) """ if self.bipartite: produce = self.__produce_bipartite_data_flow__ else: produce = self.__produce_subgraph__ if not self.use_negative_sampling: for n_id in self.__get_batches__(subset): yield produce(n_id) else: yield from self.__call_with_negatives__(subset, produce) def __call_with_negatives__(self, subset, produce): for n_id_group in grouper(self.__get_batches__(subset), self.num_proc): # print("produce start ~ ", end=""); t0 = time.time() produced_data_list = [produce(n_id) for n_id in n_id_group] # print("end: {}".format(time.time() - t0)) ns_generator = fetch_and_generate( iters=[(p_data.full_edge_index.numpy(), p_data.n_id.size(0), int(self.neg_sample_ratio * p_data.edge_index.size(1))) for p_data in produced_data_list], func=negative_sampling_numpy, num_proc=self.num_proc, ) # print("yield start ~ ", end=""); t0 = time.time() for p_data, ns in zip(produced_data_list, ns_generator): del p_data.full_edge_index p_data.neg_edge_index = torch.as_tensor(ns).long() yield p_data