def test_gated_graph_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = GatedGraphConv(32, num_layers=3) assert conv.__repr__() == 'GatedGraphConv(32, num_layers=3)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)
def test_lg_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = LGConv() assert str(conv) == 'LGConv()' out1 = conv(x, edge_index) assert out1.size() == (4, 8) assert torch.allclose(conv(x, adj1.t()), out1) out2 = conv(x, edge_index, value) assert out2.size() == (4, 8) assert torch.allclose(conv(x, adj2.t()), out2) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index), out1) assert torch.allclose(jit(x, edge_index, value), out2) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1) assert torch.allclose(jit(x, adj2.t()), out2)
def test_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = GCNConv(16, 32) assert conv.__repr__() == 'GCNConv(16, 32)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out1.tolist() conv(x, adj1.t()) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
def test_dna_conv(): x = torch.randn((4, 3, 32)) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = DNAConv(32, heads=4, groups=8, dropout=0.0) assert conv.__repr__() == 'DNAConv(32, heads=4, groups=8)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out1.tolist() conv(x, adj1.t()) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
def absv(src: SparseTensor): """ The input and ouput SparseTensors will share the memory of row,col,rowptr fields except value. Parameters ---------- src: SparseTensor Returns ------- SparseTensor """ val = src.storage.value() assert val is not None abs_val = val.abs() return src.set_value(abs_val, layout="csr")
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: adj_t_2 = adj_t if len(self.aggregators) > 1 and 'symnorm' in self.aggregators: adj_t_2 = adj_t.set_value(None) outs = [] for aggr in self.aggregators: if aggr == 'symnorm': out = matmul(adj_t, x, reduce='sum') elif aggr in ['var', 'std']: mean = matmul(adj_t_2, x, reduce='mean') mean_sq = matmul(adj_t_2, x * x, reduce='mean') out = mean_sq - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = matmul(adj_t_2, x, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]
def message_and_aggregate(self, adj_t: SparseTensor, x: OptPairTensor) -> Tensor: adj_t = adj_t.set_value(None, layout=None) return matmul(adj_t, x[0], reduce=self.aggr)
def main(): parser = argparse.ArgumentParser(description='OGBL-COLLAB (GNN)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--batch_size', type=int, default=64 * 1024) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--eval_steps', type=int, default=1) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygLinkPropPredDataset(name='ogbl-collab') split_edge = dataset.get_edge_split() data = dataset[0] x = data.x.to(device) edge_index = data.edge_index.to(device) weight = data.edge_weight.view(-1).to(torch.float).to(device) adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=weight) adj = adj.sum(dim=1).pow(-1).view(-1, 1) * adj if args.use_sage: model = SAGE(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) else: model = GCN(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) # Pre-compute GCN normalization. adj = adj.set_value(None) adj = adj.set_diag() deg = adj.sum(dim=1) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbl-ppa') loggers = { 'Hits@10': Logger(args.runs, args), 'Hits@50': Logger(args.runs, args), 'Hits@100': Logger(args.runs, args), } for run in range(args.runs): model.reset_parameters() predictor.reset_parameters() optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, predictor, x, adj, split_edge, optimizer, args.batch_size) if epoch % args.eval_steps == 0: results = test(model, predictor, x, adj, split_edge, evaluator, args.batch_size) for key, result in results.items(): loggers[key].add_result(run, result) if epoch % args.log_steps == 0: for key, result in results.items(): train_hits, valid_hits, test_hits = result print(key) print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {100 * train_hits:.2f}%, ' f'Valid: {100 * valid_hits:.2f}%, ' f'Test: {100 * test_hits:.2f}%') print('---') for key in loggers.keys(): print(key) loggers[key].print_statistics(run) for key in loggers.keys(): print(key) loggers[key].print_statistics()
def main(): parser = argparse.ArgumentParser(description='OGBL-Citation (GNN)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0) parser.add_argument('--batch_size', type=int, default=64 * 1024) parser.add_argument('--lr', type=float, default=0.0005) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--eval_steps', type=int, default=1) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygLinkPropPredDataset(name='ogbl-citation') split_edge = dataset.get_edge_split() data = dataset[0] # We randomly pick some training samples that we want to evaluate on: torch.manual_seed(12345) idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596] split_edge['eval_train'] = { 'source_node': split_edge['train']['source_node'][idx], 'target_node': split_edge['train']['target_node'][idx], 'target_node_neg': split_edge['valid']['target_node_neg'], } x = data.x.to(device) edge_index = data.edge_index.to(device) edge_index = to_undirected(edge_index, data.num_nodes) adj = SparseTensor(row=edge_index[0], col=edge_index[1]) if args.use_sage: model = SAGE(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) else: model = GCN(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) # Pre-compute GCN normalization. adj = adj.set_value(None) adj = adj.set_diag() deg = adj.sum(dim=1) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbl-citation') logger = Logger(args.runs, args) for run in range(args.runs): model.reset_parameters() predictor.reset_parameters() optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, predictor, x, adj, split_edge, optimizer, args.batch_size) print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}') if epoch % args.eval_steps == 0: result = test(model, predictor, x, adj, split_edge, evaluator, args.batch_size) logger.add_result(run, result) if epoch % args.log_steps == 0: train_mrr, valid_mrr, test_mrr = result print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {train_mrr:.4f}, ' f'Valid: {valid_mrr:.4f}, ' f'Test: {test_mrr:.4f}') logger.print_statistics(run) logger.print_statistics()
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: adj_t = adj_t.set_value(None) return matmul(adj_t, x, reduce=self.aggr)