def main(): parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)') # parser.add_argument('--device', type=int, default=0) parser.add_argument('--log_steps', type=int, default=1) parser.add_argument('--use_sage', action='store_true') parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--hidden_channels', type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=201) parser.add_argument('--runs', type=int, default=1) parser.add_argument('--prune_set', type=str, default='train') parser.add_argument('--ratio', type=float, default=0.95) parser.add_argument('--times', type=int, default=20) parser.add_argument('--prune_epoch', type=int, default=301) parser.add_argument('--reset_param', type=bool, default=False) parser.add_argument('--naive', type=bool, default=False) parser.add_argument('--data_dir', type=str, default='./data/') args = parser.parse_args() log_name = f'log/arxivtest_{args.prune_set}_{args.ratio}_{args.epochs}_{args.prune_epoch}_{args.times}.log' logger.add(log_name) logger.info('logname: {}'.format(log_name)) logger.info(args) # device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' # device = torch.device(device) device = 'cuda' if torch.cuda.is_available() else 'cpu' dataset = PygNodePropPredDataset(name='ogbn-arxiv', root=args.data_dir, transform=T.ToSparseTensor()) data = dataset[0] data.adj_t = data.adj_t.to_symmetric() data = data.to(device) split_idx = dataset.get_idx_split() train_idx = split_idx['train'].to(device) if args.use_sage: model = SAGE(data.num_features, args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) else: model = GCN(data.num_features, args.hidden_channels, dataset.num_classes, args.num_layers, args.dropout).to(device) evaluator = Evaluator(name='ogbn-arxiv') logger1 = Logger(args.runs, args) row, col, val = data.adj_t.coo() N = int(row.max() + 1) row = torch.cat([torch.arange(0, N).cuda(), row], dim=0) col = torch.cat([torch.arange(0, N).cuda(), col], dim=0) edge_index = torch.cat([row, col]).view(2, -1) data.edge_index = edge_index # print(data.edge_index) pruner = Pruner(edge_index.cpu(), split_idx, prune_set=args.prune_set, ratio=args.ratio) for run in range(args.runs): model.reset_parameters() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(model, data, train_idx, optimizer) result = test(model, data, split_idx, evaluator) logger1.add_result(run, result) if epoch % args.log_steps == 0: train_acc, valid_acc, test_acc = result logger.info( f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}% Test: {100 * test_acc:.2f}%' ) logger1.print_statistics(ratio=1) logger1.flush() for i in range(1, args.times + 1): pruner.prune(naive=args.naive) if args.reset_param == True: model.reset_parameters() for epoch in range(1, 1 + args.prune_epoch): loss = train(model, data, train_idx, optimizer, pruner=pruner) result = test(model, data, split_idx, evaluator, pruner=pruner) logger1.add_result(run, result) if epoch % args.log_steps == 0: train_acc, valid_acc, test_acc = result logger.info( f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}% Test: {100 * test_acc:.2f}%' ) logger1.print_statistics(ratio=args.ratio**i) logger1.flush()
def appen(): train_best.append(t_b) valid_best.append(val_b) test_best.append(te_b) #test() # Test if inference on GPU succeeds. for run in range(args.runs): model.reset_parameters() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) for epoch in range(1, 1 + args.epochs): loss = train(epoch) torch.cuda.empty_cache() result = test() logger1.add_result(run, result) train_acc, valid_acc, test_acc = result logger.info( f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}%, Test: {100 * test_acc:.2f}%' ) t_b = np.max([train_acc, t_b]) te_b = np.max([test_acc, te_b]) val_b = np.max([valid_acc, val_b]) appen() logger1.print_statistics(ratio=1) logger1.flush() for i in range(1, args.times + 1): train_loader.prune(rec_loss, args.ratio, method=args.method) zeros()