コード例 #1
0
def main():
    parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)')
    # parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--use_sage', action='store_true')
    parser.add_argument('--num_layers', type=int, default=3)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=201)
    parser.add_argument('--runs', type=int, default=1)
    parser.add_argument('--prune_set', type=str, default='train')
    parser.add_argument('--ratio', type=float, default=0.95)
    parser.add_argument('--times', type=int, default=20)
    parser.add_argument('--prune_epoch', type=int, default=301)
    parser.add_argument('--reset_param', type=bool, default=False)
    parser.add_argument('--naive', type=bool, default=False)
    parser.add_argument('--data_dir', type=str, default='./data/')
    args = parser.parse_args()

    log_name = f'log/arxivtest_{args.prune_set}_{args.ratio}_{args.epochs}_{args.prune_epoch}_{args.times}.log'
    logger.add(log_name)
    logger.info('logname: {}'.format(log_name))
    logger.info(args)

    # device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    # device = torch.device(device)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataset = PygNodePropPredDataset(name='ogbn-arxiv',
                                     root=args.data_dir,
                                     transform=T.ToSparseTensor())

    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data = data.to(device)

    split_idx = dataset.get_idx_split()
    train_idx = split_idx['train'].to(device)

    if args.use_sage:
        model = SAGE(data.num_features, args.hidden_channels,
                     dataset.num_classes, args.num_layers,
                     args.dropout).to(device)
    else:
        model = GCN(data.num_features, args.hidden_channels,
                    dataset.num_classes, args.num_layers,
                    args.dropout).to(device)

    evaluator = Evaluator(name='ogbn-arxiv')
    logger1 = Logger(args.runs, args)
    row, col, val = data.adj_t.coo()
    N = int(row.max() + 1)
    row = torch.cat([torch.arange(0, N).cuda(), row], dim=0)
    col = torch.cat([torch.arange(0, N).cuda(), col], dim=0)
    edge_index = torch.cat([row, col]).view(2, -1)
    data.edge_index = edge_index
    # print(data.edge_index)
    pruner = Pruner(edge_index.cpu(),
                    split_idx,
                    prune_set=args.prune_set,
                    ratio=args.ratio)
    for run in range(args.runs):
        model.reset_parameters()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, data, train_idx, optimizer)
            result = test(model, data, split_idx, evaluator)
            logger1.add_result(run, result)

            if epoch % args.log_steps == 0:
                train_acc, valid_acc, test_acc = result
                logger.info(
                    f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}% Test: {100 * test_acc:.2f}%'
                )

        logger1.print_statistics(ratio=1)
        logger1.flush()
        for i in range(1, args.times + 1):
            pruner.prune(naive=args.naive)
            if args.reset_param == True:
                model.reset_parameters()
            for epoch in range(1, 1 + args.prune_epoch):
                loss = train(model, data, train_idx, optimizer, pruner=pruner)
                result = test(model, data, split_idx, evaluator, pruner=pruner)
                logger1.add_result(run, result)
                if epoch % args.log_steps == 0:
                    train_acc, valid_acc, test_acc = result
                    logger.info(
                        f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}% Test: {100 * test_acc:.2f}%'
                    )

            logger1.print_statistics(ratio=args.ratio**i)
            logger1.flush()
コード例 #2
0
def appen():
    train_best.append(t_b)
    valid_best.append(val_b)
    test_best.append(te_b)


#test()  # Test if inference on GPU succeeds.
for run in range(args.runs):
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    for epoch in range(1, 1 + args.epochs):
        loss = train(epoch)
        torch.cuda.empty_cache()
        result = test()
        logger1.add_result(run, result)
        train_acc, valid_acc, test_acc = result
        logger.info(
            f'Run: {run + 1:02d}, Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_acc:.2f}%, Valid: {100 * valid_acc:.2f}%, Test: {100 * test_acc:.2f}%'
        )
        t_b = np.max([train_acc, t_b])
        te_b = np.max([test_acc, te_b])
        val_b = np.max([valid_acc, val_b])
    appen()

    logger1.print_statistics(ratio=1)
    logger1.flush()

    for i in range(1, args.times + 1):
        train_loader.prune(rec_loss, args.ratio, method=args.method)
        zeros()