Exemplo n.º 1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting dataset.")
    parser.add_argument('--runseed',
                        type=int,
                        default=0,
                        help="Seed for running experiments.")
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers for dataset loading')
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--split',
                        type=str,
                        default="species",
                        help='Random or species split')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    root_supervised = 'dataset/supervised'

    dataset = BioDataset(root_supervised, data_type='supervised')

    print(dataset)

    node_num = 0
    edge_num = 0
    for d in dataset:
        node_num += d.x.size()[0]
        edge_num += d.edge_index.size()[1]
    print(node_num / len(dataset))
    print(edge_num / len(dataset))
    assert False

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
    elif args.split == "species":
        trainval_dataset, test_dataset = species_split(dataset)
        train_dataset, valid_dataset, _ = random_split(trainval_dataset,
                                                       seed=args.seed,
                                                       frac_train=0.85,
                                                       frac_valid=0.15,
                                                       frac_test=0)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print("species splitting")
    else:
        raise ValueError("Unknown split name.")

    train_loader = DataLoaderFinetune(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)
    val_loader = DataLoaderFinetune(valid_dataset,
                                    batch_size=10 * args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers)

    if args.split == "random":
        test_loader = DataLoaderFinetune(test_dataset,
                                         batch_size=10 * args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)
    else:
        ### for species splitting
        test_easy_loader = DataLoaderFinetune(test_dataset_broad,
                                              batch_size=10 * args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)
        test_hard_loader = DataLoaderFinetune(test_dataset_none,
                                              batch_size=10 * args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)

    num_tasks = len(dataset[0].go_target_downstream)

    print(train_dataset[0])

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)

    if not args.model_file == "":
        model.from_pretrained(args.model_file)

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)

    train_acc_list = []
    val_acc_list = []

    ### for random splitting
    test_acc_list = []

    ### for species splitting
    test_acc_easy_list = []
    test_acc_hard_list = []

    if not args.filename == "":
        if os.path.exists(args.filename):
            print("removed existing file!!")
            os.remove(args.filename)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)

        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            train_acc = 0
            print("ommitting training evaluation")
        val_acc = eval(args, model, device, val_loader)

        val_acc_list.append(np.mean(val_acc))
        train_acc_list.append(train_acc)

        if args.split == "random":
            test_acc = eval(args, model, device, test_loader)
            test_acc_list.append(test_acc)
        else:
            test_acc_easy = eval(args, model, device, test_easy_loader)
            test_acc_hard = eval(args, model, device, test_hard_loader)
            test_acc_easy_list.append(np.mean(test_acc_easy))
            test_acc_hard_list.append(np.mean(test_acc_hard))
            print(val_acc_list[-1])
            print(test_acc_easy_list[-1])
            print(test_acc_hard_list[-1])

        print("")

    with open('result.log', 'a+') as f:
        f.write(
            str(args.runseed) + ' ' +
            str(np.array(test_acc_easy_list)[np.array(
                val_acc_list).argmax()]) + ' ' +
            str(np.array(test_acc_hard_list)[np.array(val_acc_list).argmax()]))
        f.write('\n')
Exemplo n.º 2
0
def main(args):
    torch.manual_seed(args.run_seed)
    np.random.seed(args.run_seed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.run_seed)

    # set up dataset
    if args.dataset == 'bio':
        from model import GraphPred
        root_supervised = '../data/bio/supervised'
        dataset = BioDataset(root_supervised, data_type='supervised')
        args.split = 'species'
        num_tasks = len(dataset[0].go_target_downstream)
    elif args.dataset == 'dblp':
        from model import GraphPred
        root_supervised = '../data/dblp/supervised'
        dataset = DblpDataset(root_supervised, data_type='supervised')
        args.split = 'random'
        num_tasks = len(dataset[0].go_target_downstream)
    elif args.dataset == 'chem':
        from model_chem import GraphPred
        dataset = MoleculeDataset("../data/chem/" + args.down_dataset,
                                  dataset=args.down_dataset)
        args.split = 'scaffold'
        # Bunch of classification tasks
        if args.down_dataset == "tox21":
            num_tasks = 12
        elif args.down_dataset == "hiv":
            num_tasks = 1
        elif args.down_dataset == "pcba":
            num_tasks = 128
        elif args.down_dataset == "muv":
            num_tasks = 17
        elif args.down_dataset == "bace":
            num_tasks = 1
        elif args.down_dataset == "bbbp":
            num_tasks = 1
        elif args.down_dataset == "toxcast":
            num_tasks = 617
        elif args.down_dataset == "sider":
            num_tasks = 27
        elif args.down_dataset == "clintox":
            num_tasks = 2
        else:
            raise ValueError("Invalid dataset name.")

    print(dataset)
    args.node_fea_dim = dataset[0].x.shape[1]
    args.edge_fea_dim = dataset[0].edge_attr.shape[1]
    print(args)

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
    elif args.split == "species":
        train_val_dataset, test_dataset = species_split(dataset)
        train_dataset, valid_dataset, _ = random_split(train_val_dataset,
                                                       seed=args.seed,
                                                       frac_train=0.85,
                                                       frac_valid=0.15,
                                                       frac_test=0)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print("species splitting")
    elif args.split == "scaffold":
        smiles_list = pd.read_csv('../data/chem/' + args.down_dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    else:
        raise ValueError("Unknown split name.")
    if args.dataset == 'chem':
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
        val_loader = DataLoader(valid_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)
    else:
        train_loader = DataLoaderAE(train_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers)
        val_loader = DataLoaderAE(valid_dataset,
                                  batch_size=10 * args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers)
        if args.split == "species":
            # for species splitting
            test_easy_loader = DataLoaderAE(test_dataset_broad,
                                            batch_size=10 * args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
            test_hard_loader = DataLoaderAE(test_dataset_none,
                                            batch_size=10 * args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
        else:
            test_loader = DataLoaderAE(test_dataset,
                                       batch_size=10 * args.batch_size,
                                       shuffle=False,
                                       num_workers=args.num_workers)

    print(train_dataset[0])

    # set up model
    model = GraphPred(args,
                      args.emb_dim,
                      args.edge_fea_dim,
                      num_tasks,
                      drop_ratio=args.dropout_ratio,
                      gnn_type=args.gnn_type)

    if not args.pre_trained_model_file == "":
        model.from_pretrained(
            '../res/' + args.dataset + '/' + args.pre_trained_model_file,
            '../res/' + args.dataset + '/' + args.pool_trained_model_file,
            '../res/' + args.dataset + '/' + args.emb_trained_model_file)
    model.to(device)
    # set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)

    train_acc_list = []
    val_acc_list = []
    # for random splitting
    test_acc_list = []
    # for species splitting
    test_acc_easy_list = []
    test_acc_hard_list = []
    if args.dataset == 'chem':
        os.makedirs("../res/" + args.dataset + '/' + args.down_dataset + '/' +
                    "finetune_seed" + str(args.run_seed),
                    exist_ok=True)
        fname = "../res/" + args.dataset + '/' + args.down_dataset + '/' + "finetune_seed" + str(
            args.run_seed) + "/" + args.result_file
        writer = SummaryWriter(fname)
    else:
        os.makedirs("../res/" + args.dataset + '/' + "finetune_seed" +
                    str(args.run_seed),
                    exist_ok=True)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))
        train_loss = train(args, model, device, train_loader, optimizer)
        print('train loss:', train_loss)
        train_acc = eval(args, model, device, train_loader)
        train_acc_list.append(train_acc)
        print('train auc:', train_acc.mean(0))
        val_acc = eval(args, model, device, val_loader)
        val_acc_list.append(val_acc)
        print('val auc:', val_acc.mean(0))
        if args.split == "species":
            test_acc_easy = eval(args, model, device, test_easy_loader)
            test_acc_hard = eval(args, model, device, test_hard_loader)
            test_acc_easy_list.append(test_acc_easy)
            test_acc_hard_list.append(test_acc_hard)
            print(test_acc_easy.mean(0))
            print(test_acc_hard.mean(0))
        else:
            test_acc = eval(args, model, device, test_loader)
            test_acc_list.append(test_acc)
            print(test_acc.mean(0))
        if args.dataset == 'chem' and not args.result_file == "":  # chem dataset
            writer.add_scalar('data/train auc', train_acc.mean(0), epoch)
            writer.add_scalar('data/val auc', val_acc.mean(0), epoch)
            writer.add_scalar('data/test auc', test_acc.mean(0), epoch)

        print("")

    if not args.result_file == "":
        if args.dataset == 'bio' or args.dataset == 'dblp':
            with open(
                    "../res/" + args.dataset + '/' + "finetune_seed" +
                    str(args.run_seed) + "/" + args.result_file, 'wb') as f:
                if args.split == "random":
                    pickle.dump(
                        {
                            "train": np.array(train_acc_list),
                            "val": np.array(val_acc_list),
                            "test": np.array(test_acc_list)
                        }, f)
                else:
                    pickle.dump(
                        {
                            "train": np.array(train_acc_list),
                            "val": np.array(val_acc_list),
                            "test_easy": np.array(test_acc_easy_list),
                            "test_hard": np.array(test_acc_hard_list)
                        }, f)

            print('saving model...')
            torch.save(
                model.gnn.state_dict(), "../res/" + args.dataset + '/' +
                "finetune_seed" + str(args.run_seed) + "/" + args.result_file +
                '_' + str(epoch) + "_finetuned_gnn.pth")
        else:
            writer.close()
            print('saving model...')
            torch.save(
                model.gnn.state_dict(),
                "../res/" + args.dataset + '/' + args.down_dataset + '/' +
                "finetune_seed" + str(args.run_seed) + "/" + args.result_file +
                '_' + str(epoch) + "_finetuned_gnn.pth")
Exemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        "PyTorch implementation of pre-training of graph neural networks")
    parser.add_argument("--device",
                        type=int,
                        default=0,
                        help="which gpu to use if any (default: 0)")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="input batch size for training (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        help="learning rate (default: 0.001)")
    parser.add_argument(
        "--lr_scale",
        type=float,
        default=1,
        help=
        "relative learning rate for the feature extraction layer (default: 1)",
    )
    parser.add_argument("--decay",
                        type=float,
                        default=0,
                        help="weight decay (default: 0)")
    parser.add_argument(
        "--num_layer",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5).",
    )

    parser.add_argument(
        "--node_feat_dim",
        type=int,
        default=154,
        help="dimension of the node features.",
    )
    parser.add_argument("--edge_feat_dim",
                        type=int,
                        default=2,
                        help="dimension ofo the edge features.")

    parser.add_argument("--emb_dim",
                        type=int,
                        default=256,
                        help="embedding dimensions (default: 300)")
    parser.add_argument("--dropout_ratio",
                        type=float,
                        default=0.5,
                        help="dropout ratio (default: 0.5)")
    parser.add_argument(
        "--graph_pooling",
        type=str,
        default="mean",
        help="graph level pooling (sum, mean, max, set2set, attention)",
    )
    parser.add_argument(
        "--JK",
        type=str,
        default="last",
        help=
        "how the node features across layers are combined. last, sum, max or concat",
    )
    parser.add_argument("--gnn_type", type=str, default="gine")
    parser.add_argument(
        "--dataset",
        type=str,
        default="bbbp",
        help="root directory of dataset. For now, only classification.",
    )
    parser.add_argument(
        "--input_model_file",
        type=str,
        default="",
        help="filename to read the model (if there is any)",
    )
    parser.add_argument("--filename",
                        type=str,
                        default="",
                        help="output filename")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        "--runseed",
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="scaffold",
        help="random or scaffold or random_scaffold",
    )
    parser.add_argument("--eval_train",
                        type=int,
                        default=0,
                        help="evaluating training or not")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="number of workers for dataset loading",
    )
    parser.add_argument("--use_original",
                        type=int,
                        default=0,
                        help="run benchmark experiment or not")
    #parser.add_argument('--output_model_file', type = str, default = 'finetuned_model/amu', help='filename to output the finetuned model')

    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = (torch.device("cuda:" + str(args.device))
              if torch.cuda.is_available() else torch.device("cpu"))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    # Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset in ["jak1", "jak2", "jak3", "amu", "ellinger", "mpro"]:
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    # set up dataset
    #    dataset = MoleculeDataset("contextPred/chem/dataset/" + args.dataset, dataset=args.dataset)
    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset)
    if args.use_original == 0:
        dataset = MoleculeDataset(
            root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset,
            transform=ONEHOT_ENCODING(dataset=dataset),
        )

    print(dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv(
            "/raid/home/public/dataset_ContextPred_0219/" + args.dataset +
            "/processed/smiles.csv",
            header=None,
        )[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
        )
        print("scaffold")
    elif args.split == "oversample":
        train_dataset, valid_dataset, test_dataset = oversample_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("oversample")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv(
            "/raid/home/public/dataset_ContextPred_0219/" + args.dataset +
            "/processed/smiles.csv",
            header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    print(train_dataset[0])

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
    )
    val_loader = DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    # set up model
    model = GNN_graphpred(
        args.num_layer,
        args.node_feat_dim,
        args.edge_feat_dim,
        args.emb_dim,
        num_tasks,
        JK=args.JK,
        drop_ratio=args.dropout_ratio,
        graph_pooling=args.graph_pooling,
        gnn_type=args.gnn_type,
        use_embedding=args.use_original,
    )
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    # set up optimizer
    # different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    train_roc_list = []
    train_acc_list = []
    train_f1_list = []
    train_ap_list = []
    val_roc_list = []
    val_acc_list = []
    val_f1_list = []
    val_ap_list = []
    test_roc_list = []
    test_acc_list = []
    test_f1_list = []
    test_ap_list = []

    if not args.filename == "":
        fname = ("/raid/home/yoyowu/Weihua_b/BASE_TFlogs/" +
                 str(args.runseed) + "/" + args.filename)
        # delete the directory if there exists one
        if os.path.exists(fname):
            shutil.rmtree(fname)
            print("removed the existing file.")
        writer = SummaryWriter(fname)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)
        #if not args.output_model_file == "":
        #   torch.save(model.state_dict(), args.output_model_file + str(epoch)+ ".pth")

        print("====Evaluation")
        if args.eval_train:
            train_roc, train_acc, train_f1, train_ap, train_num_positive_true, train_num_positive_scores = eval(
                args, model, device, train_loader)

        else:
            print("omit the training accuracy computation")
            train_roc = 0
            train_acc = 0
            train_f1 = 0
            train_ap = 0
        val_roc, val_acc, val_f1, val_ap, val_num_positive_true, val_num_positive_scores = eval(
            args, model, device, val_loader)
        test_roc, test_acc, test_f1, test_ap, test_num_positive_true, test_num_positive_scores = eval(
            args, model, device, test_loader)
        #with open('debug_ellinger.txt', "a") as f:
        #   f.write("====epoch " + str(epoch) +" \n training:  positive true count {} , positive scores count {} \n".format(train_num_positive_true,train_num_positive_scores))
        #  f.write("val:  positive true count {} , positive scores count {} \n".format(val_num_positive_true,val_num_positive_scores))
        # f.write("test:  positive true count {} , positive scores count {} \n".format(test_num_positive_true,test_num_positive_scores))
        #f.write("\n")

        print("train: %f val: %f test auc: %f " %
              (train_roc, val_roc, test_roc))
        val_roc_list.append(val_roc)
        val_f1_list.append(val_f1)
        val_acc_list.append(val_acc)
        val_ap_list.append(val_ap)
        test_acc_list.append(test_acc)
        test_roc_list.append(test_roc)
        test_f1_list.append(test_f1)
        test_ap_list.append(test_ap)
        train_acc_list.append(train_acc)
        train_roc_list.append(train_roc)
        train_f1_list.append(train_f1)
        train_ap_list.append(train_ap)

        if not args.filename == "":
            writer.add_scalar("data/train roc", train_roc, epoch)
            writer.add_scalar("data/train acc", train_acc, epoch)
            writer.add_scalar("data/train f1", train_f1, epoch)
            writer.add_scalar("data/train ap", train_ap, epoch)

            writer.add_scalar("data/val roc", val_roc, epoch)
            writer.add_scalar("data/val acc", val_acc, epoch)
            writer.add_scalar("data/val f1", val_f1, epoch)
            writer.add_scalar("data/val ap", val_ap, epoch)

            writer.add_scalar("data/test roc", test_roc, epoch)
            writer.add_scalar("data/test acc", test_acc, epoch)
            writer.add_scalar("data/test f1", test_f1, epoch)
            writer.add_scalar("data/test ap", test_ap, epoch)

        print("")

    if not args.filename == "":
        writer.close()
Exemplo n.º 4
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--output_model_file',
                        type=str,
                        default='',
                        help='filename to output the pre-trained model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers for dataset loading')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting dataset.")
    parser.add_argument('--split',
                        type=str,
                        default="species",
                        help='Random or species split')
    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    root_supervised = 'dataset/supervised'

    dataset = BioDataset(root_supervised, data_type='supervised')

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
        print(train_dataset)
        print(valid_dataset)
        pretrain_dataset = combine_dataset(train_dataset, valid_dataset)
        print(pretrain_dataset)
    elif args.split == "species":
        print("species splitting")
        trainval_dataset, test_dataset = species_split(dataset)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print(trainval_dataset)
        print(test_dataset_broad)
        pretrain_dataset = combine_dataset(trainval_dataset,
                                           test_dataset_broad)
        print(pretrain_dataset)
        #train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0)
    else:
        raise ValueError("Unknown split name.")

    # train_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    # (Note) Fixed the bug here. DataloaderFinetune should be used here to increment the center_node_idx.
    # The resluts in the paper are obtained with the original pytorch geometric dataloder, so the results with the correct dataloader might be slightly different.
    train_loader = DataLoaderFinetune(pretrain_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)

    num_tasks = len(pretrain_dataset[0].go_target_pretrain)

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train_loss = train(args, model, device, train_loader, optimizer)

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")
Exemplo n.º 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=3,
        help='number of GNN message passing layers (default: 3).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=512,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='esol',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=1,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="random",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=1,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--aug1',
                        type=str,
                        default='dropN_random',
                        help='augmentation1')
    parser.add_argument('--aug2',
                        type=str,
                        default='dropN_random',
                        help='augmentation2')
    parser.add_argument('--aug_ratio1',
                        type=float,
                        default=0.0,
                        help='aug ratio1')
    parser.add_argument('--aug_ratio2',
                        type=float,
                        default=0.0,
                        help='aug ratio2')
    parser.add_argument('--dataset_load',
                        type=str,
                        default='esol',
                        help='load pretrain model from which dataset.')
    parser.add_argument('--protocol',
                        type=str,
                        default='linear',
                        help='downstream protocol, linear, nonlinear')
    parser.add_argument(
        '--semi_ratio',
        type=float,
        default=1.0,
        help='proportion of labels in semi-supervised settings')
    parser.add_argument('--pretrain_method',
                        type=str,
                        default='local',
                        help='pretrain_method: local, global')
    parser.add_argument('--lamb',
                        type=float,
                        default=0.0,
                        help='hyper para of global-structure loss')
    parser.add_argument('--n_nb',
                        type=int,
                        default=0,
                        help='number of neighbors for  global-structure loss')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    if args.dataset in [
            'tox21', 'hiv', 'pcba', 'muv', 'bace', 'bbbp', 'toxcast', 'sider',
            'clintox', 'mutag'
    ]:
        task_type = 'cls'
    else:
        task_type = 'reg'

    #Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset == 'esol':
        num_tasks = 1
    elif args.dataset == 'freesolv':
        num_tasks = 1
    elif args.dataset == 'mutag':
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    print('The whole dataset:', dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    # semi-supervised settings
    if args.semi_ratio != 1.0:
        n_total, n_sample = len(train_dataset), int(
            len(train_dataset) * args.semi_ratio)
        print(
            'sample {:.2f} = {:d} labels for semi-supervised training!'.format(
                args.semi_ratio, n_sample))
        all_idx = list(range(n_total))
        random.seed(0)
        idx_semi = random.sample(all_idx, n_sample)
        train_dataset = train_dataset[torch.tensor(
            idx_semi)]  #int(len(train_dataset)*args.semi_ratio)
        print('new train dataset size:', len(train_dataset))
    else:
        print('finetune using all data!')

    if args.dataset == 'freesolv':
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  drop_last=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    val_loader = DataLoader(valid_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    if args.pretrain_method == 'local':
        load_dir = 'results/' + args.dataset + '/pretrain_local/'
        save_dir = 'results/' + args.dataset + '/finetune_local/'
    elif args.pretrain_method == 'global':
        load_dir = 'results/' + args.dataset + '/pretrain_global/nb_' + str(
            args.n_nb) + '/'
        save_dir = 'results/' + args.dataset + '/finetune_global/nb_' + str(
            args.n_nb) + '/'
    else:
        print('Invalid method!!')

    if not os.path.exists(save_dir):
        os.system('mkdir -p %s' % save_dir)

    if not args.input_model_file == "":
        input_model_str = args.dataset_load + '_aug1_' + args.aug1 + '_' + str(
            args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(
                args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(
                    args.dropout_ratio) + '_seed_' + str(args.runseed)
        output_model_str = args.dataset + '_semi_' + str(
            args.semi_ratio
        ) + '_protocol_' + args.protocol + '_aug1_' + args.aug1 + '_' + str(
            args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(
                args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(
                    args.dropout_ratio) + '_seed_' + str(
                        args.runseed) + '_' + str(args.seed)
    else:
        output_model_str = 'scratch_' + args.dataset + '_semi_' + str(
            args.semi_ratio) + '_protocol_' + args.protocol + '_do_' + str(
                args.dropout_ratio) + '_seed_' + str(args.runseed) + '_' + str(
                    args.seed)

    txtfile = save_dir + output_model_str + ".txt"
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    if os.path.exists(txtfile):
        os.system('mv %s %s' %
                  (txtfile, txtfile +
                   ".bak-%s" % nowTime))  # rename exsist file for collison

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(load_dir + args.input_model_file +
                              input_model_str + '.pth')
        print('successfully load pretrained model!')
    else:
        print('No pretrain! train from scratch!')

    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    # if linear protocol, fix GNN layers
    if args.protocol == 'linear':
        print("linear protocol, only train the top layer!")
        for name, param in model.named_parameters():
            if not 'pred_linear' in name:
                param.requires_grad = False
    elif args.protocol == 'nonlinear':
        print("finetune protocol, train all the layers!")
    else:
        print("invalid protocol!")

    # all task info summary
    print('=========task summary=========')
    print('Dataset: ', args.dataset)
    if args.semi_ratio == 1.0:
        print('full-supervised {:.2f}'.format(args.semi_ratio))
    else:
        print('semi-supervised {:.2f}'.format(args.semi_ratio))
    if args.input_model_file == '':
        print('scratch or finetune: scratch')
        print('loaded model from: - ')
    else:
        print('scratch or finetune: finetune')
        print('loaded model from: ', args.dataset_load)
        print('global_mode: n_nb = ', args.n_nb)
    print('Protocol: ', args.protocol)
    print('task type:', task_type)
    print('=========task summary=========')

    # training based on task type
    if task_type == 'cls':
        with open(txtfile, "a") as myfile:
            myfile.write('epoch: train_auc val_auc test_auc\n')
        wait = 0
        best_auc = 0
        patience = 10
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train_cls(args, model, device, train_loader, optimizer)

            print("====Evaluation")
            if args.eval_train:
                train_auc = eval_cls(args, model, device, train_loader)
            else:
                print("omit the training accuracy computation")
                train_auc = 0
            val_auc = eval_cls(args, model, device, val_loader)
            test_auc = eval_cls(args, model, device, test_loader)

            with open(txtfile, "a") as myfile:
                myfile.write(
                    str(int(epoch)) + ': ' + str(train_auc) + ' ' +
                    str(val_auc) + ' ' + str(test_auc) + "\n")

            print("train: %f val: %f test: %f" %
                  (train_auc, val_auc, test_auc))

            # Early stopping
            if np.greater(val_auc, best_auc):  # change for train loss
                best_auc = val_auc
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print(
                        'Early stop at Epoch: {:d} with final val auc: {:.4f}'.
                        format(epoch, val_auc))
                    break

    elif task_type == 'reg':
        with open(txtfile, "a") as myfile:
            myfile.write(
                'epoch: train_mse train_cor val_mse val_cor test_mse test_cor\n'
            )
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train(args, model, device, train_loader, optimizer)

            print("====Evaluation")
            if args.eval_train:
                train_mse, train_cor = eval_reg(args, model, device,
                                                train_loader)
            else:
                print("omit the training accuracy computation")
                train_mse, train_cor = 0, 0
            val_mse, val_cor = eval_reg(args, model, device, val_loader)
            test_mse, test_cor = eval_reg(args, model, device, test_loader)

            with open(txtfile, "a") as myfile:
                myfile.write(
                    str(int(epoch)) + ': ' + str(train_mse) + ' ' +
                    str(train_cor) + ' ' + str(val_mse) + ' ' + str(val_cor) +
                    ' ' + str(test_mse) + ' ' + str(test_cor) + "\n")

            print("train: %f val: %f test: %f" %
                  (train_mse, val_mse, test_mse))
            print("train: %f val: %f test: %f" %
                  (train_cor, val_cor, test_cor))
Exemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=16,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_decay', type=float, default=0.995,
                        help='learning rate decay (default: 0.995)')
    parser.add_argument('--lr_scale', type=float, default=1,
                        help='relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--loss_type', type=str, default="bce")
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=768,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--heads', type=int, default=12,
                        help='multi heads (default: 4)')
    parser.add_argument('--num_message_passing', type=int, default=3,
                        help='message passing steps (default: 3)')
    parser.add_argument('--dropout_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--graph_pooling', type=str, default="collection",
                        help='graph level pooling (collection,sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--dataset', type=str, default='biosnap',
                        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file', type=str, default='pretrained_model/MolGNet.pt',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename', type=str, default='', help='output filename')
    parser.add_argument('--seed', type=int, default=42, help="Seed for splitting the dataset.")
    parser.add_argument('--runseed', type=int, default=0, help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split', type=str, default="scaffold", help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train', type=int, default=0, help='evaluating training or not')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataset loading')
    parser.add_argument('--iters', type=int, default=1, help='number of run seeds')
    parser.add_argument('--save', type=str, default=None)
    parser.add_argument('--log_freq', type=int, default=0)
    parser.add_argument('--KFold', type=int, default=5, help='number of folds for cross validation')
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
                             "of training.")
    parser.add_argument('--cpu', default=False, action="store_true")
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available() and not args.cpu else torch.device("cpu")
    args.seed = args.seed
    args.runseed = args.runseed
    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    exp_path = 'runs/{}/'.format(args.dataset)
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)

    num_tasks = 1
    dataset = MoleculeDataset("data/downstream/" + args.dataset, dataset=args.dataset, transform=None)

    train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.7, frac_valid=0.1,
                                                              frac_test=0.2, seed=args.seed)
    train_loader = DataLoaderMasking(train_dataset, batch_size=args.batch_size, shuffle=True,
                                     num_workers=args.num_workers)
    val_loader = DataLoaderMasking(valid_dataset, batch_size=args.batch_size, shuffle=False,
                                   num_workers=args.num_workers)
    test_loader = DataLoaderMasking(test_dataset, batch_size=args.batch_size, shuffle=False,
                                    num_workers=args.num_workers)
    # set up model
    model = MolGT_graphpred(args.num_layer, args.emb_dim, args.heads, args.num_message_passing, num_tasks,
                            drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file)
        print('Pretrained model loaded')
    else:
        print('No pretrain')

    total = sum([param.nelement() for param in model.gnn.parameters()])
    print("Number of GNN parameter: %.2fM" % (total/1e6))

    model.to(device)

    model_param_group = list(model.gnn.named_parameters())
    if args.graph_pooling == "attention":
        model_param_group += list(model.pool.named_parameters())
    model_param_group += list(model.graph_pred_linear.named_parameters())

    param_optimizer = [n for n in model_param_group if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    num_train_optimization_steps = int(len(train_dataset) / args.batch_size) * args.epochs
    print(num_train_optimization_steps)
    optimizer = optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay)

    if args.loss_type == 'bce':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss_type == 'softmax':
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = FocalLoss(gamma=2, alpha=0.25)

    best_result = []
    acc=0
    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))
        train(args, model, device, train_loader, optimizer,criterion)
        scheduler.step()
        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            print("omit the training accuracy computation")
            train_acc = 0
        result = eval(args, model, device, val_loader)
        print('validation: ',result)
        test_result = eval(args, model, device, test_loader)
        print('test: ', test_result)
        if result[0] > acc:
            acc = result[0]
            best_result=test_result
            torch.save(model.state_dict(), exp_path + "model_seed{}.pkl".format(args.seed))
            print("save network for epoch:", epoch, acc)
    with open(exp_path+"log.txt", "a+") as f:
        log = 'Test metrics: auc,prc,f1 is {}, at seed {}'.format(best_result,args.seed)
        print(log)
        f.write(log)
        f.write('\n')
Exemplo n.º 7
0
Arquivo: eval.py Projeto: pyli0628/MPG
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.995,
                        help='learning rate decay (default: 0.995)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--loss_type', type=str, default="bce")
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=768,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--heads',
                        type=int,
                        default=12,
                        help='multi heads (default: 4)')
    parser.add_argument('--num_message_passing',
                        type=int,
                        default=3,
                        help='message passing steps (default: 3)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="collection",
        help=
        'graph level pooling (collection,sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='twosides',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_model/MolGNet.pt',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--iters',
                        type=int,
                        default=1,
                        help='number of run seeds')
    parser.add_argument('--log_file', type=str, default=None)
    parser.add_argument('--log_freq', type=int, default=0)
    parser.add_argument('--KFold',
                        type=int,
                        default=5,
                        help='number of folds for cross validation')
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument('--cpu', default=False, action="store_true")
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")
    args.seed = args.seed
    args.runseed = args.runseed
    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    args.model_dir = os.path.join(args.model_dir, args.dataset)

    num_tasks = 1
    dataset = MoleculeDataset("data/downstream/" + args.dataset,
                              dataset=args.dataset,
                              transform=None)
    if args.dataset == 'twosides':
        print('Run 5-fold cross validation')
        for fold in range(args.KFold):
            train_dataset, test_dataset = cv_random_split(dataset, fold, 5)
            test_loader = DataLoaderMasking(test_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
            # set up model
            model = MolGT_graphpred(args.num_layer,
                                    args.emb_dim,
                                    args.heads,
                                    args.num_message_passing,
                                    num_tasks,
                                    drop_ratio=args.dropout_ratio,
                                    graph_pooling=args.graph_pooling)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.model_dir, 'Fold_{}.pkl'.format(fold))))
            model.to(device)

            acc, f1, prec, rec = eval_twosides(model, device, test_loader)
            print('Dataset:{}, Fold:{}, precision:{}, recall:{}, F1:{}'.format(
                args.dataset, fold, prec, rec, f1))

    if args.dataset == 'biosnap':
        print('Run three random split')
        for i in range(3):
            seed = args.seed + i
            train_dataset, valid_dataset, test_dataset = random_split(
                dataset,
                null_value=0,
                frac_train=0.7,
                frac_valid=0.1,
                frac_test=0.2,
                seed=seed)
            test_loader = DataLoaderMasking(test_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
            model = MolGT_graphpred(args.num_layer,
                                    args.emb_dim,
                                    args.heads,
                                    args.num_message_passing,
                                    num_tasks,
                                    drop_ratio=args.dropout_ratio,
                                    graph_pooling=args.graph_pooling)
            model.load_state_dict(
                torch.load(
                    os.path.join(args.model_dir,
                                 'model_seed{}.pkl'.format(seed))))
            model.to(device)
            auc, prc, f1 = eval_biosnap(model, device, test_loader)
            print('Dataset:{}, Fold:{}, AUC:{}, PRC:{}, F1:{}'.format(
                args.dataset, i, auc, prc, f1))