def get_adj(row, col, N, asymm_norm=False, set_diag=True, remove_diag=False): adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) if set_diag: print('... setting diagonal entries') adj = adj.set_diag() elif remove_diag: print('... removing diagonal entries') adj = adj.remove_diag() else: print('... keeping diag elements as they are') if not asymm_norm: print('... performing symmetric normalization') deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) else: print('... performing asymmetric normalization') deg = adj.sum(dim=1).to(torch.float) deg_inv = deg.pow(-1.0) deg_inv[deg_inv == float('inf')] = 0 adj = deg_inv.view(-1, 1) * adj adj = adj.to_scipy(layout='csr') return adj
def degree_matrix(adj: SparseTensor, indeg=True): N = adj.size(-1) deg = adj.sum(0) if indeg else adj.sum(1) row = col = torch.arange(N, device=adj.device()) degs = torch.as_tensor(deg, device=adj.device()) return SparseTensor( row=row, col=col, value=degs, sparse_sizes=(N, N), is_sorted=True )
def in_degree(adj: SparseTensor, bunch=None): if bunch is None: in_deg = adj.sum(0) else: N = adj.size(0) if len(bunch) > int(0.2 * N): in_deg = adj.sum(0)[bunch] else: ptr, idx, val = adj.csc() in_deg = val.new_zeros(len(bunch)) for i, v in enumerate(bunch): in_deg[i] = val[ptr[v] : ptr[v + 1]].sum() return in_deg
def test(model, data, evaluator): print('Evaluating full-batch GNN on CPU...') weights = [(conv.lin_rel.weight.t().cpu().detach().numpy(), conv.lin_rel.bias.cpu().detach().numpy(), conv.lin_root.weight.t().cpu().detach().numpy()) for conv in model.convs] model = SAGEInference(weights) x = data.x.numpy() adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1]) adj = adj.sum(dim=1).pow(-1).view(-1, 1) * adj adj = adj.to_scipy(layout='csr') out = model(x, adj) y_true = data.y y_pred = torch.from_numpy(out).argmax(dim=-1, keepdim=True) train_acc = evaluator.eval({ 'y_true': y_true[data.train_mask], 'y_pred': y_pred[data.train_mask] })['acc'] valid_acc = evaluator.eval({ 'y_true': y_true[data.valid_mask], 'y_pred': y_pred[data.valid_mask] })['acc'] test_acc = evaluator.eval({ 'y_true': y_true[data.test_mask], 'y_pred': y_pred[data.test_mask] })['acc'] return train_acc, valid_acc, test_acc
def laplace(adj: SparseTensor, lap_type=None): M, N = adj.sizes() assert M == N row, col, val = adj.clone().coo() val = col.new_ones(col.shape, dtype=adj.dtype()) if val is None else val deg = adj.sum(0) loop_index = torch.arange(N, device=adj.device()).unsqueeze_(0) if lap_type in (None, "sym"): deg05 = deg.pow(-0.5) deg05[deg05 == float("inf")] = 0 wgt = deg05[row] * val * deg05[col] wgt = torch.cat([-wgt.unsqueeze_(0), val.new_ones(1, N)], 1).squeeze_() elif lap_type == "rw": deg_inv = 1.0 / deg deg_inv[deg_inv == float("inf")] = 0 wgt = deg_inv[row] * val wgt = torch.cat([-wgt.unsqueeze_(0), val.new_ones(1, N)], 1).squeeze_() elif lap_type == "comb": wgt = torch.cat([-val.unsqueeze_(0), deg.unsqueeze_(0)], 1).squeeze_() else: raise TypeError("Invalid laplace type: {}".format(lap_type)) row = torch.cat([row.unsqueeze_(0), loop_index], 1).squeeze_() col = torch.cat([col.unsqueeze_(0), loop_index], 1).squeeze_() lap = SparseTensor(row=row, col=col, value=wgt, sparse_sizes=(M, N)) return lap
def preprocess(data, preprocess="diffusion", num_propagations=10, p=None, alpha=None, use_cache=True, post_fix=""): if use_cache: try: x = torch.load(f'embeddings/{preprocess}{post_fix}.pt') print('Using cache') return x except: print( f'embeddings/{preprocess}{post_fix}.pt not found or not enough iterations! Regenerating it now' ) # Creates a new file with open(f'embeddings/{preprocess}{post_fix}.pt', 'w') as fp: pass if preprocess == "community": return community(data, post_fix) if preprocess == "spectral": return spectral(data, post_fix) print('Computing adj...') N = data.num_nodes data.edge_index = to_undirected(data.edge_index, data.num_nodes) row, col = data.edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) adj = adj.set_diag() deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) adj = adj.to_scipy(layout='csr') sgc_dict = {} print(f'Start {preprocess} processing') if preprocess == "sgc": result = sgc(data.x.numpy(), adj, num_propagations) # if preprocess == "lp": # result = lp(adj, data.y.data, num_propagations, p = p, alpha = alpha, preprocess = preprocess) if preprocess == "diffusion": result = diffusion(data.x.numpy(), adj, num_propagations, p=p, alpha=alpha) torch.save(result, f'embeddings/{preprocess}{post_fix}.pt') return result
def get_adj(row, col, N, asymm_norm=False, set_diag=True, remove_diag=False): adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) if set_diag: adj = adj.set_diag() elif remove_diag: adj = adj.remove_diag() if not asymm_norm: deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) else: deg = adj.sum(dim=1).to(torch.float) deg_inv = deg.pow(-1.0) deg_inv[deg_inv == float('inf')] = 0 adj = deg_inv.view(-1, 1) * adj return adj
def process_adj(data): N = data.num_nodes data.edge_index = to_undirected(data.edge_index, data.num_nodes) row, col = data.edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return adj, deg_inv_sqrt
def gen_normalized_adjs(dataset): """ returns the normalized adjacency matrix """ row, col = dataset.graph['edge_index'] N = dataset.graph['num_nodes'] adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) deg = adj.sum(dim=1).to(torch.float) D_isqrt = deg.pow(-0.5) D_isqrt[D_isqrt == float('inf')] = 0 DAD = D_isqrt.view(-1, 1) * adj * D_isqrt.view(1, -1) DA = D_isqrt.view(-1, 1) * D_isqrt.view(-1, 1) * adj AD = adj * D_isqrt.view(1, -1) * D_isqrt.view(1, -1) return DAD, DA, AD
def test(model, predictor, data, split_edge, evaluator, batch_size, device): predictor.eval() print('Evaluating full-batch GNN on CPU...') weights = [(conv.weight.cpu().detach().numpy(), conv.bias.cpu().detach().numpy()) for conv in model.convs] model = GCNInference(weights) x = data.x.numpy() adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1]) adj = adj.set_diag() deg = adj.sum(dim=1) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) adj = adj.to_scipy(layout='csr') h = torch.from_numpy(model(x, adj)).to(device) def test_split(split): source = split_edge[split]['source_node'].to(device) target = split_edge[split]['target_node'].to(device) target_neg = split_edge[split]['target_node_neg'].to(device) pos_preds = [] for perm in DataLoader(range(source.size(0)), batch_size): src, dst = source[perm], target[perm] pos_preds += [predictor(h[src], h[dst]).squeeze().cpu()] pos_pred = torch.cat(pos_preds, dim=0) neg_preds = [] source = source.view(-1, 1).repeat(1, 1000).view(-1) target_neg = target_neg.view(-1) for perm in DataLoader(range(source.size(0)), batch_size): src, dst_neg = source[perm], target_neg[perm] neg_preds += [predictor(h[src], h[dst_neg]).squeeze().cpu()] neg_pred = torch.cat(neg_preds, dim=0).view(-1, 1000) return evaluator.eval({ 'y_pred_pos': pos_pred, 'y_pred_neg': neg_pred, })['mrr_list'].mean().item() train_mrr = test_split('eval_train') valid_mrr = test_split('valid') test_mrr = test_split('test') return train_mrr, valid_mrr, test_mrr
def __call__(self, data): assert data.edge_index is not None row, col = data.edge_index adj_t = SparseTensor(row=col, col=row, sparse_sizes=(data.num_nodes, data.num_nodes)) deg = adj_t.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) assert data.x is not None xs = [data.x] for i in range(1, self.num_hops + 1): xs += [adj_t @ xs[-1]] data[f'x{i}'] = xs[-1] return data
def main(): parser = argparse.ArgumentParser(description='OGBN-papers100M (MLP)') parser.add_argument('--data_root_dir', type=str, default='../../dataset') parser.add_argument('--num_propagations', type=int, default=3) parser.add_argument('--dropedge_rate', type=float, default=0.4) parser.add_argument('--node_emb_path', type=str, default=None) parser.add_argument('--output_path', type=str, required=True) args = parser.parse_args() # SGC pre-processing ###################################################### dataset = PygNodePropPredDataset(name='ogbn-papers100M', root=args.data_root_dir) split_idx = dataset.get_idx_split() data = dataset[0] x = None if args.node_emb_path: x = np.load(args.node_emb_path) else: x = data.x.numpy() N = data.num_nodes print('Making the graph undirected.') ### Randomly drop some edges to save computation data.edge_index, _ = dropout_adj(data.edge_index, p=args.dropedge_rate, num_nodes=data.num_nodes) data.edge_index = to_undirected(data.edge_index, data.num_nodes) print(data) row, col = data.edge_index print('Computing adj...') adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) adj = adj.set_diag() deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) adj = adj.to_scipy(layout='csr') train_idx, valid_idx, test_idx = split_idx['train'], split_idx[ 'valid'], split_idx['test'] all_idx = torch.cat([train_idx, valid_idx, test_idx]) mapped_train_idx = torch.arange(len(train_idx)) mapped_valid_idx = torch.arange(len(train_idx), len(train_idx) + len(valid_idx)) mapped_test_idx = torch.arange( len(train_idx) + len(valid_idx), len(train_idx) + len(valid_idx) + len(test_idx)) sgc_dict = {} sgc_dict['label'] = data.y.data[all_idx].to(torch.long) sgc_dict['split_idx'] = { 'train': mapped_train_idx, 'valid': mapped_valid_idx, 'test': mapped_test_idx } print('Start SGC processing') for _ in tqdm(range(args.num_propagations)): x = adj @ x sgc_dict['sgc_embedding'] = torch.from_numpy(x[all_idx]).to(torch.float) torch.save(sgc_dict, args.output_path)
def main(): parser = argparse.ArgumentParser(description='OGBL-Citation (GNN)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0) parser.add_argument('--batch_size', type=int, default=64 * 1024) parser.add_argument('--lr', type=float, default=0.0005) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--eval_steps', type=int, default=1) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygLinkPropPredDataset(name='ogbl-citation') split_edge = dataset.get_edge_split() data = dataset[0] # We randomly pick some training samples that we want to evaluate on: torch.manual_seed(12345) idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596] split_edge['eval_train'] = { 'source_node': split_edge['train']['source_node'][idx], 'target_node': split_edge['train']['target_node'][idx], 'target_node_neg': split_edge['valid']['target_node_neg'], } x = data.x.to(device) edge_index = data.edge_index.to(device) edge_index = to_undirected(edge_index, data.num_nodes) adj = SparseTensor(row=edge_index[0], col=edge_index[1]) if args.use_sage: model = SAGE(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) else: model = GCN(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) # Pre-compute GCN normalization. adj = adj.set_value(None) adj = adj.set_diag() deg = adj.sum(dim=1) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbl-citation') logger = Logger(args.runs, args) for run in range(args.runs): model.reset_parameters() predictor.reset_parameters() optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, predictor, x, adj, split_edge, optimizer, args.batch_size) print(f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}') if epoch % args.eval_steps == 0: result = test(model, predictor, x, adj, split_edge, evaluator, args.batch_size) logger.add_result(run, result) if epoch % args.log_steps == 0: train_mrr, valid_mrr, test_mrr = result print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {train_mrr:.4f}, ' f'Valid: {valid_mrr:.4f}, ' f'Test: {test_mrr:.4f}') logger.print_statistics(run) logger.print_statistics()
def main(): parser = argparse.ArgumentParser(description='OGBL-DDI (Full-Batch)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--batch_size', type=int, default=64 * 1024) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--eval_steps', type=int, default=5) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygLinkPropPredDataset(name='ogbl-ddi') split_edge = dataset.get_edge_split() data = dataset[0] # We randomly pick some training samples that we want to evaluate on: torch.manual_seed(12345) idx = torch.randperm(split_edge['train']['edge'].size(0)) idx = idx[:split_edge['valid']['edge'].size(0)] split_edge['eval_train'] = {'edge': split_edge['train']['edge'][idx]} edge_index = data.edge_index adj = SparseTensor(row=edge_index[0], col=edge_index[1]).to(device) if args.use_sage: model = SAGE(args.hidden_channels, args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) else: model = GCN(args.hidden_channels, args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) # Pre-compute GCN normalization. adj = adj.set_diag() deg = adj.sum(dim=1) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) emb = torch.nn.Embedding(data.num_nodes, args.hidden_channels).to(device) predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbl-ddi') loggers = { 'Hits@10': Logger(args.runs, args), 'Hits@20': Logger(args.runs, args), 'Hits@30': Logger(args.runs, args), } for run in range(args.runs): torch.nn.init.xavier_uniform_(emb.weight) model.reset_parameters() predictor.reset_parameters() optimizer = torch.optim.Adam(list(model.parameters()) + list(emb.parameters()) + list(predictor.parameters()), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, predictor, emb.weight, adj, data.edge_index, split_edge, optimizer, args.batch_size) if epoch % args.eval_steps == 0: results = test(model, predictor, emb.weight, adj, split_edge, evaluator, args.batch_size) for key, result in results.items(): loggers[key].add_result(run, result) if epoch % args.log_steps == 0: for key, result in results.items(): train_hits, valid_hits, test_hits = result print(key) print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {100 * train_hits:.2f}%, ' f'Valid: {100 * valid_hits:.2f}%, ' f'Test: {100 * test_hits:.2f}%') print('---') for key in loggers.keys(): print(key) loggers[key].print_statistics(run) for key in loggers.keys(): print(key) loggers[key].print_statistics()
def osglm(A: SparseTensor, lc: Optional[int] = None, vtx_color: Optional[VertexColor] = None): r""" The oversampled bipartite graph approximation method proposed in [1]_ Parameters ---------- A: SparseTensor The adjacent matrix of graph. lc: int The ordinal of color marking the boundary such that all nodes with a smaller color ordinal are grouped into the low-pass channel while those with a larger color ordinal are in the high-pass channel. vtx_color:iter The graph coloring result Returns ------- bptG: lil_matrix The oversampled graph(with additional nodes) beta : np.ndarray append_nodes: np.ndarray The indices of those appended nodes vtx_color: np.ndarray The node colors References ---------- .. [1] Akie Sakiyama, et al, "Oversampled Graph Laplacian Matrix for Graph Filter Banks", IEEE trans on SP, 2016. """ if vtx_color is None: from thgsp.alg import dsatur vtx_color = dsatur(A) vtx_color = np.asarray(vtx_color) n_color = max(vtx_color) + 1 if lc is None: lc = n_color // 2 assert 1 <= lc < n_color A = A.to_scipy(layout="csr").tolil() # the foundation bipartite graph Gb Gb = lil_matrix(A.shape, dtype=A.dtype) N = A.shape[-1] bt = np.in1d(vtx_color, range(lc)) idx_s1 = np.nonzero(bt)[0] # L idx_s2 = np.nonzero(~bt)[0] # H mask = bipartite_mask(bt) # the desired edges Gb[mask] = A[mask] A[mask] = 0 eye_mask = eye(N, N, dtype=bool) A[eye_mask] = 1 # add vertical edges degree = A.sum(0).getA1() # 2D np.matrix -> 1D np.array append_nodes = (degree != 0).nonzero()[0] Nos = len(append_nodes) + N # oversampled size bptG = [lil_matrix((Nos, Nos), dtype=A.dtype)] # the expanded graph bptG[0][:N, N:] = A[:, append_nodes] bptG[0][:N, :N] = Gb bptG[0][N:, :N] = A[append_nodes, :] beta = np.zeros((Nos, 1), dtype=bool) beta[idx_s1, 0] = 1 # appended nodes corresponding to idx_s2 are assigned to # the L channel of oversampled graph with idx_s1 _, node_ordinal_append, _ = np.intersect1d(append_nodes, idx_s2, return_indices=True) beta[N + node_ordinal_append, 0] = 1 return bptG, beta, append_nodes, vtx_color
parser.add_argument('--eval_steps', type=int, default=5) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygNodePropPredDataset(name='ogbn-proteins') split_idx = dataset.get_idx_split() data = dataset[0] edge_index = data.edge_index.to(device) adj = SparseTensor(row=edge_index[0], col=edge_index[1]).set_diag() adj_0 = adj deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) class l_GCN(torch.nn.Module): def __init__(self, in_channels, out_channels): super(l_GCN, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.weight = Parameter(torch.Tensor(in_channels, out_channels)) self.bias = Parameter(torch.Tensor(out_channels))
def main(): parser = argparse.ArgumentParser(description='OGBL-PPA (Full-Batch)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_node_embedding', action='store_true') parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--batch_size', type=int, default=64 * 1024) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--eval_steps', type=int, default=1) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygLinkPropPredDataset(name='ogbl-ppa') data = dataset[0] splitted_edge = dataset.get_edge_split() if args.use_node_embedding: x = data.x.to(torch.float) x = torch.cat([x, torch.load('embedding.pt')], dim=-1) x = x.to(device) else: x = data.x.to(torch.float).to(device) edge_index = data.edge_index.to(device) adj = SparseTensor(row=edge_index[0], col=edge_index[1]) if args.use_sage: model = SAGE(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) else: model = GCN(x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) # Pre-compute GCN normalization. adj = adj.set_diag() deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbl-ppa') loggers = { 'Hits@10': Logger(args.runs, args), 'Hits@50': Logger(args.runs, args), 'Hits@100': Logger(args.runs, args), } for run in range(args.runs): model.reset_parameters() predictor.reset_parameters() optimizer = torch.optim.Adam( list(model.parameters()) + list(predictor.parameters()), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, predictor, x, adj, splitted_edge, optimizer, args.batch_size) if epoch % args.eval_steps == 0: results = test(model, predictor, x, adj, splitted_edge, evaluator, args.batch_size) for key, result in results.items(): loggers[key].add_result(run, result) if epoch % args.log_steps == 0: for key, result in results.items(): train_hits, valid_hits, test_hits = result print(key) print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {100 * train_hits:.2f}%, ' f'Valid: {100 * valid_hits:.2f}%, ' f'Test: {100 * test_hits:.2f}%') for key in loggers.keys(): print(key) loggers[key].print_statistics(run) for key in loggers.keys(): print(key) loggers[key].print_statistics()
def main(): parser = argparse.ArgumentParser(description='OGBN-Arxiv (Full-Batch)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--runs', type=int, default=5) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygNodePropPredDataset(name='ogbn-arxiv') split_idx = dataset.get_idx_split() data = dataset[0] x = data.x.to(device) y_true = data.y.to(device) train_idx = split_idx['train'].to(device) edge_index = data.edge_index.to(device) edge_index = to_undirected(edge_index, data.num_nodes) adj = SparseTensor(row=edge_index[0], col=edge_index[1]) if args.use_sage: model = SAGE(data.x.size(-1), args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) else: model = GCN(data.x.size(-1), args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) # Pre-compute GCN normalization. adj = adj.set_diag() deg = adj.sum(dim=1).to(torch.float) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 adj = deg_inv_sqrt.view(-1, 1) * adj * deg_inv_sqrt.view(1, -1) evaluator = Evaluator(name='ogbn-arxiv') logger = Logger(args.runs, args) for run in range(args.runs): model.reset_parameters() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, x, adj, y_true, train_idx, optimizer) result = test(model, x, adj, y_true, split_idx, evaluator) logger.add_result(run, result) if epoch % args.log_steps == 0: train_acc, valid_acc, test_acc = result print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {100 * train_acc:.2f}%, ' f'Valid: {100 * valid_acc:.2f}% ' f'Test: {100 * test_acc:.2f}%') logger.print_statistics(run) logger.print_statistics()