def test_graph_store(): graph_store = MyGraphStore() edge_index = torch.LongTensor([[0, 1], [1, 2]]) adj = SparseTensor(row=edge_index[0], col=edge_index[1]) def assert_equal_tensor_tuple(expected, actual): assert len(expected) == len(actual) for i in range(len(expected)): assert torch.equal(expected[i], actual[i]) # We put all three tensor types: COO, CSR, and CSC, and we get them back # to confirm that `GraphStore` works as intended. coo = adj.coo()[:-1] csr = adj.csr()[:-1] csc = adj.csc()[-2::-1] # (row, colptr) # Put: graph_store['edge', EdgeLayout.COO] = coo graph_store['edge', 'csr'] = csr graph_store['edge', 'csc'] = csc # Get: assert_equal_tensor_tuple(coo, graph_store['edge', 'coo']) assert_equal_tensor_tuple(csr, graph_store['edge', 'csr']) assert_equal_tensor_tuple(csc, graph_store['edge', 'csc']) # Get attrs: edge_attrs = graph_store.get_all_edge_attrs() assert len(edge_attrs) == 3 with pytest.raises(KeyError): _ = graph_store['edge_2', 'coo']
def forward(self, x: Tensor, M: SparseTensor, batch: OptTensor = None): """""" if batch is None: batch = x.new_zeros(x.size(0), dtype=torch.long) row, col, edge_weight = M.coo() score1 = (x * self.p).sum(dim=-1) score2 = scatter_add(edge_weight, col, dim=0, dim_size=x.size(0)) score = self.beta[0] * score1 + self.beta[1] * score2 if self.min_score is None: score = self.nonlinearity(score) else: score = softmax(score, batch) perm = topk(score, self.ratio, batch, self.min_score) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x edge_index = torch.stack([col, row], dim=0) edge_index, edge_attr = filter_adj(edge_index, edge_weight, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch[perm], perm, score[perm]
def forward(self, x, edge_index, edge_weight=None, batch=None): """""" N = x.size(0) edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, fill_value=1, num_nodes=N) if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x x_pool = x if self.GNN is not None: x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index, edge_weight=edge_weight) x_pool_j = x_pool[edge_index[0]] x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max') x_q = self.lin(x_q)[edge_index[1]] score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1) score = F.leaky_relu(score, self.negative_slope) score = softmax(score, edge_index[1], num_nodes=N) # Sample attention coefficients stochastically. score = F.dropout(score, p=self.dropout, training=self.training) v_j = x[edge_index[0]] * score.view(-1, 1) x = scatter(v_j, edge_index[1], dim=0, reduce='add') # Cluster selection. fitness = self.gnn_score(x, edge_index).sigmoid().view(-1) perm = topk(fitness, self.ratio, batch) x = x[perm] * fitness[perm].view(-1, 1) batch = batch[perm] # Graph coarsening. row, col = edge_index A = SparseTensor(row=row, col=col, value=edge_weight, sparse_sizes=(N, N)) S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N)) S = S[:, perm] A = S.t() @ A @ S if self.add_self_loops: A = A.fill_diag(1.) else: A = A.remove_diag() row, col, edge_weight = A.coo() edge_index = torch.stack([row, col], dim=0) return x, edge_index, edge_weight, batch, perm
def graclus_coarsen(A: SparseTensor, level: int): row, col, wgt = A.coo() coarsen_cluster = [] for i in range(level): cluster = graclus_cluster(row, col, wgt) _, cluster = cluster.unique(return_inverse=True) (row, col), wgt = pool_edge(cluster, torch.stack([row, col]), wgt) coarsen_cluster.append(cluster.cpu().numpy()) return row, col, wgt, coarsen_cluster
def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" if x.dim() > 1: assert (x.sum(dim=-1) == 1).sum() == x.size(0) x = x.argmax(dim=-1) # one-hot -> integer. assert x.dtype == torch.long adj_t = edge_index if not isinstance(adj_t, SparseTensor): adj_t = SparseTensor(row=edge_index[1], col=edge_index[0], sparse_sizes=(x.size(0), x.size(0))) out = [] _, col, _ = adj_t.coo() deg = adj_t.storage.rowcount().tolist() for node, neighbors in zip(x.tolist(), x[col].split(deg)): idx = hash(tuple([node] + neighbors.sort()[0].tolist())) if idx not in self.hashmap: self.hashmap[idx] = len(self.hashmap) out.append(self.hashmap[idx]) return torch.tensor(out, device=x.device)
def __dropout_adj__(self, sparse_adj: SparseTensor, dropout_adj_prob: float): # number of nodes N = sparse_adj.size(0) # sparse adj matrix to dense adj matrix row, col, edge_attr = sparse_adj.coo() edge_index = torch.stack([row, col], dim=0) # dropout adjacency matrix -> generalization edge_index, edge_attr = dropout_adj(edge_index, edge_attr=edge_attr, p=dropout_adj_prob, force_undirected=True, training=self.training) # because dropout removes self-loops (due to force_undirected=True), make sure to add them back again edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_weight=edge_attr, fill_value=0.00, num_nodes=N) # dense adj matrix to sparse adj matrix sparse_adj = SparseTensor.from_edge_index(edge_index, edge_attr=edge_attr, sparse_sizes=(N, N)) return sparse_adj
class GraphSAINTSampler(object): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`torch_geometric.data.GraphSAINTNodeSampler`, :class:`torch_geometric.data.GraphSAINTEdgeSampler` and :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/ blob/master/examples/graph_saint.py>`_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch to load. num_steps (int, optional): The number of iterations. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`50`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) num_workers (int, optional): How many subprocesses to use for data sampling. :obj:`0` means that the data will be sampled in the main process. (default: :obj:`0`) """ def __init__(self, data, batch_size, num_steps=1, sample_coverage=50, save_dir=None, num_workers=0): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=data.edge_attr, sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None self.data.edge_attr = None self.batch_size = batch_size self.num_steps = num_steps self.sample_coverage = sample_coverage self.num_workers = num_workers self.__count__ = 0 if self.num_workers > 0: self.__sample_queue__ = Queue() self.__sample_workers__ = [] for _ in range(self.num_workers): worker = Process(target=self.__put_sample__, args=(self.__sample_queue__, )) worker.daemon = True worker.start() self.__sample_workers__.append(worker) path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: torch.save((self.node_norm, self.edge_norm), path) if self.num_workers > 0: self.__data_queue__ = Queue() self.__data_workers__ = [] for _ in range(self.num_workers): worker = Process(target=self.__put_data__, args=(self.__data_queue__, )) worker.daemon = True worker.start() self.__data_workers__.append(worker) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __sample_nodes__(self, num_examples): raise NotImplementedError def __sample__(self, num_examples): node_samples = self.__sample_nodes__(num_examples) samples = [] for node_idx in node_samples: node_idx = node_idx.unique() adj, edge_idx = self.adj.saint_subgraph(node_idx) samples.append((node_idx, edge_idx, adj)) return samples def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) num_samples = total_sampled_nodes = 0 pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') while total_sampled_nodes < self.N * self.sample_coverage: num_sampled_nodes = 0 if self.num_workers > 0: for _ in range(200): node_idx, edge_idx, _ = self.__sample_queue__.get() node_count[node_idx] += 1 edge_count[edge_idx] += 1 num_sampled_nodes += node_idx.size(0) else: samples = self.__sample__(200) for node_idx, edge_idx, _ in samples: node_count[node_idx] += 1 edge_count[edge_idx] += 1 num_sampled_nodes += node_idx.size(0) pbar.update(num_sampled_nodes) total_sampled_nodes += num_sampled_nodes num_samples += 200 pbar.close() row, col, _ = self.adj.coo() edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm def __get_data_from_sample__(self, sample): node_idx, edge_idx, adj = sample data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, value = adj.coo() data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value for key, item in self.data: if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __put_sample__(self, queue): while True: sample = self.__sample__(1)[0] queue.put(sample) def __put_data__(self, queue): while True: sample = self.__sample_queue__.get() data = self.__get_data_from_sample__(sample) queue.put(data) def __next__(self): if self.__count__ < len(self): self.__count__ += 1 if self.num_workers > 0: data = self.__data_queue__.get() else: sample = self.__sample__(1)[0] data = self.__get_data_from_sample__(sample) return data else: raise StopIteration def __len__(self): return self.num_steps def __iter__(self): self.__count__ = 0 return self
class GraphSAINTSampler(torch.utils.data.DataLoader): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`torch_geometric.data.GraphSAINTNodeSampler`, :class:`torch_geometric.data.GraphSAINTEdgeSampler` and :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/ blob/master/examples/graph_saint.py>`_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch. num_steps (int, optional): The number of iterations per epoch. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`0`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any pre-processing progress. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.num_steps = num_steps self.__batch_size__ = batch_size self.sample_coverage = sample_coverage self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=torch.arange( self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None super(GraphSAINTSampler, self).__init__(self, batch_size=1, collate_fn=self.__collate__, **kwargs) if self.sample_coverage > 0: path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def __sample_nodes__(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self.__sample_nodes__(self.__batch_size__).unique() adj, _ = self.adj.saint_subgraph(node_idx) return node_idx, adj def __collate__(self, data_list): assert len(data_list) == 1 node_idx, adj = data_list[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if isinstance(item, torch.Tensor) and item.size(0) == self.N: data[key] = item[node_idx] elif isinstance(item, torch.Tensor) and item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item if self.sample_coverage > 0: data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) loader = torch.utils.data.DataLoader(self, batch_size=200, collate_fn=lambda x: x, num_workers=self.num_workers) if self.log: # pragma: no cover pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: for data in loader: for node_idx, adj in data: edge_idx = adj.storage.value() node_count[node_idx] += 1 edge_count[edge_idx] += 1 total_sampled_nodes += node_idx.size(0) if self.log: # pragma: no cover pbar.update(node_idx.size(0)) num_samples += self.num_steps if self.log: # pragma: no cover pbar.close() row, _, edge_idx = self.adj.coo() t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row]) edge_norm = (t / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm
def sample_adj(self, adj: SparseTensor) -> SparseTensor: row, col, _ = adj.coo() deg = degree(row, num_nodes=adj.size(0)) prob = (self.max_sample * (1. / deg))[row] mask = torch.rand_like(prob) < prob return adj.masked_select_nnz(mask, layout='coo')
class MySAINTSampler(object): r"""A new random-walk sampler for GraphSAINT that samples initial nodes by iterating over node permutations. The benefit is that we can leverage this sampler for subgraph-based inference. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The number of walks to sample per batch. walk_length (int): The length of each random walk. sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`50`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any progress. (default: :obj:`True`) """ def __init__(self, data, batch_size, sample_type='random_walk', walk_length=2, sample_coverage=50, save_dir=None, log=True): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data assert sample_type in ('node', 'random_walk') self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=data.edge_attr, sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None self.data.edge_attr = None self.sample_type = sample_type self.batch_size = batch_size # self.num_steps = num_steps self.walk_length = walk_length self.sample_coverage = sample_coverage self.log = log path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __sample_nodes__(self): """Sampling initial nodes by iterating over the random permutation of nodes""" tmp_map = torch.arange(self.N, dtype=torch.long) all_n_id = torch.randperm(self.N, dtype=torch.long) node_samples = [] for s_id in range(0, self.N, self.batch_size): init_n_id = all_n_id[s_id:s_id+self.batch_size] # [batch_size] if self.sample_type == 'random_walk': n_id = self.adj.random_walk(init_n_id, self.walk_length) # [batch_size, walk_length+1] n_id = n_id.flatten().unique() # [num_nodes_in_subgraph] tmp_map[n_id] = torch.arange(n_id.size(0), dtype=torch.long) res_n_id = tmp_map[init_n_id] elif self.sample_type == 'node': n_id = init_n_id res_n_id = torch.arange(n_id.size(0), dtype=torch.long) else: raise ValueError('Unsupported value type {}'.format(self.sample_type)) node_samples.append((n_id, res_n_id)) return node_samples def __sample__(self, num_epoches): samples = [] for _ in range(num_epoches): node_samples = self.__sample_nodes__() for n_id, res_n_id in node_samples: adj, e_id = self.adj.saint_subgraph(n_id) samples.append((n_id, e_id, adj, res_n_id)) return samples def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) if self.log: pbar = tqdm(total=self.sample_coverage) pbar.set_description('GraphSAINT Normalization') num_samples = len(self) * self.sample_coverage for _ in range(self.sample_coverage): samples = self.__sample__(1) for n_id, e_id, _, _ in samples: node_count[n_id] += 1 edge_count[e_id] += 1 if self.log: pbar.update(1) row, col, _ = self.adj.coo() edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / (node_count * self.N) return node_norm, edge_norm def __get_data_from_sample__(self, sample): n_id, e_id, adj, res_n_id = sample data = self.data.__class__() data.num_nodes = n_id.size(0) row, col, value = adj.coo() data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value for key, item in self.data: if item.size(0) == self.N: data[key] = item[n_id] elif item.size(0) == self.E: data[key] = item[e_id] else: data[key] = item data.node_norm = self.node_norm[n_id] data.edge_norm = self.edge_norm[e_id] data.n_id = n_id data.res_n_id = res_n_id return data def __len__(self): return (self.N + self.batch_size-1) // self.batch_size def __iter__(self): for sample in self.__sample__(1): data = self.__get_data_from_sample__(sample) yield data
class GraphSAINTSampler(torch.utils.data.DataLoader): def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.num_steps = num_steps self.__batch_size__ = batch_size self.sample_coverage = sample_coverage self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None super(GraphSAINTSampler, self).__init__(self, batch_size=1, collate_fn=self.__collate__, **kwargs) if self.sample_coverage > 0: path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def __sample_nodes__(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self.__sample_nodes__(self.__batch_size__).unique() adj, _ = self.adj.saint_subgraph(node_idx) return node_idx, adj def __collate__(self, data_list): assert len(data_list) == 1 node_idx, adj = data_list[0] data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, edge_idx = adj.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in self.data: if isinstance(item, torch.Tensor) and item.size(0) == self.N: data[key] = item[node_idx] elif isinstance(item, torch.Tensor) and item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item if self.sample_coverage > 0: data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) loader = torch.utils.data.DataLoader(self, batch_size=200, collate_fn=lambda x: x, num_workers=self.num_workers) if self.log: pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: for data in loader: for node_idx, adj in data: edge_idx = adj.storage.value() node_count[node_idx] += 1 edge_count[edge_idx] += 1 total_sampled_nodes += node_idx.size(0) if self.log: pbar.update(node_idx.size(0)) num_samples += 200 if self.log: pbar.close() row, _, edge_idx = self.adj.coo() t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row]) edge_norm = (t / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm