Ejemplo n.º 1
0
def main():
    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + "repurposing")

    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + "repurposing",
        transform=ONEHOT_ENCODING(dataset=dataset),
    )

    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
    )

    model = GNN_graphpred(
        num_layer=5,
        node_feat_dim=154,
        edge_feat_dim=2,
        emb_dim=256,
        num_tasks=1,
        JK="last",
        drop_ratio=0.5,
        graph_pooling="mean",
        gnn_type="gine",
        use_embedding=0,
    )

    model.load_state_dict(torch.load("tuned_model/jak3/90.pth"))
    model.eval()
    id = []
    cid = []
    score = []
    fields = ['id', 'cid', 'score']
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr,
                         batch.batch)
            id.append(batch.id)
            cid.append(batch.cid)
            score.append(pred)

    dict = {'id': id, 'cid': cid, 'score': score}
    df = pd.DataFrame(dict)

    df.to_csv('jak3_score_90.csv')
Ejemplo n.º 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument(
        '--dataset',
        type=str,
        default='chembl_filtered',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--output_model_file',
                        type=str,
                        default='',
                        help='filename to output the pre-trained model')
    parser.add_argument('--num_workers',
                        type=int,
                        default=8,
                        help='number of workers for dataset loading')
    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #Bunch of classification tasks
    if args.dataset == "chembl_filtered":
        num_tasks = 1310
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    loader = DataLoader(dataset,
                        batch_size=args.batch_size,
                        shuffle=True,
                        num_workers=args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, loader, optimizer)

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")
Ejemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting dataset.")
    parser.add_argument('--runseed',
                        type=int,
                        default=0,
                        help="Seed for running experiments.")
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers for dataset loading')
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--split',
                        type=str,
                        default="species",
                        help='Random or species split')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    root_supervised = 'dataset/supervised'

    dataset = BioDataset(root_supervised, data_type='supervised')

    print(dataset)

    node_num = 0
    edge_num = 0
    for d in dataset:
        node_num += d.x.size()[0]
        edge_num += d.edge_index.size()[1]
    print(node_num / len(dataset))
    print(edge_num / len(dataset))
    assert False

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
    elif args.split == "species":
        trainval_dataset, test_dataset = species_split(dataset)
        train_dataset, valid_dataset, _ = random_split(trainval_dataset,
                                                       seed=args.seed,
                                                       frac_train=0.85,
                                                       frac_valid=0.15,
                                                       frac_test=0)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print("species splitting")
    else:
        raise ValueError("Unknown split name.")

    train_loader = DataLoaderFinetune(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)
    val_loader = DataLoaderFinetune(valid_dataset,
                                    batch_size=10 * args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers)

    if args.split == "random":
        test_loader = DataLoaderFinetune(test_dataset,
                                         batch_size=10 * args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers)
    else:
        ### for species splitting
        test_easy_loader = DataLoaderFinetune(test_dataset_broad,
                                              batch_size=10 * args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)
        test_hard_loader = DataLoaderFinetune(test_dataset_none,
                                              batch_size=10 * args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers)

    num_tasks = len(dataset[0].go_target_downstream)

    print(train_dataset[0])

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)

    if not args.model_file == "":
        model.from_pretrained(args.model_file)

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)

    train_acc_list = []
    val_acc_list = []

    ### for random splitting
    test_acc_list = []

    ### for species splitting
    test_acc_easy_list = []
    test_acc_hard_list = []

    if not args.filename == "":
        if os.path.exists(args.filename):
            print("removed existing file!!")
            os.remove(args.filename)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)

        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            train_acc = 0
            print("ommitting training evaluation")
        val_acc = eval(args, model, device, val_loader)

        val_acc_list.append(np.mean(val_acc))
        train_acc_list.append(train_acc)

        if args.split == "random":
            test_acc = eval(args, model, device, test_loader)
            test_acc_list.append(test_acc)
        else:
            test_acc_easy = eval(args, model, device, test_easy_loader)
            test_acc_hard = eval(args, model, device, test_hard_loader)
            test_acc_easy_list.append(np.mean(test_acc_easy))
            test_acc_hard_list.append(np.mean(test_acc_hard))
            print(val_acc_list[-1])
            print(test_acc_easy_list[-1])
            print(test_acc_hard_list[-1])

        print("")

    with open('result.log', 'a+') as f:
        f.write(
            str(args.runseed) + ' ' +
            str(np.array(test_acc_easy_list)[np.array(
                val_acc_list).argmax()]) + ' ' +
            str(np.array(test_acc_hard_list)[np.array(val_acc_list).argmax()]))
        f.write('\n')
Ejemplo n.º 4
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='tox21',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    #Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    print(dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    print(train_dataset[0])

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    val_loader = DataLoader(valid_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file)

    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    train_acc_list = []
    val_acc_list = []
    test_acc_list = []

    if not args.filename == "":
        fname = 'runs/finetune_cls_runseed' + str(
            args.runseed) + '/' + args.filename
        #delete the directory if there exists one
        if os.path.exists(fname):
            shutil.rmtree(fname)
            print("removed the existing file.")
        writer = SummaryWriter(fname)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)

        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            print("omit the training accuracy computation")
            train_acc = 0
        val_acc = eval(args, model, device, val_loader)
        test_acc = eval(args, model, device, test_loader)

        print("train: %f val: %f test: %f" % (train_acc, val_acc, test_acc))

        val_acc_list.append(val_acc)
        test_acc_list.append(test_acc)
        train_acc_list.append(train_acc)

        if not args.filename == "":
            writer.add_scalar('data/train auc', train_acc, epoch)
            writer.add_scalar('data/val auc', val_acc, epoch)
            writer.add_scalar('data/test auc', test_acc, epoch)

        print("")

    if not args.filename == "":
        writer.close()
Ejemplo n.º 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        "PyTorch implementation of pre-training of graph neural networks")
    parser.add_argument("--device",
                        type=int,
                        default=0,
                        help="which gpu to use if any (default: 0)")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="input batch size for training (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        help="learning rate (default: 0.001)")
    parser.add_argument(
        "--lr_scale",
        type=float,
        default=1,
        help=
        "relative learning rate for the feature extraction layer (default: 1)",
    )
    parser.add_argument("--decay",
                        type=float,
                        default=0,
                        help="weight decay (default: 0)")
    parser.add_argument(
        "--num_layer",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5).",
    )

    parser.add_argument(
        "--node_feat_dim",
        type=int,
        default=154,
        help="dimension of the node features.",
    )
    parser.add_argument("--edge_feat_dim",
                        type=int,
                        default=2,
                        help="dimension ofo the edge features.")

    parser.add_argument("--emb_dim",
                        type=int,
                        default=256,
                        help="embedding dimensions (default: 300)")
    parser.add_argument("--dropout_ratio",
                        type=float,
                        default=0.5,
                        help="dropout ratio (default: 0.5)")
    parser.add_argument(
        "--graph_pooling",
        type=str,
        default="mean",
        help="graph level pooling (sum, mean, max, set2set, attention)",
    )
    parser.add_argument(
        "--JK",
        type=str,
        default="last",
        help=
        "how the node features across layers are combined. last, sum, max or concat",
    )
    parser.add_argument("--gnn_type", type=str, default="gine")
    parser.add_argument(
        "--dataset",
        type=str,
        default="bbbp",
        help="root directory of dataset. For now, only classification.",
    )
    parser.add_argument(
        "--input_model_file",
        type=str,
        default="",
        help="filename to read the model (if there is any)",
    )
    parser.add_argument("--filename",
                        type=str,
                        default="",
                        help="output filename")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        "--runseed",
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="scaffold",
        help="random or scaffold or random_scaffold",
    )
    parser.add_argument("--eval_train",
                        type=int,
                        default=0,
                        help="evaluating training or not")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="number of workers for dataset loading",
    )
    parser.add_argument("--use_original",
                        type=int,
                        default=0,
                        help="run benchmark experiment or not")
    #parser.add_argument('--output_model_file', type = str, default = 'finetuned_model/amu', help='filename to output the finetuned model')

    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = (torch.device("cuda:" + str(args.device))
              if torch.cuda.is_available() else torch.device("cpu"))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    # Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset in ["jak1", "jak2", "jak3", "amu", "ellinger", "mpro"]:
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    # set up dataset
    #    dataset = MoleculeDataset("contextPred/chem/dataset/" + args.dataset, dataset=args.dataset)
    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset)
    if args.use_original == 0:
        dataset = MoleculeDataset(
            root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset,
            transform=ONEHOT_ENCODING(dataset=dataset),
        )

    print(dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv(
            "/raid/home/public/dataset_ContextPred_0219/" + args.dataset +
            "/processed/smiles.csv",
            header=None,
        )[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
        )
        print("scaffold")
    elif args.split == "oversample":
        train_dataset, valid_dataset, test_dataset = oversample_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("oversample")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv(
            "/raid/home/public/dataset_ContextPred_0219/" + args.dataset +
            "/processed/smiles.csv",
            header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    print(train_dataset[0])

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
    )
    val_loader = DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    # set up model
    model = GNN_graphpred(
        args.num_layer,
        args.node_feat_dim,
        args.edge_feat_dim,
        args.emb_dim,
        num_tasks,
        JK=args.JK,
        drop_ratio=args.dropout_ratio,
        graph_pooling=args.graph_pooling,
        gnn_type=args.gnn_type,
        use_embedding=args.use_original,
    )
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    # set up optimizer
    # different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    train_roc_list = []
    train_acc_list = []
    train_f1_list = []
    train_ap_list = []
    val_roc_list = []
    val_acc_list = []
    val_f1_list = []
    val_ap_list = []
    test_roc_list = []
    test_acc_list = []
    test_f1_list = []
    test_ap_list = []

    if not args.filename == "":
        fname = ("/raid/home/yoyowu/Weihua_b/BASE_TFlogs/" +
                 str(args.runseed) + "/" + args.filename)
        # delete the directory if there exists one
        if os.path.exists(fname):
            shutil.rmtree(fname)
            print("removed the existing file.")
        writer = SummaryWriter(fname)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)
        #if not args.output_model_file == "":
        #   torch.save(model.state_dict(), args.output_model_file + str(epoch)+ ".pth")

        print("====Evaluation")
        if args.eval_train:
            train_roc, train_acc, train_f1, train_ap, train_num_positive_true, train_num_positive_scores = eval(
                args, model, device, train_loader)

        else:
            print("omit the training accuracy computation")
            train_roc = 0
            train_acc = 0
            train_f1 = 0
            train_ap = 0
        val_roc, val_acc, val_f1, val_ap, val_num_positive_true, val_num_positive_scores = eval(
            args, model, device, val_loader)
        test_roc, test_acc, test_f1, test_ap, test_num_positive_true, test_num_positive_scores = eval(
            args, model, device, test_loader)
        #with open('debug_ellinger.txt', "a") as f:
        #   f.write("====epoch " + str(epoch) +" \n training:  positive true count {} , positive scores count {} \n".format(train_num_positive_true,train_num_positive_scores))
        #  f.write("val:  positive true count {} , positive scores count {} \n".format(val_num_positive_true,val_num_positive_scores))
        # f.write("test:  positive true count {} , positive scores count {} \n".format(test_num_positive_true,test_num_positive_scores))
        #f.write("\n")

        print("train: %f val: %f test auc: %f " %
              (train_roc, val_roc, test_roc))
        val_roc_list.append(val_roc)
        val_f1_list.append(val_f1)
        val_acc_list.append(val_acc)
        val_ap_list.append(val_ap)
        test_acc_list.append(test_acc)
        test_roc_list.append(test_roc)
        test_f1_list.append(test_f1)
        test_ap_list.append(test_ap)
        train_acc_list.append(train_acc)
        train_roc_list.append(train_roc)
        train_f1_list.append(train_f1)
        train_ap_list.append(train_ap)

        if not args.filename == "":
            writer.add_scalar("data/train roc", train_roc, epoch)
            writer.add_scalar("data/train acc", train_acc, epoch)
            writer.add_scalar("data/train f1", train_f1, epoch)
            writer.add_scalar("data/train ap", train_ap, epoch)

            writer.add_scalar("data/val roc", val_roc, epoch)
            writer.add_scalar("data/val acc", val_acc, epoch)
            writer.add_scalar("data/val f1", val_f1, epoch)
            writer.add_scalar("data/val ap", val_ap, epoch)

            writer.add_scalar("data/test roc", test_roc, epoch)
            writer.add_scalar("data/test acc", test_acc, epoch)
            writer.add_scalar("data/test f1", test_f1, epoch)
            writer.add_scalar("data/test ap", test_ap, epoch)

        print("")

    if not args.filename == "":
        writer.close()
Ejemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--output_model_file',
                        type=str,
                        default='',
                        help='filename to output the pre-trained model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of workers for dataset loading')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting dataset.")
    parser.add_argument('--split',
                        type=str,
                        default="species",
                        help='Random or species split')
    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    root_supervised = 'dataset/supervised'

    dataset = BioDataset(root_supervised, data_type='supervised')

    if args.split == "random":
        print("random splitting")
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset, seed=args.seed)
        print(train_dataset)
        print(valid_dataset)
        pretrain_dataset = combine_dataset(train_dataset, valid_dataset)
        print(pretrain_dataset)
    elif args.split == "species":
        print("species splitting")
        trainval_dataset, test_dataset = species_split(dataset)
        test_dataset_broad, test_dataset_none, _ = random_split(test_dataset,
                                                                seed=args.seed,
                                                                frac_train=0.5,
                                                                frac_valid=0.5,
                                                                frac_test=0)
        print(trainval_dataset)
        print(test_dataset_broad)
        pretrain_dataset = combine_dataset(trainval_dataset,
                                           test_dataset_broad)
        print(pretrain_dataset)
        #train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0)
    else:
        raise ValueError("Unknown split name.")

    # train_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    # (Note) Fixed the bug here. DataloaderFinetune should be used here to increment the center_node_idx.
    # The resluts in the paper are obtained with the original pytorch geometric dataloder, so the results with the correct dataloader might be slightly different.
    train_loader = DataLoaderFinetune(pretrain_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)

    num_tasks = len(pretrain_dataset[0].go_target_pretrain)

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train_loss = train(args, model, device, train_loader, optimizer)

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")
Ejemplo n.º 7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=3,
        help='number of GNN message passing layers (default: 3).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=512,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='esol',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=1,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="random",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=1,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--aug1',
                        type=str,
                        default='dropN_random',
                        help='augmentation1')
    parser.add_argument('--aug2',
                        type=str,
                        default='dropN_random',
                        help='augmentation2')
    parser.add_argument('--aug_ratio1',
                        type=float,
                        default=0.0,
                        help='aug ratio1')
    parser.add_argument('--aug_ratio2',
                        type=float,
                        default=0.0,
                        help='aug ratio2')
    parser.add_argument('--dataset_load',
                        type=str,
                        default='esol',
                        help='load pretrain model from which dataset.')
    parser.add_argument('--protocol',
                        type=str,
                        default='linear',
                        help='downstream protocol, linear, nonlinear')
    parser.add_argument(
        '--semi_ratio',
        type=float,
        default=1.0,
        help='proportion of labels in semi-supervised settings')
    parser.add_argument('--pretrain_method',
                        type=str,
                        default='local',
                        help='pretrain_method: local, global')
    parser.add_argument('--lamb',
                        type=float,
                        default=0.0,
                        help='hyper para of global-structure loss')
    parser.add_argument('--n_nb',
                        type=int,
                        default=0,
                        help='number of neighbors for  global-structure loss')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    if args.dataset in [
            'tox21', 'hiv', 'pcba', 'muv', 'bace', 'bbbp', 'toxcast', 'sider',
            'clintox', 'mutag'
    ]:
        task_type = 'cls'
    else:
        task_type = 'reg'

    #Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset == 'esol':
        num_tasks = 1
    elif args.dataset == 'freesolv':
        num_tasks = 1
    elif args.dataset == 'mutag':
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    print('The whole dataset:', dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    # semi-supervised settings
    if args.semi_ratio != 1.0:
        n_total, n_sample = len(train_dataset), int(
            len(train_dataset) * args.semi_ratio)
        print(
            'sample {:.2f} = {:d} labels for semi-supervised training!'.format(
                args.semi_ratio, n_sample))
        all_idx = list(range(n_total))
        random.seed(0)
        idx_semi = random.sample(all_idx, n_sample)
        train_dataset = train_dataset[torch.tensor(
            idx_semi)]  #int(len(train_dataset)*args.semi_ratio)
        print('new train dataset size:', len(train_dataset))
    else:
        print('finetune using all data!')

    if args.dataset == 'freesolv':
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  drop_last=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    val_loader = DataLoader(valid_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    if args.pretrain_method == 'local':
        load_dir = 'results/' + args.dataset + '/pretrain_local/'
        save_dir = 'results/' + args.dataset + '/finetune_local/'
    elif args.pretrain_method == 'global':
        load_dir = 'results/' + args.dataset + '/pretrain_global/nb_' + str(
            args.n_nb) + '/'
        save_dir = 'results/' + args.dataset + '/finetune_global/nb_' + str(
            args.n_nb) + '/'
    else:
        print('Invalid method!!')

    if not os.path.exists(save_dir):
        os.system('mkdir -p %s' % save_dir)

    if not args.input_model_file == "":
        input_model_str = args.dataset_load + '_aug1_' + args.aug1 + '_' + str(
            args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(
                args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(
                    args.dropout_ratio) + '_seed_' + str(args.runseed)
        output_model_str = args.dataset + '_semi_' + str(
            args.semi_ratio
        ) + '_protocol_' + args.protocol + '_aug1_' + args.aug1 + '_' + str(
            args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(
                args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(
                    args.dropout_ratio) + '_seed_' + str(
                        args.runseed) + '_' + str(args.seed)
    else:
        output_model_str = 'scratch_' + args.dataset + '_semi_' + str(
            args.semi_ratio) + '_protocol_' + args.protocol + '_do_' + str(
                args.dropout_ratio) + '_seed_' + str(args.runseed) + '_' + str(
                    args.seed)

    txtfile = save_dir + output_model_str + ".txt"
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    if os.path.exists(txtfile):
        os.system('mv %s %s' %
                  (txtfile, txtfile +
                   ".bak-%s" % nowTime))  # rename exsist file for collison

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(load_dir + args.input_model_file +
                              input_model_str + '.pth')
        print('successfully load pretrained model!')
    else:
        print('No pretrain! train from scratch!')

    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    # if linear protocol, fix GNN layers
    if args.protocol == 'linear':
        print("linear protocol, only train the top layer!")
        for name, param in model.named_parameters():
            if not 'pred_linear' in name:
                param.requires_grad = False
    elif args.protocol == 'nonlinear':
        print("finetune protocol, train all the layers!")
    else:
        print("invalid protocol!")

    # all task info summary
    print('=========task summary=========')
    print('Dataset: ', args.dataset)
    if args.semi_ratio == 1.0:
        print('full-supervised {:.2f}'.format(args.semi_ratio))
    else:
        print('semi-supervised {:.2f}'.format(args.semi_ratio))
    if args.input_model_file == '':
        print('scratch or finetune: scratch')
        print('loaded model from: - ')
    else:
        print('scratch or finetune: finetune')
        print('loaded model from: ', args.dataset_load)
        print('global_mode: n_nb = ', args.n_nb)
    print('Protocol: ', args.protocol)
    print('task type:', task_type)
    print('=========task summary=========')

    # training based on task type
    if task_type == 'cls':
        with open(txtfile, "a") as myfile:
            myfile.write('epoch: train_auc val_auc test_auc\n')
        wait = 0
        best_auc = 0
        patience = 10
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train_cls(args, model, device, train_loader, optimizer)

            print("====Evaluation")
            if args.eval_train:
                train_auc = eval_cls(args, model, device, train_loader)
            else:
                print("omit the training accuracy computation")
                train_auc = 0
            val_auc = eval_cls(args, model, device, val_loader)
            test_auc = eval_cls(args, model, device, test_loader)

            with open(txtfile, "a") as myfile:
                myfile.write(
                    str(int(epoch)) + ': ' + str(train_auc) + ' ' +
                    str(val_auc) + ' ' + str(test_auc) + "\n")

            print("train: %f val: %f test: %f" %
                  (train_auc, val_auc, test_auc))

            # Early stopping
            if np.greater(val_auc, best_auc):  # change for train loss
                best_auc = val_auc
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print(
                        'Early stop at Epoch: {:d} with final val auc: {:.4f}'.
                        format(epoch, val_auc))
                    break

    elif task_type == 'reg':
        with open(txtfile, "a") as myfile:
            myfile.write(
                'epoch: train_mse train_cor val_mse val_cor test_mse test_cor\n'
            )
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train(args, model, device, train_loader, optimizer)

            print("====Evaluation")
            if args.eval_train:
                train_mse, train_cor = eval_reg(args, model, device,
                                                train_loader)
            else:
                print("omit the training accuracy computation")
                train_mse, train_cor = 0, 0
            val_mse, val_cor = eval_reg(args, model, device, val_loader)
            test_mse, test_cor = eval_reg(args, model, device, test_loader)

            with open(txtfile, "a") as myfile:
                myfile.write(
                    str(int(epoch)) + ': ' + str(train_mse) + ' ' +
                    str(train_cor) + ' ' + str(val_mse) + ' ' + str(val_cor) +
                    ' ' + str(test_mse) + ' ' + str(test_cor) + "\n")

            print("train: %f val: %f test: %f" %
                  (train_mse, val_mse, test_mse))
            print("train: %f val: %f test: %f" %
                  (train_cor, val_cor, test_cor))
Ejemplo n.º 8
0
    def __init__(self, args):
        super(Meta_model, self).__init__()

        self.dataset = args.dataset
        self.num_tasks = args.num_tasks
        self.num_train_tasks = args.num_train_tasks
        self.num_test_tasks = args.num_test_tasks
        self.n_way = args.n_way
        self.m_support = args.m_support
        self.k_query = args.k_query
        self.gnn_type = args.gnn_type

        self.emb_dim = args.emb_dim

        self.device = args.device

        self.add_similarity = args.add_similarity
        self.add_selfsupervise = args.add_selfsupervise
        self.add_masking = args.add_masking
        self.add_weight = args.add_weight
        self.interact = args.interact

        self.batch_size = args.batch_size

        self.meta_lr = args.meta_lr
        self.update_lr = args.update_lr
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test

        self.criterion = nn.BCEWithLogitsLoss()

        self.graph_model = GNN_graphpred(args.num_layer,
                                         args.emb_dim,
                                         1,
                                         JK=args.JK,
                                         drop_ratio=args.dropout_ratio,
                                         graph_pooling=args.graph_pooling,
                                         gnn_type=args.gnn_type)
        if not args.input_model_file == "":
            self.graph_model.from_pretrained(args.input_model_file)

        if self.add_selfsupervise:
            self.self_criterion = nn.BCEWithLogitsLoss()

        if self.add_masking:
            self.masking_criterion = nn.CrossEntropyLoss()
            self.masking_linear = nn.Linear(self.emb_dim, 119)

        if self.add_similarity:
            self.Attention = attention(self.emb_dim)

        if self.interact:
            self.softmax = nn.Softmax(dim=0)
            self.Interact_attention = Interact_attention(
                self.emb_dim, self.num_train_tasks)

        model_param_group = []
        model_param_group.append({"params": self.graph_model.gnn.parameters()})
        if args.graph_pooling == "attention":
            model_param_group.append({
                "params":
                self.graph_model.pool.parameters(),
                "lr":
                args.lr * args.lr_scale
            })
        model_param_group.append({
            "params":
            self.graph_model.graph_pred_linear.parameters(),
            "lr":
            args.lr * args.lr_scale
        })

        if self.add_masking:
            model_param_group.append(
                {"params": self.masking_linear.parameters()})

        if self.add_similarity:
            model_param_group.append({"params": self.Attention.parameters()})

        if self.interact:
            model_param_group.append(
                {"params": self.Interact_attention.parameters()})

        self.optimizer = optim.Adam(model_param_group,
                                    lr=args.meta_lr,
                                    weight_decay=args.decay)
Ejemplo n.º 9
0
class Meta_model(nn.Module):
    def __init__(self, args):
        super(Meta_model, self).__init__()

        self.dataset = args.dataset
        self.num_tasks = args.num_tasks
        self.num_train_tasks = args.num_train_tasks
        self.num_test_tasks = args.num_test_tasks
        self.n_way = args.n_way
        self.m_support = args.m_support
        self.k_query = args.k_query
        self.gnn_type = args.gnn_type

        self.emb_dim = args.emb_dim

        self.device = args.device

        self.add_similarity = args.add_similarity
        self.add_selfsupervise = args.add_selfsupervise
        self.add_masking = args.add_masking
        self.add_weight = args.add_weight
        self.interact = args.interact

        self.batch_size = args.batch_size

        self.meta_lr = args.meta_lr
        self.update_lr = args.update_lr
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test

        self.criterion = nn.BCEWithLogitsLoss()

        self.graph_model = GNN_graphpred(args.num_layer,
                                         args.emb_dim,
                                         1,
                                         JK=args.JK,
                                         drop_ratio=args.dropout_ratio,
                                         graph_pooling=args.graph_pooling,
                                         gnn_type=args.gnn_type)
        if not args.input_model_file == "":
            self.graph_model.from_pretrained(args.input_model_file)

        if self.add_selfsupervise:
            self.self_criterion = nn.BCEWithLogitsLoss()

        if self.add_masking:
            self.masking_criterion = nn.CrossEntropyLoss()
            self.masking_linear = nn.Linear(self.emb_dim, 119)

        if self.add_similarity:
            self.Attention = attention(self.emb_dim)

        if self.interact:
            self.softmax = nn.Softmax(dim=0)
            self.Interact_attention = Interact_attention(
                self.emb_dim, self.num_train_tasks)

        model_param_group = []
        model_param_group.append({"params": self.graph_model.gnn.parameters()})
        if args.graph_pooling == "attention":
            model_param_group.append({
                "params":
                self.graph_model.pool.parameters(),
                "lr":
                args.lr * args.lr_scale
            })
        model_param_group.append({
            "params":
            self.graph_model.graph_pred_linear.parameters(),
            "lr":
            args.lr * args.lr_scale
        })

        if self.add_masking:
            model_param_group.append(
                {"params": self.masking_linear.parameters()})

        if self.add_similarity:
            model_param_group.append({"params": self.Attention.parameters()})

        if self.interact:
            model_param_group.append(
                {"params": self.Interact_attention.parameters()})

        self.optimizer = optim.Adam(model_param_group,
                                    lr=args.meta_lr,
                                    weight_decay=args.decay)

        # for name, para in self.named_parameters():
        #     if para.requires_grad:
        #         print(name, para.data.shape)
        # raise TypeError

    def update_params(self, loss, update_lr):
        grads = torch.autograd.grad(loss, self.graph_model.parameters())
        return parameters_to_vector(grads), parameters_to_vector(
            self.graph_model.parameters(
            )) - parameters_to_vector(grads) * update_lr

    def build_negative_edges(self, batch):
        font_list = batch.edge_index[0, ::2].tolist()
        back_list = batch.edge_index[1, ::2].tolist()

        all_edge = {}
        for count, front_e in enumerate(font_list):
            if front_e not in all_edge:
                all_edge[front_e] = [back_list[count]]
            else:
                all_edge[front_e].append(back_list[count])

        negative_edges = []
        for num in range(batch.x.size()[0]):
            if num in all_edge:
                for num_back in range(num, batch.x.size()[0]):
                    if num_back not in all_edge[num] and num != num_back:
                        negative_edges.append((num, num_back))
            else:
                for num_back in range(num, batch.x.size()[0]):
                    if num != num_back:
                        negative_edges.append((num, num_back))

        negative_edge_index = torch.tensor(np.array(
            random.sample(negative_edges, len(font_list))).T,
                                           dtype=torch.long)

        return negative_edge_index

    def forward(self, epoch):
        support_loaders = []
        query_loaders = []

        device = torch.device("cuda:" +
                              str(self.device)) if torch.cuda.is_available(
                              ) else torch.device("cpu")
        self.graph_model.train()

        # tasks_list = random.sample(range(0,self.num_train_tasks), self.batch_size)

        for task in range(self.num_train_tasks):
            # for task in tasks_list:
            dataset = MoleculeDataset("Original_datasets/" + self.dataset +
                                      "/new/" + str(task + 1),
                                      dataset=self.dataset)
            support_dataset, query_dataset = sample_datasets(
                dataset, self.dataset, task, self.n_way, self.m_support,
                self.k_query)
            support_loader = DataLoader(support_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=1)
            query_loader = DataLoader(query_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=1)
            support_loaders.append(support_loader)
            query_loaders.append(query_loader)

        for k in range(0, self.update_step):
            # print(self.fi)
            old_params = parameters_to_vector(self.graph_model.parameters())

            losses_q = torch.tensor([0.0]).to(device)

            # support_params = []
            # support_grads = torch.Tensor(self.num_train_tasks, parameters_to_vector(self.graph_model.parameters()).size()[0]).to(device)

            for task in range(self.num_train_tasks):

                losses_s = torch.tensor([0.0]).to(device)
                if self.add_similarity or self.interact:
                    one_task_emb = torch.zeros(300).to(device)

                for step, batch in enumerate(
                        tqdm(support_loaders[task], desc="Iteration")):
                    batch = batch.to(device)

                    pred, node_emb = self.graph_model(batch.x,
                                                      batch.edge_index,
                                                      batch.edge_attr,
                                                      batch.batch)
                    y = batch.y.view(pred.shape).to(torch.float64)

                    loss = torch.sum(self.criterion(pred.double(),
                                                    y)) / pred.size()[0]

                    if self.add_selfsupervise:
                        positive_score = torch.sum(
                            node_emb[batch.edge_index[0, ::2]] *
                            node_emb[batch.edge_index[1, ::2]],
                            dim=1)

                        negative_edge_index = self.build_negative_edges(batch)
                        negative_score = torch.sum(
                            node_emb[negative_edge_index[0]] *
                            node_emb[negative_edge_index[1]],
                            dim=1)

                        self_loss = torch.sum(
                            self.self_criterion(
                                positive_score, torch.ones_like(
                                    positive_score)) + self.self_criterion(
                                        negative_score,
                                        torch.zeros_like(negative_score))
                        ) / negative_edge_index[0].size()[0]

                        loss += (self.add_weight * self_loss)

                    if self.add_masking:
                        mask_num = random.sample(range(0,
                                                       node_emb.size()[0]),
                                                 self.batch_size)
                        pred_emb = self.masking_linear(node_emb[mask_num])
                        loss += (self.add_weight * self.masking_criterion(
                            pred_emb.double(), batch.x[mask_num, 0]))

                    if self.add_similarity or self.interact:
                        one_task_emb = torch.div(
                            (one_task_emb + torch.mean(node_emb, 0)), 2.0)

                    losses_s += loss

                if self.add_similarity or self.interact:
                    if task == 0:
                        tasks_emb = []
                    tasks_emb.append(one_task_emb)

                new_grad, new_params = self.update_params(
                    losses_s, update_lr=self.update_lr)

                vector_to_parameters(new_params, self.graph_model.parameters())

                this_loss_q = torch.tensor([0.0]).to(device)
                for step, batch in enumerate(
                        tqdm(query_loaders[task], desc="Iteration")):
                    batch = batch.to(device)

                    pred, node_emb = self.graph_model(batch.x,
                                                      batch.edge_index,
                                                      batch.edge_attr,
                                                      batch.batch)
                    y = batch.y.view(pred.shape).to(torch.float64)

                    loss_q = torch.sum(self.criterion(pred.double(),
                                                      y)) / pred.size()[0]

                    if self.add_selfsupervise:
                        positive_score = torch.sum(
                            node_emb[batch.edge_index[0, ::2]] *
                            node_emb[batch.edge_index[1, ::2]],
                            dim=1)

                        negative_edge_index = self.build_negative_edges(batch)
                        negative_score = torch.sum(
                            node_emb[negative_edge_index[0]] *
                            node_emb[negative_edge_index[1]],
                            dim=1)

                        self_loss = torch.sum(
                            self.self_criterion(
                                positive_score, torch.ones_like(
                                    positive_score)) + self.self_criterion(
                                        negative_score,
                                        torch.zeros_like(negative_score))
                        ) / negative_edge_index[0].size()[0]

                        loss_q += (self.add_weight * self_loss)

                    if self.add_masking:
                        mask_num = random.sample(range(0,
                                                       node_emb.size()[0]),
                                                 self.batch_size)
                        pred_emb = self.masking_linear(node_emb[mask_num])
                        loss += (self.add_weight * self.masking_criterion(
                            pred_emb.double(), batch.x[mask_num, 0]))

                    this_loss_q += loss_q

                if task == 0:
                    losses_q = this_loss_q
                else:
                    losses_q = torch.cat((losses_q, this_loss_q), 0)

                vector_to_parameters(old_params, self.graph_model.parameters())

            if self.add_similarity:
                for t_index, one_task_e in enumerate(tasks_emb):
                    if t_index == 0:
                        tasks_emb_new = one_task_e
                    else:
                        tasks_emb_new = torch.cat((tasks_emb_new, one_task_e),
                                                  0)

                tasks_emb_new = torch.reshape(
                    tasks_emb_new, (self.num_train_tasks, self.emb_dim))
                tasks_emb_new = tasks_emb_new.detach()

                tasks_weight = self.Attention(tasks_emb_new)
                losses_q = torch.sum(tasks_weight * losses_q)

            elif self.interact:
                for t_index, one_task_e in enumerate(tasks_emb):
                    if t_index == 0:
                        tasks_emb_new = one_task_e
                    else:
                        tasks_emb_new = torch.cat((tasks_emb_new, one_task_e),
                                                  0)

                tasks_emb_new = tasks_emb_new.detach()
                represent_emb = self.Interact_attention(tasks_emb_new)
                represent_emb = F.normalize(represent_emb, p=2, dim=0)

                tasks_emb_new = torch.reshape(
                    tasks_emb_new, (self.num_train_tasks, self.emb_dim))
                tasks_emb_new = F.normalize(tasks_emb_new, p=2, dim=1)

                tasks_weight = torch.mm(
                    tasks_emb_new,
                    torch.reshape(represent_emb, (self.emb_dim, 1)))
                print(tasks_weight)
                print(self.softmax(tasks_weight))
                print(losses_q)

                # tasks_emb_new = tasks_emb_new * torch.reshape(represent_emb_m, (self.batch_size, self.emb_dim))

                losses_q = torch.sum(
                    losses_q *
                    torch.transpose(self.softmax(tasks_weight), 1, 0))
                print(losses_q)

            else:
                losses_q = torch.sum(losses_q)

            loss_q = losses_q / self.num_train_tasks
            self.optimizer.zero_grad()
            loss_q.backward()
            self.optimizer.step()

        return []

    def test(self, support_grads):
        accs = []
        old_params = parameters_to_vector(self.graph_model.parameters())
        for task in range(self.num_test_tasks):
            print(self.num_tasks - task)
            dataset = MoleculeDataset("Original_datasets/" + self.dataset +
                                      "/new/" + str(self.num_tasks - task),
                                      dataset=self.dataset)
            support_dataset, query_dataset = sample_test_datasets(
                dataset, self.dataset, self.num_tasks - task - 1, self.n_way,
                self.m_support, self.k_query)
            support_loader = DataLoader(support_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=1)
            query_loader = DataLoader(query_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=1)

            device = torch.device("cuda:" +
                                  str(self.device)) if torch.cuda.is_available(
                                  ) else torch.device("cpu")

            self.graph_model.eval()

            for k in range(0, self.update_step_test):
                loss = torch.tensor([0.0]).to(device)
                for step, batch in enumerate(
                        tqdm(support_loader, desc="Iteration")):
                    batch = batch.to(device)

                    pred, node_emb = self.graph_model(batch.x,
                                                      batch.edge_index,
                                                      batch.edge_attr,
                                                      batch.batch)
                    y = batch.y.view(pred.shape).to(torch.float64)

                    loss += torch.sum(self.criterion(pred.double(),
                                                     y)) / pred.size()[0]

                    if self.add_selfsupervise:
                        positive_score = torch.sum(
                            node_emb[batch.edge_index[0, ::2]] *
                            node_emb[batch.edge_index[1, ::2]],
                            dim=1)

                        negative_edge_index = self.build_negative_edges(batch)
                        negative_score = torch.sum(
                            node_emb[negative_edge_index[0]] *
                            node_emb[negative_edge_index[1]],
                            dim=1)

                        self_loss = torch.sum(
                            self.self_criterion(
                                positive_score, torch.ones_like(
                                    positive_score)) + self.self_criterion(
                                        negative_score,
                                        torch.zeros_like(negative_score))
                        ) / negative_edge_index[0].size()[0]

                        loss += (self.add_weight * self_loss)

                    if self.add_masking:
                        mask_num = random.sample(range(0,
                                                       node_emb.size()[0]),
                                                 self.batch_size)
                        pred_emb = self.masking_linear(node_emb[mask_num])
                        loss += (self.add_weight * self.masking_criterion(
                            pred_emb.double(), batch.x[mask_num, 0]))

                    print(loss)

                new_grad, new_params = self.update_params(
                    loss, update_lr=self.update_lr)

                # if self.add_similarity:
                #     new_params = self.update_similarity_params(new_grad, support_grads)

                vector_to_parameters(new_params, self.graph_model.parameters())

            y_true = []
            y_scores = []
            for step, batch in enumerate(tqdm(query_loader, desc="Iteration")):
                batch = batch.to(device)

                pred, node_emb = self.graph_model(batch.x, batch.edge_index,
                                                  batch.edge_attr, batch.batch)
                # print(pred)
                pred = F.sigmoid(pred)
                pred = torch.where(pred > 0.5, torch.ones_like(pred), pred)
                pred = torch.where(pred <= 0.5, torch.zeros_like(pred), pred)
                y_scores.append(pred)
                y_true.append(batch.y.view(pred.shape))

            y_true = torch.cat(y_true, dim=0).cpu().detach().numpy()
            y_scores = torch.cat(y_scores, dim=0).cpu().detach().numpy()

            roc_list = []
            roc_list.append(roc_auc_score(y_true, y_scores))
            acc = sum(roc_list) / len(roc_list)
            accs.append(acc)

            vector_to_parameters(old_params, self.graph_model.parameters())

        return accs