def prepare_gnn_training(self): if verbose: print("\n\n==>> Clustering the graph and preparing dataloader....") self.data = Data(x=self.x_data.float(), edge_index = self.edge_index_data.long(), edge_attr = self.edge_type_data, y=self.y_data) new_num_nodes, _ = self.data.x.shape self.data.train_mask = torch.FloatTensor(self.split_masks['train_mask']) self.data.val_mask = torch.FloatTensor(self.split_masks['val_mask']) self.data.representation_mask = torch.FloatTensor(self.split_masks['repr_mask']) self.data.node2id = torch.tensor(list(self.node2id.values())) # self.data.node_type = self.node_type if not self.config['full_graph']: if self.config['cluster'] : cluster_data = ClusterData(self.data, num_parts=self.config['clusters'], recursive=False) self.loader = ClusterLoader(cluster_data, batch_size=self.config['batch_size'], shuffle=self.config['shuffle'], num_workers=0) elif self.config['saint'] == 'random_walk': self.loader = GraphSAINTRandomWalkSampler(self.data, batch_size=6000, walk_length=2, num_steps=5, sample_coverage=100, num_workers=0) elif self.config['saint'] == 'node': self.loader = GraphSAINTNodeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0) elif self.config['saint'] == 'edge': self.loader = GraphSAINTEdgeSampler(self.data, batch_size=6000, num_steps=5, sample_coverage=100, num_workers=0) else: self.loader=None return self.loader, self.vocab_size, self.data
def __init__(self, graph, model, args, criterion=None): self.args = args self.graph = graph self.model = model.to(self.args.device) #build data loader on cpu self.loader = GraphSAINTRandomWalkSampler(graph, batch_size=self.args.GraphSAINT['batch_size'], walk_length=self.args.GraphSAINT['walk_length'], num_steps=self.args.GraphSAINT['num_steps']) print ('Data Loader created for learner . . .') #Set loss function if not criterion: #BCEWithLogitsLoss is a class to use for individual or multiple binary prediction tasks criterion = nn.BCEWithLogitsLoss() self.criterion = criterion #Set optimizer self.optim = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.w_decay) self.train_loss = [] self.val_loss = [] self.train_complete = False
def build_sampler(args, data, save_dir): if args.sampler == 'rw-my': msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)' loader = MySAINTSampler(data, batch_size=args.batch_size, sample_type='random_walk', walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node-my': msg = 'Use random node sampler(mysaint sampler)' loader = MySAINTSampler(data, sample_type='node', batch_size=args.batch_size * 3, walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'rw': msg = 'Use GraphSaint randomwalk sampler' loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=2, num_steps=5, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node': msg = 'Use GraphSaint node sampler' loader = GraphSAINTNodeSampler(data, batch_size=args.batch_size * 3, num_steps=5, sample_coverage=1000, num_workers=0, save_dir=save_dir) elif args.sampler == 'edge': msg = 'Use GraphSaint edge sampler' loader = GraphSAINTEdgeSampler(data, batch_size=args.batch_size, num_steps=5, sample_coverage=1000, save_dir=save_dir, num_workers=0) elif args.sampler == 'cluster': msg = 'Use cluster sampler' cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=save_dir) loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=0) else: raise KeyError('Sampler type error') return loader, msg
def test_graph_saint(): adj = torch.tensor([ [1, 1, 1, 0, 1, 0], [1, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], [1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1], ]) edge_index = adj.nonzero().t() x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) data = Data(edge_index=edge_index, x=x, num_nodes=6) torch.manual_seed(12345) loader = GraphSAINTNodeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) for sample in loader: assert len(sample) == 4 assert sample.num_nodes <= 2 assert sample.num_edges <= 3 * 2 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) for sample in loader: assert len(sample) == 4 assert sample.num_nodes <= 4 assert sample.num_edges <= 3 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) for sample in loader: assert len(sample) == 4 assert sample.num_nodes <= 4 assert sample.num_edges <= 3 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges
def sample_subgraph(whole_data): data = whole_data.data loader = GraphSAINTRandomWalkSampler( data, batch_size=self.sample_batch_size, walk_length=self.sample_walk_length, num_steps=self.subgraphs, save_dir=whole_data.processed_dir, ) results = [] for data in loader: results.append(data) return results
def build_sampler(args, data, save_dir): if args.sampler == 'rw-my': msg = 'Use GraphSaint randomwalk sampler(mysaint sampler)' loader = MySAINTSampler(data, batch_size=args.batch_size, sample_type='random_walk', walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node-my': msg = 'Use random node sampler(mysaint sampler)' loader = MySAINTSampler(data, sample_type='node', batch_size=args.batch_size * 3, walk_length=2, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'rw': msg = 'Use GraphSaint randomwalk sampler' loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=2, num_steps=5, sample_coverage=1000, save_dir=save_dir) elif args.sampler == 'node': msg = 'Use GraphSaint node sampler' loader = GraphSAINTNodeSampler(data, batch_size=args.batch_size * 3, num_steps=5, sample_coverage=1000, num_workers=0, save_dir=save_dir) elif args.sampler == 'edge': msg = 'Use GraphSaint edge sampler' loader = GraphSAINTEdgeSampler(data, batch_size=args.batch_size, num_steps=5, sample_coverage=1000, save_dir=save_dir, num_workers=0) # elif args.sampler == 'cluster': # logger.info('Use cluster sampler') # cluster_data = ClusterData(data, num_parts=args.num_parts, save_dir=dataset.processed_dir) # raise NotImplementedError('Cluster loader not implement yet') else: raise KeyError('Sampler type error') return loader, msg
edge_attr=edge_type, node_type=node_type, local_node_idx=local_node_idx, num_nodes=node_type.size(0)) homo_data.y = node_type.new_full((node_type.size(0), 1), -1) homo_data.y[local2global['paper']] = data.y_dict['paper'] homo_data.train_mask = torch.zeros((node_type.size(0)), dtype=torch.bool) homo_data.train_mask[local2global['paper'][split_idx['train']['paper']]] = True print(homo_data) train_loader = GraphSAINTRandomWalkSampler(homo_data, batch_size=args.batch_size, walk_length=args.num_layers, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir) # Map informations to their canonical type. x_dict = {} for key, x in data.x_dict.items(): x_dict[key2int[key]] = x num_nodes_dict = {} for key, N in data.num_nodes_dict.items(): num_nodes_dict[key2int[key]] = N class RGCNConv(MessagePassing): def __init__(self, in_channels, out_channels, num_node_types,
def main(): parser = argparse.ArgumentParser(description='OGBN-Products (GraphSAINT)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--inductive', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--batch_size', type=int, default=20000) parser.add_argument('--walk_length', type=int, default=3) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--num_steps', type=int, default=30) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--eval_steps', type=int, default=2) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygNodePropPredDataset( name='ogbn-products', root='/srv/scratch/ogb/datasets/nodeproppred') split_idx = dataset.get_idx_split() data = dataset[0] # Convert split indices to boolean masks and add them to `data`. for key, idx in split_idx.items(): mask = torch.zeros(data.num_nodes, dtype=torch.bool) mask[idx] = True data[f'{key}_mask'] = mask # We omit normalization factors here since those are only defined for the # inductive learning setup. sampler_data = data if args.inductive: sampler_data = to_inductive(data) loader = GraphSAINTRandomWalkSampler(sampler_data, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir) model = SAGE(data.x.size(-1), args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=4096, shuffle=False, num_workers=12) evaluator = Evaluator(name='ogbn-products') logger = Logger(args.runs, args) for run in range(args.runs): model.reset_parameters() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, loader, optimizer, device) if epoch % args.log_steps == 0: print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}') if epoch > 9 and epoch % args.eval_steps == 0: result = test(model, data, evaluator, subgraph_loader, device) logger.add_result(run, result) train_acc, valid_acc, test_acc = result print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Train: {100 * train_acc:.2f}%, ' f'Valid: {100 * valid_acc:.2f}% ' f'Test: {100 * test_acc:.2f}%') logger.add_result(run, result) logger.print_statistics(run) logger.print_statistics()
def main(): parser = argparse.ArgumentParser(description='OGBL-Citation2 (GraphSAINT)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--batch_size', type=int, default=16 * 1024) parser.add_argument('--walk_length', type=int, default=3) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--epochs', type=int, default=200) parser.add_argument('--num_steps', type=int, default=100) parser.add_argument('--eval_steps', type=int, default=10) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygLinkPropPredDataset(name='ogbl-citation2') split_edge = dataset.get_edge_split() data = dataset[0] data.edge_index = to_undirected(data.edge_index, data.num_nodes) loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir) # We randomly pick some training samples that we want to evaluate on: torch.manual_seed(12345) idx = torch.randperm(split_edge['train']['source_node'].numel())[:86596] split_edge['eval_train'] = { 'source_node': split_edge['train']['source_node'][idx], 'target_node': split_edge['train']['target_node'][idx], 'target_node_neg': split_edge['valid']['target_node_neg'], } model = GCN(data.x.size(-1), args.hidden_channels, args.hidden_channels, args.num_layers, args.dropout).to(device) predictor = LinkPredictor(args.hidden_channels, args.hidden_channels, 1, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbl-citation2') logger = Logger(args.runs, args) run_idx = 0 while run_idx < args.runs: model.reset_parameters() predictor.reset_parameters() optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) run_success = True for epoch in range(1, 1 + args.epochs): loss = train(model, predictor, loader, optimizer, device) print( f'Run: {run_idx + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}' ) if loss > 2.: run_success = False logger.reset(run_idx) print('Learning failed. Rerun...') break if epoch > 49 and epoch % args.eval_steps == 0: result = test(model, predictor, data, split_edge, evaluator, batch_size=64 * 1024, device=device) logger.add_result(run_idx, result) train_mrr, valid_mrr, test_mrr = result print(f'Run: {run_idx + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, ' f'Train: {train_mrr:.4f}, ' f'Valid: {valid_mrr:.4f}, ' f'Test: {test_mrr:.4f}') print('GraphSAINT') if run_success: logger.print_statistics(run_idx) run_idx += 1 print('GraphSAINT') logger.print_statistics()
def train(epoch, model, optimizer): global all_data, best_val_acc, best_embeddings, best_model, curr_hyperparameters, best_hyperparameters # Save predictions total_loss = 0 roc_val = [] ap_val = [] f1_val = [] acc_val = [] # Minibatches if config.MINIBATCH == "NeighborSampler": loader = NeighborSampler(all_data.edge_index, sizes=[curr_hyperparameters['nb_size']], batch_size=curr_hyperparameters['batch_size'], shuffle=True) elif config.MINIBATCH == "GraphSaint": all_data.num_classes = torch.tensor([2]) loader = GraphSAINTRandomWalkSampler( all_data, batch_size=curr_hyperparameters['batch_size'], walk_length=curr_hyperparameters['walk_length'], num_steps=curr_hyperparameters['num_steps']) # Iterate through minibatches for data in loader: if config.MINIBATCH == "NeighborSampler": data = preprocess.set_data(data, all_data, config.MINIBATCH) curr_train_pos = data.edge_index[:, data.train_mask] curr_train_neg = negative_sampling( curr_train_pos, num_neg_samples=curr_train_pos.size(1) // 4) curr_train_total = torch.cat([curr_train_pos, curr_train_neg], dim=-1) data.y = torch.zeros(curr_train_total.size(1)).float() data.y[:curr_train_pos.size(1)] = 1. # Perform training data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index) curr_dot_embed = utils.el_dot(out, curr_train_total) loss = utils.calc_loss_both(data, curr_dot_embed) if torch.isnan(loss) == False: total_loss += loss loss.backward() optimizer.step() curr_train_pos_mask = torch.zeros(curr_train_total.size(1)).bool() curr_train_pos_mask[:curr_train_pos.size(1)] = 1 curr_train_neg_mask = (curr_train_pos_mask == 0) roc_score, ap_score, train_acc, train_f1 = utils.calc_roc_score( pred_all=curr_dot_embed.T[1], pos_edges=curr_train_pos_mask, neg_edges=curr_train_neg_mask) print(">>>>>>Train: (ROC) ", roc_score, " (AP) ", ap_score, " (ACC) ", train_acc, " (F1) ", train_f1) curr_val_pos = data.edge_index[:, data.val_mask] curr_val_neg = negative_sampling( curr_val_pos, num_neg_samples=curr_val_pos.size(1) // 4) curr_val_total = torch.cat([curr_val_pos, curr_val_neg], dim=-1) curr_val_pos_mask = torch.zeros(curr_val_total.size(1)).bool() curr_val_pos_mask[:curr_val_pos.size(1)] = 1 curr_val_neg_mask = (curr_val_pos_mask == 0) val_dot_embed = utils.el_dot(out, curr_val_total) data.y = torch.zeros(curr_val_total.size(1)).float() data.y[:curr_val_pos.size(1)] = 1. roc_score, ap_score, val_acc, val_f1 = utils.calc_roc_score( pred_all=val_dot_embed.T[1], pos_edges=curr_val_pos_mask, neg_edges=curr_val_neg_mask) roc_val.append(roc_score) ap_val.append(ap_score) acc_val.append(val_acc) f1_val.append(val_f1) res = "\t".join([ "Epoch: %04d" % (epoch + 1), "train_loss = {:.5f}".format(total_loss), "val_roc = {:.5f}".format(np.mean(roc_val)), "val_ap = {:.5f}".format(np.mean(ap_val)), "val_f1 = {:.5f}".format(np.mean(f1_val)), "val_acc = {:.5f}".format(np.mean(acc_val)) ]) print(res) log_f.write(res + "\n") # Save best model and parameters if best_val_acc <= np.mean(acc_val) + eps: best_val_acc = np.mean(acc_val) with open(str(config.DATASET_DIR / "best_model.pth"), 'wb') as f: torch.save(model.state_dict(), f) best_hyperparameters = curr_hyperparameters best_model = model return total_loss
def main(): parser = argparse.ArgumentParser(description='OGBN-Products (GraphSAINT)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=5) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_workers', type=int, default=0) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--batch_size', type=int, default=8000) parser.add_argument('--walk_length', type=int, default=3) parser.add_argument('--sample_coverage', type=int, default=1000) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--num_steps', type=int, default=20) parser.add_argument('--epochs', type=int, default=150) parser.add_argument('--eval_steps', type=int, default=25) parser.add_argument('--runs', type=int, default=10) args = parser.parse_args() print(args) device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygNodePropPredDataset(name='ogbn-products') split_idx = dataset.get_idx_split() data = dataset[0] # Convert split indices to boolean masks and add them to `data`. for key, idx in split_idx.items(): mask = torch.zeros(data.num_nodes, dtype=torch.bool) mask[idx] = True data[f'{key}_mask'] = mask # Create "inductive" subgraph containing only train and validation nodes. ind_data = to_inductive(copy.copy(data)) row, col = ind_data.edge_index ind_data.edge_attr = 1. / degree(col, ind_data.num_nodes)[col] loader = GraphSAINTRandomWalkSampler(ind_data, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=args.sample_coverage, save_dir=dataset.processed_dir, num_workers=args.num_workers) model = SAGE(ind_data.x.size(-1), args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbn-products') logger = Logger(args.runs, args) for run in range(args.runs): model.reset_parameters() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, loader, optimizer, device) if epoch % args.log_steps == 0: print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}') if epoch % args.eval_steps == 0: result = test(model, data, evaluator) logger.add_result(run, result) train_acc, valid_acc, test_acc = result print(f'Run: {run + 1:02d}, ' f'Epoch: {epoch:02d}, ' f'Train: {100 * train_acc:.2f}%, ' f'Valid: {100 * valid_acc:.2f}% ' f'Test: {100 * test_acc:.2f}%') logger.add_result(run, result) logger.print_statistics(run) logger.print_statistics()
import torch import torch.nn.functional as F from torch_geometric.datasets import Flickr from torch_geometric.data import GraphSAINTRandomWalkSampler from torch_geometric.nn import SAGEConv from torch_geometric.utils import degree path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr') dataset = Flickr(path) data = dataset[0] row, col = data.edge_index data.edge_attr = 1. / degree(col, data.num_nodes)[col] # Norm by in-degree. loader = GraphSAINTRandomWalkSampler(data, batch_size=6000, walk_length=2, num_steps=5, sample_coverage=1000, save_dir=dataset.processed_dir, num_workers=4) class Net(torch.nn.Module): def __init__(self, hidden_channels): super(Net, self).__init__() in_channels = dataset.num_node_features out_channels = dataset.num_classes self.conv1 = SAGEConv(in_channels, hidden_channels, concat=True) self.conv2 = SAGEConv(hidden_channels, hidden_channels, concat=True) self.conv3 = SAGEConv(hidden_channels, hidden_channels, concat=True) self.lin = torch.nn.Linear(3 * hidden_channels, out_channels) def set_aggr(self, aggr): self.conv1.aggr = aggr
def main(): parser = argparse.ArgumentParser(description='OGBN-Products (GraphSAINT)') parser.add_argument('--device', type=int, default=0) parser.add_argument('--inductive', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--batch_size', type=int, default=20000) parser.add_argument('--walk_length', type=int, default=3) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--num_steps', type=int, default=30) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--runs', type=int, default=10) parser.add_argument('--step-size', type=float, default=8e-3) parser.add_argument('-m', type=int, default=3) parser.add_argument('--test-freq', type=int, default=2) parser.add_argument('--attack', type=str, default='flag') parser.add_argument('--amp', type=float, default=2) args = parser.parse_args() device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = torch.device(device) dataset = PygNodePropPredDataset(name='ogbn-products') split_idx = dataset.get_idx_split() data = dataset[0] # Convert split indices to boolean masks and add them to `data`. for key, idx in split_idx.items(): mask = torch.zeros(data.num_nodes, dtype=torch.bool) mask[idx] = True data[f'{key}_mask'] = mask # We omit normalization factors here since those are only defined for the # inductive learning setup. sampler_data = data if args.inductive: sampler_data = to_inductive(data) loader = GraphSAINTRandomWalkSampler(sampler_data, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir) model = SAGE(data.x.size(-1), args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=4096, shuffle=False, num_workers=12) evaluator = Evaluator(name='ogbn-products') vals, tests = [], [] for run in range(args.runs): best_val, final_test = 0, 0 model.reset_parameters() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) for epoch in range(1, args.epochs + 1): loss = train_flag(model, loader, optimizer, device, args) if epoch > args.epochs / 2 and epoch % args.test_freq == 0 or epoch == args.epochs: result = test(model, data, evaluator, subgraph_loader, device) train, val, tst = result if val > best_val: best_val = val final_test = tst print(f'Run{run} val:{best_val}, test:{final_test}') vals.append(best_val) tests.append(final_test) print('') print(f"Average val accuracy: {np.mean(vals)} ± {np.std(vals)}") print(f"Average test accuracy: {np.mean(tests)} ± {np.std(tests)}")
def test_graph_saint(): adj = torch.tensor([ [+1, +2, +3, +0, +4, +0], [+5, +6, +0, +7, +0, +8], [+9, +0, 10, +0, 11, +0], [+0, 12, +0, 13, +0, 14], [15, +0, 16, +0, 17, +0], [+0, 18, +0, 19, +0, 20], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_type = adj[edge_index[0], edge_index[1]] x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) data = Data(edge_index=edge_index, x=x, edge_type=edge_type, num_nodes=6) torch.manual_seed(12345) loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4, sample_coverage=10, log=False) sample = next(iter(loader)) assert sample.x.tolist() == [[2, 2], [4, 4], [5, 5]] assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]] assert sample.edge_type.tolist() == [10, 11, 16, 17, 20] assert len(loader) == 4 for sample in loader: assert len(sample) == 5 assert sample.num_nodes <= 3 assert sample.num_edges <= 3 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) sample = next(iter(loader)) assert sample.x.tolist() == [[0, 0], [2, 2], [3, 3]] assert sample.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 1, 2]] assert sample.edge_type.tolist() == [1, 3, 9, 10, 13] assert len(loader) == 4 for sample in loader: assert len(sample) == 5 assert sample.num_nodes <= 4 assert sample.num_edges <= 4 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges torch.manual_seed(12345) loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) sample = next(iter(loader)) assert sample.x.tolist() == [[1, 1], [2, 2], [4, 4]] assert sample.edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]] assert sample.edge_type.tolist() == [6, 10, 11, 16, 17] assert len(loader) == 4 for sample in loader: assert len(sample) == 5 assert sample.num_nodes <= 4 assert sample.num_edges <= 4 * 4 assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges
splitted_idx = dataset.get_idx_split() data = dataset[0] data.n_id = torch.arange(data.num_nodes) data.node_species = None data.y = data.y.float() # Initialize features of nodes by aggregating edge features. row, col = data.edge_index #Set split indices to masks. for split in ['train', 'valid', 'test']: mask = torch.zeros(data.num_nodes, dtype=torch.bool) mask[splitted_idx[split]] = True data[f'{split}_mask'] = mask train_loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir) test_loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir) p_train_loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch_size * 2, walk_length=args.walk_length, num_steps=args.num_steps, sample_coverage=0, save_dir=dataset.processed_dir)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default='0') parser.add_argument('--model', type=str, default='GraphSAINT') parser.add_argument('--dataset', type=str, default='Reddit') # Reddit or Flickr parser.add_argument('--batch', type=int, default=2000) # Reddit:2000, Flickr:6000 parser.add_argument('--walk_length', type=int, default=4) # Reddit:4, Flickr:2 parser.add_argument('--sample_coverage', type=int, default=50) # Reddit:50, Flickr:100 parser.add_argument('--runs', type=int, default=10) parser.add_argument('--epochs', type=int, default=100) # 100, 50 parser.add_argument('--lr', type=float, default=0.01) # 0.01, 0.001 parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--hidden', type=int, default=256) # 128, 256 parser.add_argument('--dropout', type=float, default=0.1) # 0.1, 0.2 parser.add_argument('--use_normalization', action='store_true') parser.add_argument('--binarize', action='store_true') args = parser.parse_args() assert args.model in ['GraphSAINT'] assert args.dataset in ['Flickr', 'Reddit'] path = '/home/wangjunfu/dataset/graph/' + str(args.dataset) if args.dataset == 'Flickr': dataset = Flickr(path) else: dataset = Reddit(path) data = dataset[0] row, col = data.edge_index data.edge_weight = 1. / degree(col, data.num_nodes)[col] # Norm by in-degree. loader = GraphSAINTRandomWalkSampler(data, batch_size=args.batch, walk_length=args.walk_length, num_steps=5, sample_coverage=args.sample_coverage, save_dir=dataset.processed_dir, num_workers=0) device = torch.device( f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu') model = SAINT(data.num_node_features, args.hidden, dataset.num_classes, args.dropout, args.binarize).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) val_f1s, test_f1s = [], [] for run in range(1, args.runs + 1): best_val, best_test = 0, 0 model.reset_parameters() start_time = time.time() for epoch in range(1, args.epochs + 1): loss = train(model, loader, optimizer, device, args.use_normalization) accs = test(model, data, device, args.use_normalization) if accs[1] > best_val: best_val = accs[1] best_test = accs[2] if args.runs == 1: print( f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, ' f'Val: {accs[1]:.4f}, Test: {accs[2]:.4f}') test_f1s.append(best_test) print( "Run: {:d}, best val: {:.4f}, best test: {:.4f}, time cost: {:d}s". format(run, best_val, best_test, int(time.time() - start_time))) test_f1s = torch.tensor(test_f1s) print("{:.4f} ± {:.4f}".format(test_f1s.mean(), test_f1s.std()))
def test_graph_saint(): adj = torch.tensor([ [+1, +2, +3, +0, +4, +0], [+5, +6, +0, +7, +0, +8], [+9, +0, 10, +0, 11, +0], [+0, 12, +0, 13, +0, 14], [15, +0, 16, +0, 17, +0], [+0, 18, +0, 19, +0, 20], ]) edge_index = adj.nonzero(as_tuple=False).t() edge_id = adj[edge_index[0], edge_index[1]] x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) n_id = torch.arange(6) data = Data(edge_index=edge_index, x=x, n_id=n_id, edge_id=edge_id, num_nodes=6) loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, num_steps=4, sample_coverage=10, log=False) assert len(loader) == 4 for sample in loader: assert sample.num_nodes <= data.num_nodes assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 assert sample.num_nodes == sample.n_id.numel() assert sample.x.tolist() == x[sample.n_id].tolist() assert sample.edge_index.min() >= 0 assert sample.edge_index.max() < sample.num_nodes assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 assert sample.edge_id.numel() == sample.num_edges assert sample.node_norm.numel() == sample.num_nodes assert sample.edge_norm.numel() == sample.num_edges