def load_gin_dataset(args):
    dataset = GINDataset(args.dataset, self_loop=True)
    return GraphDataLoader(dataset,
                           batch_size=args.batch_size,
                           collate_fn=collate,
                           seed=args.seed,
                           shuffle=True,
                           split_name='fold10',
                           fold_idx=args.fold_idx).train_valid_loader()
Beispiel #2
0
def main(args):

    # set up seeds, args.seed supported
    mx.random.seed(0)
    np.random.seed(seed=0)

    if args.device >= 0:
        args.device = mx.gpu(args.device)
    else:
        args.device = mx.cpu()

    dataset = GINDataset(args.dataset, not args.learn_eps)

    trainloader, validloader = GraphDataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        seed=args.seed,
        shuffle=True,
        split_name='fold10',
        fold_idx=args.fold_idx).train_valid_loader()
    # or split_name='rand', split_ratio=0.7

    model = GIN(args.num_layers, args.num_mlp_layers, dataset.dim_nfeats,
                args.hidden_dim, dataset.gclasses, args.final_dropout,
                args.learn_eps, args.graph_pooling_type,
                args.neighbor_pooling_type)
    model.initialize(ctx=args.device)

    criterion = gluon.loss.SoftmaxCELoss()

    print(model.collect_params())
    lr_scheduler = mx.lr_scheduler.FactorScheduler(50, 0.5)
    trainer = gluon.Trainer(model.collect_params(), 'adam',
                            {'lr_scheduler': lr_scheduler})

    # it's not cost-effective to hanle the cursor and init 0
    # https://stackoverflow.com/a/23121189
    tbar = tqdm(range(args.epochs),
                unit="epoch",
                position=3,
                ncols=0,
                file=sys.stdout)
    vbar = tqdm(range(args.epochs),
                unit="epoch",
                position=4,
                ncols=0,
                file=sys.stdout)
    lrbar = tqdm(range(args.epochs),
                 unit="epoch",
                 position=5,
                 ncols=0,
                 file=sys.stdout)

    for epoch, _, _ in zip(tbar, vbar, lrbar):
        train(args, model, trainloader, trainer, criterion, epoch)

        train_loss, train_acc = eval_net(args, model, trainloader, criterion)
        tbar.set_description(
            'train set - average loss: {:.4f}, accuracy: {:.0f}%'.format(
                train_loss, 100. * train_acc))

        valid_loss, valid_acc = eval_net(args, model, validloader, criterion)
        vbar.set_description(
            'valid set - average loss: {:.4f}, accuracy: {:.0f}%'.format(
                valid_loss, 100. * valid_acc))

        if not args.filename == "":
            with open(args.filename, 'a') as f:
                f.write('%s %s %s %s' %
                        (args.dataset, args.learn_eps,
                         args.neighbor_pooling_type, args.graph_pooling_type))
                f.write("\n")
                f.write("%f %f %f %f" %
                        (train_loss, train_acc, valid_loss, valid_acc))
                f.write("\n")

        lrbar.set_description("Learning eps with learn_eps={}: {}".format(
            args.learn_eps, [
                layer.eps.data(args.device).asscalar()
                for layer in model.ginlayers
            ]))

    tbar.close()
    vbar.close()
    lrbar.close()
Beispiel #3
0
def main(args):

    # set up seeds, args.seed supported
    torch.manual_seed(seed=args.seed)
    np.random.seed(seed=args.seed)

    is_cuda = not args.disable_cuda and torch.cuda.is_available()

    if is_cuda:
        args.device = torch.device("cuda:" + str(args.device))
        torch.cuda.manual_seed_all(seed=args.seed)
    else:
        args.device = torch.device("cpu")

    dataset = GINDataset(args.dataset, not args.learn_eps)

    trainloader, validloader = GraphDataLoader(
        dataset,
        batch_size=args.batch_size,
        device=args.device,
        collate_fn=collate,
        seed=args.seed,
        shuffle=True,
        split_name='fold10',
        fold_idx=args.fold_idx).train_valid_loader()
    # or split_name='rand', split_ratio=0.7

    model = GIN(args.num_layers, args.num_mlp_layers, dataset.dim_nfeats,
                args.hidden_dim, dataset.gclasses, args.final_dropout,
                args.learn_eps, args.graph_pooling_type,
                args.neighbor_pooling_type).to(args.device)

    criterion = nn.CrossEntropyLoss()  # defaul reduce is true
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    # it's not cost-effective to hanle the cursor and init 0
    # https://stackoverflow.com/a/23121189
    tbar = tqdm(range(args.epochs),
                unit="epoch",
                position=3,
                ncols=0,
                file=sys.stdout)
    vbar = tqdm(range(args.epochs),
                unit="epoch",
                position=4,
                ncols=0,
                file=sys.stdout)
    lrbar = tqdm(range(args.epochs),
                 unit="epoch",
                 position=5,
                 ncols=0,
                 file=sys.stdout)

    for epoch, _, _ in zip(tbar, vbar, lrbar):

        train(args, model, trainloader, optimizer, criterion, epoch)
        scheduler.step()

        train_loss, train_acc = eval_net(args, model, trainloader, criterion)
        tbar.set_description(
            'train set - average loss: {:.4f}, accuracy: {:.0f}%'.format(
                train_loss, 100. * train_acc))

        valid_loss, valid_acc = eval_net(args, model, validloader, criterion)
        vbar.set_description(
            'valid set - average loss: {:.4f}, accuracy: {:.0f}%'.format(
                valid_loss, 100. * valid_acc))

        if not args.filename == "":
            with open(args.filename, 'a') as f:
                f.write('%s %s %s %s' %
                        (args.dataset, args.learn_eps,
                         args.neighbor_pooling_type, args.graph_pooling_type))
                f.write("\n")
                f.write("%f %f %f %f" %
                        (train_loss, train_acc, valid_loss, valid_acc))
                f.write("\n")

        lrbar.set_description("Learning eps with learn_eps={}: {}".format(
            args.learn_eps,
            [layer.eps.data.item() for layer in model.ginlayers]))

    tbar.close()
    vbar.close()
    lrbar.close()
Beispiel #4
0
    num_mlp_layers = 2
    hidden_dim = 64
    graph_pooling_type = "mean"
    neighbor_pooling_type = "sum"
    learn_eps = False
    seed = 0
    epochs = 100
    lr = 0.001
    final_dropout = 0.5
    disable_cuda = True
    device = 0


if __name__ == '__main__':
    args = Args()
    dataset = GINDataset(args.dataset, False)

    if not os.path.isfile("gin_model.p"):
        model = main(args, dataset)
    else:
        model = GIN(args.num_layers, args.num_mlp_layers, dataset.dim_nfeats,
                    args.hidden_dim, dataset.gclasses, args.final_dropout,
                    args.learn_eps, args.graph_pooling_type,
                    args.neighbor_pooling_type)
        model.load_state_dict(torch.load("gin_model.p"))
    model.eval()
    num_hops = args.num_layers - 1
    graph_label = 178
    g = dataset.graphs[graph_label]
    g.ndata[ExplainerTags.NODE_FEATURES] = g.ndata['attr'].float().to(
        torch.device("cpu"))