Exemplo n.º 1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting dataset.")
    parser.add_argument('--runseed',
                        type=int,
                        default=0,
                        help="Seed for running experiments.")
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers for dataset loading')
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--split',
                        type=str,
                        default="species",
                        help='Random or species split')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    root_supervised = 'dataset/supervised'

    dataset = BioDataset(root_supervised, data_type='supervised')

    print(dataset)

    node_num = 0
    edge_num = 0
    for d in dataset:
        node_num += d.x.size()[0]
        edge_num += d.edge_index.size()[1]
    print(node_num / len(dataset))
    print(edge_num / len(dataset))
    assert False

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
    elif args.split == "species":
        trainval_dataset, test_dataset = species_split(dataset)
        train_dataset, valid_dataset, _ = random_split(trainval_dataset,
                                                       seed=args.seed,
                                                       frac_train=0.85,
                                                       frac_valid=0.15,
                                                       frac_test=0)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print("species splitting")
    else:
        raise ValueError("Unknown split name.")

    train_loader = DataLoaderFinetune(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)
    val_loader = DataLoaderFinetune(valid_dataset,
                                    batch_size=10 * args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers)

    if args.split == "random":
        test_loader = DataLoaderFinetune(test_dataset,
                                         batch_size=10 * args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)
    else:
        ### for species splitting
        test_easy_loader = DataLoaderFinetune(test_dataset_broad,
                                              batch_size=10 * args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)
        test_hard_loader = DataLoaderFinetune(test_dataset_none,
                                              batch_size=10 * args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)

    num_tasks = len(dataset[0].go_target_downstream)

    print(train_dataset[0])

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)

    if not args.model_file == "":
        model.from_pretrained(args.model_file)

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)

    train_acc_list = []
    val_acc_list = []

    ### for random splitting
    test_acc_list = []

    ### for species splitting
    test_acc_easy_list = []
    test_acc_hard_list = []

    if not args.filename == "":
        if os.path.exists(args.filename):
            print("removed existing file!!")
            os.remove(args.filename)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)

        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            train_acc = 0
            print("ommitting training evaluation")
        val_acc = eval(args, model, device, val_loader)

        val_acc_list.append(np.mean(val_acc))
        train_acc_list.append(train_acc)

        if args.split == "random":
            test_acc = eval(args, model, device, test_loader)
            test_acc_list.append(test_acc)
        else:
            test_acc_easy = eval(args, model, device, test_easy_loader)
            test_acc_hard = eval(args, model, device, test_hard_loader)
            test_acc_easy_list.append(np.mean(test_acc_easy))
            test_acc_hard_list.append(np.mean(test_acc_hard))
            print(val_acc_list[-1])
            print(test_acc_easy_list[-1])
            print(test_acc_hard_list[-1])

        print("")

    with open('result.log', 'a+') as f:
        f.write(
            str(args.runseed) + ' ' +
            str(np.array(test_acc_easy_list)[np.array(
                val_acc_list).argmax()]) + ' ' +
            str(np.array(test_acc_hard_list)[np.array(val_acc_list).argmax()]))
        f.write('\n')
Exemplo n.º 2
0
def main(args):
    torch.manual_seed(args.run_seed)
    np.random.seed(args.run_seed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.run_seed)

    # set up dataset
    if args.dataset == 'bio':
        from model import GraphPred
        root_supervised = '../data/bio/supervised'
        dataset = BioDataset(root_supervised, data_type='supervised')
        args.split = 'species'
        num_tasks = len(dataset[0].go_target_downstream)
    elif args.dataset == 'dblp':
        from model import GraphPred
        root_supervised = '../data/dblp/supervised'
        dataset = DblpDataset(root_supervised, data_type='supervised')
        args.split = 'random'
        num_tasks = len(dataset[0].go_target_downstream)
    elif args.dataset == 'chem':
        from model_chem import GraphPred
        dataset = MoleculeDataset("../data/chem/" + args.down_dataset,
                                  dataset=args.down_dataset)
        args.split = 'scaffold'
        # Bunch of classification tasks
        if args.down_dataset == "tox21":
            num_tasks = 12
        elif args.down_dataset == "hiv":
            num_tasks = 1
        elif args.down_dataset == "pcba":
            num_tasks = 128
        elif args.down_dataset == "muv":
            num_tasks = 17
        elif args.down_dataset == "bace":
            num_tasks = 1
        elif args.down_dataset == "bbbp":
            num_tasks = 1
        elif args.down_dataset == "toxcast":
            num_tasks = 617
        elif args.down_dataset == "sider":
            num_tasks = 27
        elif args.down_dataset == "clintox":
            num_tasks = 2
        else:
            raise ValueError("Invalid dataset name.")

    print(dataset)
    args.node_fea_dim = dataset[0].x.shape[1]
    args.edge_fea_dim = dataset[0].edge_attr.shape[1]
    print(args)

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
    elif args.split == "species":
        train_val_dataset, test_dataset = species_split(dataset)
        train_dataset, valid_dataset, _ = random_split(train_val_dataset,
                                                       seed=args.seed,
                                                       frac_train=0.85,
                                                       frac_valid=0.15,
                                                       frac_test=0)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print("species splitting")
    elif args.split == "scaffold":
        smiles_list = pd.read_csv('../data/chem/' + args.down_dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    else:
        raise ValueError("Unknown split name.")
    if args.dataset == 'chem':
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
        val_loader = DataLoader(valid_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)
    else:
        train_loader = DataLoaderAE(train_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers)
        val_loader = DataLoaderAE(valid_dataset,
                                  batch_size=10 * args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers)
        if args.split == "species":
            # for species splitting
            test_easy_loader = DataLoaderAE(test_dataset_broad,
                                            batch_size=10 * args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
            test_hard_loader = DataLoaderAE(test_dataset_none,
                                            batch_size=10 * args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
        else:
            test_loader = DataLoaderAE(test_dataset,
                                       batch_size=10 * args.batch_size,
                                       shuffle=False,
                                       num_workers=args.num_workers)

    print(train_dataset[0])

    # set up model
    model = GraphPred(args,
                      args.emb_dim,
                      args.edge_fea_dim,
                      num_tasks,
                      drop_ratio=args.dropout_ratio,
                      gnn_type=args.gnn_type)

    if not args.pre_trained_model_file == "":
        model.from_pretrained(
            '../res/' + args.dataset + '/' + args.pre_trained_model_file,
            '../res/' + args.dataset + '/' + args.pool_trained_model_file,
            '../res/' + args.dataset + '/' + args.emb_trained_model_file)
    model.to(device)
    # set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)

    train_acc_list = []
    val_acc_list = []
    # for random splitting
    test_acc_list = []
    # for species splitting
    test_acc_easy_list = []
    test_acc_hard_list = []
    if args.dataset == 'chem':
        os.makedirs("../res/" + args.dataset + '/' + args.down_dataset + '/' +
                    "finetune_seed" + str(args.run_seed),
                    exist_ok=True)
        fname = "../res/" + args.dataset + '/' + args.down_dataset + '/' + "finetune_seed" + str(
            args.run_seed) + "/" + args.result_file
        writer = SummaryWriter(fname)
    else:
        os.makedirs("../res/" + args.dataset + '/' + "finetune_seed" +
                    str(args.run_seed),
                    exist_ok=True)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))
        train_loss = train(args, model, device, train_loader, optimizer)
        print('train loss:', train_loss)
        train_acc = eval(args, model, device, train_loader)
        train_acc_list.append(train_acc)
        print('train auc:', train_acc.mean(0))
        val_acc = eval(args, model, device, val_loader)
        val_acc_list.append(val_acc)
        print('val auc:', val_acc.mean(0))
        if args.split == "species":
            test_acc_easy = eval(args, model, device, test_easy_loader)
            test_acc_hard = eval(args, model, device, test_hard_loader)
            test_acc_easy_list.append(test_acc_easy)
            test_acc_hard_list.append(test_acc_hard)
            print(test_acc_easy.mean(0))
            print(test_acc_hard.mean(0))
        else:
            test_acc = eval(args, model, device, test_loader)
            test_acc_list.append(test_acc)
            print(test_acc.mean(0))
        if args.dataset == 'chem' and not args.result_file == "":  # chem dataset
            writer.add_scalar('data/train auc', train_acc.mean(0), epoch)
            writer.add_scalar('data/val auc', val_acc.mean(0), epoch)
            writer.add_scalar('data/test auc', test_acc.mean(0), epoch)

        print("")

    if not args.result_file == "":
        if args.dataset == 'bio' or args.dataset == 'dblp':
            with open(
                    "../res/" + args.dataset + '/' + "finetune_seed" +
                    str(args.run_seed) + "/" + args.result_file, 'wb') as f:
                if args.split == "random":
                    pickle.dump(
                        {
                            "train": np.array(train_acc_list),
                            "val": np.array(val_acc_list),
                            "test": np.array(test_acc_list)
                        }, f)
                else:
                    pickle.dump(
                        {
                            "train": np.array(train_acc_list),
                            "val": np.array(val_acc_list),
                            "test_easy": np.array(test_acc_easy_list),
                            "test_hard": np.array(test_acc_hard_list)
                        }, f)

            print('saving model...')
            torch.save(
                model.gnn.state_dict(), "../res/" + args.dataset + '/' +
                "finetune_seed" + str(args.run_seed) + "/" + args.result_file +
                '_' + str(epoch) + "_finetuned_gnn.pth")
        else:
            writer.close()
            print('saving model...')
            torch.save(
                model.gnn.state_dict(),
                "../res/" + args.dataset + '/' + args.down_dataset + '/' +
                "finetune_seed" + str(args.run_seed) + "/" + args.result_file +
                '_' + str(epoch) + "_finetuned_gnn.pth")
Exemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--output_model_file',
                        type=str,
                        default='',
                        help='filename to output the pre-trained model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers for dataset loading')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting dataset.")
    parser.add_argument('--split',
                        type=str,
                        default="species",
                        help='Random or species split')
    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    root_supervised = 'dataset/supervised'

    dataset = BioDataset(root_supervised, data_type='supervised')

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
        print(train_dataset)
        print(valid_dataset)
        pretrain_dataset = combine_dataset(train_dataset, valid_dataset)
        print(pretrain_dataset)
    elif args.split == "species":
        print("species splitting")
        trainval_dataset, test_dataset = species_split(dataset)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print(trainval_dataset)
        print(test_dataset_broad)
        pretrain_dataset = combine_dataset(trainval_dataset,
                                           test_dataset_broad)
        print(pretrain_dataset)
        #train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0)
    else:
        raise ValueError("Unknown split name.")

    # train_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    # (Note) Fixed the bug here. DataloaderFinetune should be used here to increment the center_node_idx.
    # The resluts in the paper are obtained with the original pytorch geometric dataloder, so the results with the correct dataloader might be slightly different.
    train_loader = DataLoaderFinetune(pretrain_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)

    num_tasks = len(pretrain_dataset[0].go_target_pretrain)

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train_loss = train(args, model, device, train_loader, optimizer)

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")