Exemplo n.º 1
0
def main():
    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + "repurposing")

    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + "repurposing",
        transform=ONEHOT_ENCODING(dataset=dataset),
    )

    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
    )

    model = GNN_graphpred(
        num_layer=5,
        node_feat_dim=154,
        edge_feat_dim=2,
        emb_dim=256,
        num_tasks=1,
        JK="last",
        drop_ratio=0.5,
        graph_pooling="mean",
        gnn_type="gine",
        use_embedding=0,
    )

    model.load_state_dict(torch.load("tuned_model/jak3/90.pth"))
    model.eval()
    id = []
    cid = []
    score = []
    fields = ['id', 'cid', 'score']
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr,
                         batch.batch)
            id.append(batch.id)
            cid.append(batch.cid)
            score.append(pred)

    dict = {'id': id, 'cid': cid, 'score': score}
    df = pd.DataFrame(dict)

    df.to_csv('jak3_score_90.csv')
Exemplo n.º 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--dataset', type=str, default = 'zinc_standard_agent', help='root directory of dataset. For now, only classification.')
    parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading')
    args = parser.parse_args()


    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset, transform = NegativeEdge())

    print(dataset[0])

    loader = DataLoaderAE(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)

    #set up model
    model = GNN(args.num_layer, args.emb_dim, JK = args.JK, drop_ratio = args.dropout_ratio, gnn_type = args.gnn_type)
    
    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)   
    print(optimizer)

    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
    
        train_acc, train_loss = train(args, model, device, loader, optimizer)

        print(train_acc)
        print(train_loss)

    if not args.output_model_file == "":
        torch.save(model.state_dict(), args.output_model_file + ".pth")
Exemplo n.º 3
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument('--mask_rate',
                        type=float,
                        default=0.15,
                        help='dropout ratio (default: 0.15)')
    parser.add_argument(
        '--mask_edge',
        type=int,
        default=0,
        help='whether to mask edges or not together with atoms')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features are combined across layers. last, sum, max or concat'
    )
    parser.add_argument('--dataset',
                        type=str,
                        default='zinc_standard_agent',
                        help='root directory of dataset for pretraining')
    parser.add_argument('--output_model_file',
                        type=str,
                        default='',
                        help='filename to output the model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help="Seed for splitting dataset.")
    parser.add_argument('--num_workers',
                        type=int,
                        default=8,
                        help='number of workers for dataset loading')
    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    print("num layer: %d mask rate: %f mask edge: %d" %
          (args.num_layer, args.mask_rate, args.mask_edge))

    #set up dataset and transform function.
    dataset = MoleculeDataset("dataset/" + args.dataset,
                              dataset=args.dataset,
                              transform=MaskAtom(num_atom_type=119,
                                                 num_edge_type=5,
                                                 mask_rate=args.mask_rate,
                                                 mask_edge=args.mask_edge))

    loader = DataLoaderMasking(dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=args.num_workers)

    #set up models, one for pre-training and one for context embeddings
    model = GNN(args.num_layer,
                args.emb_dim,
                JK=args.JK,
                drop_ratio=args.dropout_ratio,
                gnn_type=args.gnn_type).to(device)
    linear_pred_atoms = torch.nn.Linear(args.emb_dim, 119).to(device)
    linear_pred_bonds = torch.nn.Linear(args.emb_dim, 4).to(device)

    model_list = [model, linear_pred_atoms, linear_pred_bonds]

    #set up optimizers
    optimizer_model = optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.decay)
    optimizer_linear_pred_atoms = optim.Adam(linear_pred_atoms.parameters(),
                                             lr=args.lr,
                                             weight_decay=args.decay)
    optimizer_linear_pred_bonds = optim.Adam(linear_pred_bonds.parameters(),
                                             lr=args.lr,
                                             weight_decay=args.decay)

    optimizer_list = [
        optimizer_model, optimizer_linear_pred_atoms,
        optimizer_linear_pred_bonds
    ]

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train_loss, train_acc_atom, train_acc_bond = train(
            args, model_list, loader, optimizer_list, device)
        print(train_loss, train_acc_atom, train_acc_bond)

    if not args.output_model_file == "":
        torch.save(model.state_dict(), args.output_model_file + ".pth")
Exemplo n.º 4
0
        idx_list.append(idx)
    train_idx, val_idx = idx_list[fold_idx]

    train_dataset = dataset[torch.tensor(train_idx)]
    valid_dataset = dataset[torch.tensor(val_idx)]

    return train_dataset, valid_dataset


if __name__ == "__main__":
    from loader import MoleculeDataset
    from rdkit import Chem
    import pandas as pd

    # # test scaffold_split
    dataset = MoleculeDataset('dataset/tox21', dataset='tox21')
    smiles_list = pd.read_csv('dataset/tox21/processed/smiles.csv',
                              header=None)[0].tolist()

    train_dataset, valid_dataset, test_dataset = scaffold_split(dataset,
                                                                smiles_list,
                                                                task_idx=None,
                                                                null_value=0,
                                                                frac_train=0.8,
                                                                frac_valid=0.1,
                                                                frac_test=0.1)
    # train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, task_idx=None, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = 0)
    unique_ids = set(train_dataset.data.id.tolist() +
                     valid_dataset.data.id.tolist() +
                     test_dataset.data.id.tolist())
    assert len(unique_ids) == len(dataset)  # check that we did not have any
Exemplo n.º 5
0
        return data

    def __repr__(self):
        reprs = "{}(num_atom_features={}, num_edge_type={}, mask_rate={}, mask_edge={})"
        return reprs.format(
            self.__class__.__name__,
            self.num_atom_features,
            self.num_edge_type,
            self.mask_rate,
            self.mask_edge,
        )


if __name__ == "__main__":
    transform = NegativeEdge()
    dataset = MoleculeDataset("dataset/tox21", dataset="tox21")
    transform(dataset[0])
    """
    # TODO(Bowen): more unit tests
    # test ExtractSubstructureContextPair

    smiles = 'C#Cc1c(O)c(Cl)cc(/C=C/N)c1S'
    m = AllChem.MolFromSmiles(smiles)
    data = mol_to_graph_data_obj_simple(m)
    root_idx = 13

    # 0 hops: no substructure or context. We just test the absence of x attr
    transform = ExtractSubstructureContextPair(0, 0, 0)
    transform(data, root_idx)
    assert not hasattr(data, 'x_substruct')
    assert not hasattr(data, 'x_context')
Exemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='tox21',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    #Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    print(dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    print(train_dataset[0])

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    val_loader = DataLoader(valid_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file)

    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    train_acc_list = []
    val_acc_list = []
    test_acc_list = []

    if not args.filename == "":
        fname = 'runs/finetune_cls_runseed' + str(
            args.runseed) + '/' + args.filename
        #delete the directory if there exists one
        if os.path.exists(fname):
            shutil.rmtree(fname)
            print("removed the existing file.")
        writer = SummaryWriter(fname)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)

        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            print("omit the training accuracy computation")
            train_acc = 0
        val_acc = eval(args, model, device, val_loader)
        test_acc = eval(args, model, device, test_loader)

        print("train: %f val: %f test: %f" % (train_acc, val_acc, test_acc))

        val_acc_list.append(val_acc)
        test_acc_list.append(test_acc)
        train_acc_list.append(train_acc)

        if not args.filename == "":
            writer.add_scalar('data/train auc', train_acc, epoch)
            writer.add_scalar('data/val auc', val_acc, epoch)
            writer.add_scalar('data/test auc', test_acc, epoch)

        print("")

    if not args.filename == "":
        writer.close()
Exemplo n.º 7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.995,
                        help='learning rate decay (default: 0.995)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--loss_type', type=str, default="bce")
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=768,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--heads',
                        type=int,
                        default=12,
                        help='multi heads (default: 4)')
    parser.add_argument('--num_message_passing',
                        type=int,
                        default=3,
                        help='message passing steps (default: 3)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="collection",
        help=
        'graph level pooling (collection,sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='twosides',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='pretrained_model/MolGNet.pt',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--iters',
                        type=int,
                        default=1,
                        help='number of run seeds')
    parser.add_argument('--log_file', type=str, default=None)
    parser.add_argument('--log_freq', type=int, default=0)
    parser.add_argument('--KFold',
                        type=int,
                        default=5,
                        help='number of folds for cross validation')
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument('--cpu', default=False, action="store_true")
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")
    args.seed = args.seed
    args.runseed = args.runseed
    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    exp_path = 'runs/{}/'.format(args.dataset)
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)

    num_tasks = 1
    # set up dataset
    dataset = MoleculeDataset("data/downstream/" + args.dataset,
                              dataset=args.dataset,
                              transform=None)
    all_result = []
    for fold in range(args.KFold):

        train_dataset, test_dataset = cv_random_split(dataset, fold, 5)
        train_loader = DataLoaderMasking(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)
        test_loader = DataLoaderMasking(test_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers)
        # set up model
        model = MolGT_graphpred(args.num_layer,
                                args.emb_dim,
                                args.heads,
                                args.num_message_passing,
                                num_tasks,
                                drop_ratio=args.dropout_ratio,
                                graph_pooling=args.graph_pooling)
        if not args.input_model_file == "":
            model.from_pretrained(args.input_model_file)
            print('pretrained model loaded')
        else:
            print('No pretrain')

        model.to(device)
        #
        # # set up optimizer
        # # different learning rate for different part of GNN
        # #old
        # model_param_group = []
        # model_param_group.append({"params": model.gnn.parameters()})
        # if args.graph_pooling == "attention":
        #     model_param_group.append({"params": model.pool.parameters(), "lr": args.lr * args.lr_scale})
        # model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr": args.lr * args.lr_scale})
        # optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
        # print(optimizer)
        #
        # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay)
        # #before
        model_param_group = list(model.gnn.named_parameters())
        if args.graph_pooling == "attention":
            model_param_group += list(model.pool.named_parameters())
        model_param_group += list(model.graph_pred_linear.named_parameters())
        # optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)

        param_optimizer = [
            n for n in model_param_group if 'pooler' not in n[0]
        ]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        num_train_optimization_steps = int(
            len(train_dataset) / args.batch_size) * args.epochs
        print(num_train_optimization_steps)
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.lr,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        if args.loss_type == 'bce':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss_type == 'softmax':
            criterion = nn.CrossEntropyLoss()
        else:
            criterion = FocalLoss(gamma=2, alpha=0.25)

        best_result = []
        acc = 0

        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train(args, model, device, train_loader, optimizer, criterion)

            print("====Evaluation")
            if args.eval_train:
                train_acc = eval(args, model, device, train_loader)
            else:
                print("omit the training accuracy computation")
                train_acc = 0
            result = eval(args, model, device, test_loader)
            print(result)
            if result[1] > acc:
                acc = result[1]
                best_result = result
                torch.save(model.state_dict(),
                           exp_path + "Fold_{}.pkl".format(fold))
                print("save network for epoch:", epoch, acc)
            print('test metrics: acc,f1,precision,recall:', result)
        all_result.append(best_result)
        with open(exp_path + "log.txt", "a+") as f:
            f.write('{}, fold {}, acc, f1, precision, recall: {}'.format(
                args.dataset, fold, best_result))
            f.write('\n')

    ave_result = np.mean(np.array(all_result), 0)
    with open(exp_path + "log.txt", "a+") as f:
        f.write('{}, Average--acc, f1, precision, recall: {}'.format(
            args.dataset, ave_result))
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.2,
                        help='dropout ratio (default: 0.2)')
    parser.add_argument('--graph_pooling', type=str, default="mean",
                        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--dataset', type=str, default = 'chembl_filtered', help='root directory of dataset. For now, only classification.')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--input_model_file', type=str, default = '', help='filename to read the model (if there is any)')
    parser.add_argument('--output_model_file', type = str, default = '', help='filename to output the pre-trained model')
    parser.add_argument('--num_workers', type=int, default = 8, help='number of workers for dataset loading')
    args = parser.parse_args()


    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #Bunch of classification tasks
    if args.dataset == "chembl_filtered":
        num_tasks = 1310
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")
    
    model.to(device)

    #set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)  
    print(optimizer)


    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
    
        train(args, model, device, loader, optimizer)

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")
Exemplo n.º 9
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_scale', type=float, default=1,
                        help='relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=3,
                        help='number of GNN message passing layers (default: 3).')
    parser.add_argument('--emb_dim', type=int, default=512,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--graph_pooling', type=str, default="mean",
                        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--dataset', type=str, default = 'tox21', help='root directory of dataset. For now, only classification.')
    parser.add_argument('--output_model_file', type=str, default = '', help='filename to output the model')
    parser.add_argument('--filename', type=str, default = '', help='output filename')
    parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting the dataset.")
    parser.add_argument('--runseed', type=int, default=0, help = "Seed for minibatch selection, random initialization.")
    parser.add_argument('--split', type = str, default="scaffold", help = "random or scaffold or random_scaffold")
    parser.add_argument('--eval_train', type=int, default = 0, help='evaluating training or not')
    parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')
    parser.add_argument('--aug1', type=str, default = 'drop_node', help='augmentation1')
    parser.add_argument('--aug2', type=str, default = 'drop_node', help='augmentation2')
    parser.add_argument('--aug_ratio1', type=float, default = 0.2, help='aug ratio1')
    parser.add_argument('--aug_ratio2', type=float, default = 0.2, help='aug ratio2')
    parser.add_argument('--method', type=str, default = 'local', help='method: local, global')
    parser.add_argument('--lamb', type=float, default = 0.0, help='hyper para of global-structure loss')
    parser.add_argument('--n_nb', type=int, default = 0, help='number of neighbors for  global-structure loss')
    parser.add_argument('--global_mode', type=str, default = 'sup', help='global mode: sup or cl')
    args = parser.parse_args()


    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    #Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset == 'esol':
        num_tasks = 1
    elif args.dataset == 'mutag':
        num_tasks = 1
    elif args.dataset == 'dti':
        num_tasks = 0
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)
    print(dataset)
    print(dataset.data)

    if args.method == 'local':
        save_dir = 'results/' + args.dataset + '/pretrain_local/'
    elif args.method == 'global':
        save_dir = 'results/' + args.dataset + '/pretrain_global/nb_' + str(args.n_nb) + '/' 
        if args.dataset == 'hiv':
            sim_matrix = np.zeros([len(dataset.original_smiles), len(dataset.original_smiles)])
            sim_matrix_nb = np.zeros([len(dataset.original_smiles), len(dataset.original_smiles)])
        else:
            if args.global_mode == 'cl':
                with open('results/'+args.dataset+'/sim_matrix_nb_'+str(args.n_nb)+'.pkl', 'rb') as f:
                    df = pkl.load(f)
                    sim_matrix_nb = df[0]
                sim_matrix_nb = torch.from_numpy(sim_matrix_nb).to(device)
                print('sim_matrix_nb loaded with size: ', sim_matrix_nb.size())
                sim_matrix = None
            elif args.global_mode == 'sup':
                with open('results/'+args.dataset+'/sim_matrix.pkl', 'rb') as f:
                    df = pkl.load(f)
                    sim_matrix = df[0]
                sim_matrix = torch.from_numpy(sim_matrix).to(device)
                print('sim_matrix loaded with size: ', sim_matrix.size())
                sim_matrix_nb = None

    else:
        print('Invalid method!!')

    if not os.path.exists(save_dir):
        os.system('mkdir -p %s' % save_dir)

    model_str = args.dataset + '_aug1_' + args.aug1 + '_' + str(args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(args.dropout_ratio) + '_seed_' + str(args.runseed)

    txtfile=save_dir + model_str + ".txt"
    nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    if os.path.exists(txtfile):
        os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime))  # rename exsist file for collison
    
    #set up model
    model = GNN_graphCL(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type)
    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
    model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
    optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
    print(optimizer)

    with open(txtfile, "a") as myfile:
        myfile.write('epoch: train_loss\n')

    rules = json.load(open('isostere_transformations_new.json'))
    with open('results/'+ args.dataset + '/rule_indicator_new.pkl', 'rb') as f:
        d = pkl.load(f)
        rule_indicator = d[0]
    print('rule indicator shape: ', rule_indicator.shape)
    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
        if args.method == 'local':
            train_loss, _ = train_base(model, optimizer, dataset, device, args.batch_size, args.aug1, args.aug_ratio1, args.aug2, args.aug_ratio2)
        elif args.method == 'global':
            train_loss, _ = train_global(model, optimizer, dataset, device, args.batch_size, args.aug1, args.aug_ratio1, args.aug2, args.aug_ratio2, 
                sim_matrix, sim_matrix_nb, args.lamb, mode=args.global_mode)
        else:
            print('invalid method!!')
        
        print("train: %f" %(train_loss))

        with open(txtfile, "a") as myfile:
            myfile.write(str(int(epoch)) + ': ' + str(train_loss) + "\n")

    if not args.output_model_file == "":
        torch.save(model.gnn.state_dict(), save_dir + args.output_model_file + model_str + ".pth")
Exemplo n.º 10
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=256,
                        help='input batch size for training (default: 256)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--csize',
                        type=int,
                        default=3,
                        help='context size (default: 3).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0,
                        help='dropout ratio (default: 0)')
    parser.add_argument(
        '--neg_samples',
        type=int,
        default=1,
        help='number of negative contexts per positive context (default: 1)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features are combined across layers. last, sum, max or concat'
    )
    parser.add_argument('--context_pooling',
                        type=str,
                        default="mean",
                        help='how the contexts are pooled (sum, mean, or max)')
    parser.add_argument('--mode',
                        type=str,
                        default="cbow",
                        help="cbow or skipgram")
    parser.add_argument('--dataset',
                        type=str,
                        default='zinc_standard_agent',
                        help='root directory of dataset for pretraining')
    parser.add_argument('--output_model_file',
                        type=str,
                        default='',
                        help='filename to output the model')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help="Seed for splitting dataset.")
    parser.add_argument('--num_workers',
                        type=int,
                        default=8,
                        help='number of workers for dataset loading')
    args = parser.parse_args()

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    l1 = args.num_layer - 1
    l2 = l1 + args.csize

    print(args.mode)
    print("num layer: %d l1: %d l2: %d" % (args.num_layer, l1, l2))

    #set up dataset and transform function.
    dataset = MoleculeDataset("dataset/" + args.dataset,
                              dataset=args.dataset,
                              transform=ExtractSubstructureContextPair(
                                  args.num_layer, l1, l2))
    loader = DataLoaderSubstructContext(dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers)

    #set up models, one for pre-training and one for context embeddings
    model_substruct = GNN(args.num_layer,
                          args.emb_dim,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          gnn_type=args.gnn_type).to(device)
    model_context = GNN(int(l2 - l1),
                        args.emb_dim,
                        JK=args.JK,
                        drop_ratio=args.dropout_ratio,
                        gnn_type=args.gnn_type).to(device)

    #set up optimizer for the two GNNs
    optimizer_substruct = optim.Adam(model_substruct.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.decay)
    optimizer_context = optim.Adam(model_context.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.decay)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train_loss, train_acc = train(args, model_substruct, model_context,
                                      loader, optimizer_substruct,
                                      optimizer_context, device)
        print(train_loss, train_acc)

    if not args.output_model_file == "":
        torch.save(model_substruct.state_dict(),
                   args.output_model_file + ".pth")
Exemplo n.º 11
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        "PyTorch implementation of pre-training of graph neural networks")
    parser.add_argument("--device",
                        type=int,
                        default=0,
                        help="which gpu to use if any (default: 0)")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="input batch size for training (default: 256)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        help="learning rate (default: 0.001)")
    parser.add_argument("--decay",
                        type=float,
                        default=0,
                        help="weight decay (default: 0)")
    parser.add_argument(
        "--num_layer",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5).",
    )
    parser.add_argument("--csize",
                        type=int,
                        default=3,
                        help="context size (default: 3).")
    parser.add_argument("--emb_dim",
                        type=int,
                        default=64,
                        help="embedding dimensions (default: 300)")
    parser.add_argument(
        "--node_feat_dim",
        type=int,
        default=32,
        help="dimension of the node features.",
    )
    parser.add_argument("--edge_feat_dim",
                        type=int,
                        default=2,
                        help="dimension ofo the edge features.")

    parser.add_argument("--dropout_ratio",
                        type=float,
                        default=0,
                        help="dropout ratio (default: 0)")
    parser.add_argument(
        "--neg_samples",
        type=int,
        default=1,
        help="number of negative contexts per positive context (default: 1)",
    )
    parser.add_argument(
        "--JK",
        type=str,
        default="last",
        help="how the node features are combined across layers."
        "last, sum, max or concat",
    )
    parser.add_argument(
        "--context_pooling",
        type=str,
        default="mean",
        help="how the contexts are pooled (sum, mean, or max)",
    )
    parser.add_argument("--mode",
                        type=str,
                        default="cbow",
                        help="cbow or skipgram")
    parser.add_argument(
        "--dataset",
        type=str,
        default="dataset/zinc_standard_agent",
        help="root directory of dataset for pretraining",
    )
    parser.add_argument("--output_model_file",
                        type=str,
                        default="trained_model/context_onehot_mlp",
                        help="filename to output the model")
    parser.add_argument("--gnn_type", type=str, default="gine")
    parser.add_argument("--seed",
                        type=int,
                        default=0,
                        help="Seed for splitting dataset.")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="number of workers for dataset loading",
    )
    args = parser.parse_args()

    print("show all arguments configuration...")
    print(args)

    torch.manual_seed(0)
    np.random.seed(0)
    device = (torch.device("cuda:" + str(args.device))
              if torch.cuda.is_available() else torch.device("cpu"))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    l1 = args.num_layer - 1
    l2 = l1 + args.csize

    print(args.mode)
    print("num layer: %d l1: %d l2: %d" % (args.num_layer, l1, l2))

    # set up dataset and transform function.
    dataset_og = MoleculeDataset(root=args.dataset,
                                 dataset=os.path.basename(args.dataset))

    dataset = MoleculeDataset(
        root=args.dataset,
        dataset=os.path.basename(args.dataset),
        transform=ONEHOT_ContextPair(dataset=dataset_og,
                                     k=args.num_layer,
                                     l1=l1,
                                     l2=l2),
    )
    loader = DataLoaderSubstructContext(dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers)

    # set up models, one for pre-training and one for context embeddings
    model_substruct = GNN_MLP(
        args.num_layer,
        args.node_feat_dim,
        args.edge_feat_dim,
        args.emb_dim,
        JK=args.JK,
        drop_ratio=args.dropout_ratio,
        gnn_type=args.gnn_type,
    ).to(device)
    model_context = GNN_MLP(
        int(l2 - l1),
        args.node_feat_dim,
        args.edge_feat_dim,
        args.emb_dim,
        JK=args.JK,
        drop_ratio=args.dropout_ratio,
        gnn_type=args.gnn_type,
    ).to(device)

    # set up optimizer for the two GNNs
    optimizer_substruct = optim.Adam(model_substruct.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.decay)
    optimizer_context = optim.Adam(model_context.parameters(),
                                   lr=args.lr,
                                   weight_decay=args.decay)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train_loss, train_acc = train(
            args,
            model_substruct,
            model_context,
            loader,
            optimizer_substruct,
            optimizer_context,
            device,
        )
        print(train_loss, train_acc)

    if not args.output_model_file == "":
        torch.save(model_substruct.state_dict(),
                   args.output_model_file + ".pth")
Exemplo n.º 12
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=3,
        help='number of GNN message passing layers (default: 3).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=512,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='esol',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=1,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="random",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=1,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--aug1',
                        type=str,
                        default='dropN_random',
                        help='augmentation1')
    parser.add_argument('--aug2',
                        type=str,
                        default='dropN_random',
                        help='augmentation2')
    parser.add_argument('--aug_ratio1',
                        type=float,
                        default=0.0,
                        help='aug ratio1')
    parser.add_argument('--aug_ratio2',
                        type=float,
                        default=0.0,
                        help='aug ratio2')
    parser.add_argument('--dataset_load',
                        type=str,
                        default='esol',
                        help='load pretrain model from which dataset.')
    parser.add_argument('--protocol',
                        type=str,
                        default='linear',
                        help='downstream protocol, linear, nonlinear')
    parser.add_argument(
        '--semi_ratio',
        type=float,
        default=1.0,
        help='proportion of labels in semi-supervised settings')
    parser.add_argument('--pretrain_method',
                        type=str,
                        default='local',
                        help='pretrain_method: local, global')
    parser.add_argument('--lamb',
                        type=float,
                        default=0.0,
                        help='hyper para of global-structure loss')
    parser.add_argument('--n_nb',
                        type=int,
                        default=0,
                        help='number of neighbors for  global-structure loss')
    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    if args.dataset in [
            'tox21', 'hiv', 'pcba', 'muv', 'bace', 'bbbp', 'toxcast', 'sider',
            'clintox', 'mutag'
    ]:
        task_type = 'cls'
    else:
        task_type = 'reg'

    #Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset == 'esol':
        num_tasks = 1
    elif args.dataset == 'freesolv':
        num_tasks = 1
    elif args.dataset == 'mutag':
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset)

    print('The whole dataset:', dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1)
        print("scaffold")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset +
                                  '/processed/smiles.csv',
                                  header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed)
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    # semi-supervised settings
    if args.semi_ratio != 1.0:
        n_total, n_sample = len(train_dataset), int(
            len(train_dataset) * args.semi_ratio)
        print(
            'sample {:.2f} = {:d} labels for semi-supervised training!'.format(
                args.semi_ratio, n_sample))
        all_idx = list(range(n_total))
        random.seed(0)
        idx_semi = random.sample(all_idx, n_sample)
        train_dataset = train_dataset[torch.tensor(
            idx_semi)]  #int(len(train_dataset)*args.semi_ratio)
        print('new train dataset size:', len(train_dataset))
    else:
        print('finetune using all data!')

    if args.dataset == 'freesolv':
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  drop_last=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    val_loader = DataLoader(valid_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)

    if args.pretrain_method == 'local':
        load_dir = 'results/' + args.dataset + '/pretrain_local/'
        save_dir = 'results/' + args.dataset + '/finetune_local/'
    elif args.pretrain_method == 'global':
        load_dir = 'results/' + args.dataset + '/pretrain_global/nb_' + str(
            args.n_nb) + '/'
        save_dir = 'results/' + args.dataset + '/finetune_global/nb_' + str(
            args.n_nb) + '/'
    else:
        print('Invalid method!!')

    if not os.path.exists(save_dir):
        os.system('mkdir -p %s' % save_dir)

    if not args.input_model_file == "":
        input_model_str = args.dataset_load + '_aug1_' + args.aug1 + '_' + str(
            args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(
                args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(
                    args.dropout_ratio) + '_seed_' + str(args.runseed)
        output_model_str = args.dataset + '_semi_' + str(
            args.semi_ratio
        ) + '_protocol_' + args.protocol + '_aug1_' + args.aug1 + '_' + str(
            args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str(
                args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str(
                    args.dropout_ratio) + '_seed_' + str(
                        args.runseed) + '_' + str(args.seed)
    else:
        output_model_str = 'scratch_' + args.dataset + '_semi_' + str(
            args.semi_ratio) + '_protocol_' + args.protocol + '_do_' + str(
                args.dropout_ratio) + '_seed_' + str(args.runseed) + '_' + str(
                    args.seed)

    txtfile = save_dir + output_model_str + ".txt"
    nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    if os.path.exists(txtfile):
        os.system('mv %s %s' %
                  (txtfile, txtfile +
                   ".bak-%s" % nowTime))  # rename exsist file for collison

    #set up model
    model = GNN_graphpred(args.num_layer,
                          args.emb_dim,
                          num_tasks,
                          JK=args.JK,
                          drop_ratio=args.dropout_ratio,
                          graph_pooling=args.graph_pooling,
                          gnn_type=args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(load_dir + args.input_model_file +
                              input_model_str + '.pth')
        print('successfully load pretrained model!')
    else:
        print('No pretrain! train from scratch!')

    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    # if linear protocol, fix GNN layers
    if args.protocol == 'linear':
        print("linear protocol, only train the top layer!")
        for name, param in model.named_parameters():
            if not 'pred_linear' in name:
                param.requires_grad = False
    elif args.protocol == 'nonlinear':
        print("finetune protocol, train all the layers!")
    else:
        print("invalid protocol!")

    # all task info summary
    print('=========task summary=========')
    print('Dataset: ', args.dataset)
    if args.semi_ratio == 1.0:
        print('full-supervised {:.2f}'.format(args.semi_ratio))
    else:
        print('semi-supervised {:.2f}'.format(args.semi_ratio))
    if args.input_model_file == '':
        print('scratch or finetune: scratch')
        print('loaded model from: - ')
    else:
        print('scratch or finetune: finetune')
        print('loaded model from: ', args.dataset_load)
        print('global_mode: n_nb = ', args.n_nb)
    print('Protocol: ', args.protocol)
    print('task type:', task_type)
    print('=========task summary=========')

    # training based on task type
    if task_type == 'cls':
        with open(txtfile, "a") as myfile:
            myfile.write('epoch: train_auc val_auc test_auc\n')
        wait = 0
        best_auc = 0
        patience = 10
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train_cls(args, model, device, train_loader, optimizer)

            print("====Evaluation")
            if args.eval_train:
                train_auc = eval_cls(args, model, device, train_loader)
            else:
                print("omit the training accuracy computation")
                train_auc = 0
            val_auc = eval_cls(args, model, device, val_loader)
            test_auc = eval_cls(args, model, device, test_loader)

            with open(txtfile, "a") as myfile:
                myfile.write(
                    str(int(epoch)) + ': ' + str(train_auc) + ' ' +
                    str(val_auc) + ' ' + str(test_auc) + "\n")

            print("train: %f val: %f test: %f" %
                  (train_auc, val_auc, test_auc))

            # Early stopping
            if np.greater(val_auc, best_auc):  # change for train loss
                best_auc = val_auc
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print(
                        'Early stop at Epoch: {:d} with final val auc: {:.4f}'.
                        format(epoch, val_auc))
                    break

    elif task_type == 'reg':
        with open(txtfile, "a") as myfile:
            myfile.write(
                'epoch: train_mse train_cor val_mse val_cor test_mse test_cor\n'
            )
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))

            train(args, model, device, train_loader, optimizer)

            print("====Evaluation")
            if args.eval_train:
                train_mse, train_cor = eval_reg(args, model, device,
                                                train_loader)
            else:
                print("omit the training accuracy computation")
                train_mse, train_cor = 0, 0
            val_mse, val_cor = eval_reg(args, model, device, val_loader)
            test_mse, test_cor = eval_reg(args, model, device, test_loader)

            with open(txtfile, "a") as myfile:
                myfile.write(
                    str(int(epoch)) + ': ' + str(train_mse) + ' ' +
                    str(train_cor) + ' ' + str(val_mse) + ' ' + str(val_cor) +
                    ' ' + str(test_mse) + ' ' + str(test_cor) + "\n")

            print("train: %f val: %f test: %f" %
                  (train_mse, val_mse, test_mse))
            print("train: %f val: %f test: %f" %
                  (train_cor, val_cor, test_cor))
Exemplo n.º 13
0
    def test(self, support_grads):
        accs = []
        old_params = parameters_to_vector(self.graph_model.parameters())
        for task in range(self.num_test_tasks):
            print(self.num_tasks - task)
            dataset = MoleculeDataset("Original_datasets/" + self.dataset +
                                      "/new/" + str(self.num_tasks - task),
                                      dataset=self.dataset)
            support_dataset, query_dataset = sample_test_datasets(
                dataset, self.dataset, self.num_tasks - task - 1, self.n_way,
                self.m_support, self.k_query)
            support_loader = DataLoader(support_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=1)
            query_loader = DataLoader(query_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=1)

            device = torch.device("cuda:" +
                                  str(self.device)) if torch.cuda.is_available(
                                  ) else torch.device("cpu")

            self.graph_model.eval()

            for k in range(0, self.update_step_test):
                loss = torch.tensor([0.0]).to(device)
                for step, batch in enumerate(
                        tqdm(support_loader, desc="Iteration")):
                    batch = batch.to(device)

                    pred, node_emb = self.graph_model(batch.x,
                                                      batch.edge_index,
                                                      batch.edge_attr,
                                                      batch.batch)
                    y = batch.y.view(pred.shape).to(torch.float64)

                    loss += torch.sum(self.criterion(pred.double(),
                                                     y)) / pred.size()[0]

                    if self.add_selfsupervise:
                        positive_score = torch.sum(
                            node_emb[batch.edge_index[0, ::2]] *
                            node_emb[batch.edge_index[1, ::2]],
                            dim=1)

                        negative_edge_index = self.build_negative_edges(batch)
                        negative_score = torch.sum(
                            node_emb[negative_edge_index[0]] *
                            node_emb[negative_edge_index[1]],
                            dim=1)

                        self_loss = torch.sum(
                            self.self_criterion(
                                positive_score, torch.ones_like(
                                    positive_score)) + self.self_criterion(
                                        negative_score,
                                        torch.zeros_like(negative_score))
                        ) / negative_edge_index[0].size()[0]

                        loss += (self.add_weight * self_loss)

                    if self.add_masking:
                        mask_num = random.sample(range(0,
                                                       node_emb.size()[0]),
                                                 self.batch_size)
                        pred_emb = self.masking_linear(node_emb[mask_num])
                        loss += (self.add_weight * self.masking_criterion(
                            pred_emb.double(), batch.x[mask_num, 0]))

                    print(loss)

                new_grad, new_params = self.update_params(
                    loss, update_lr=self.update_lr)

                # if self.add_similarity:
                #     new_params = self.update_similarity_params(new_grad, support_grads)

                vector_to_parameters(new_params, self.graph_model.parameters())

            y_true = []
            y_scores = []
            for step, batch in enumerate(tqdm(query_loader, desc="Iteration")):
                batch = batch.to(device)

                pred, node_emb = self.graph_model(batch.x, batch.edge_index,
                                                  batch.edge_attr, batch.batch)
                # print(pred)
                pred = F.sigmoid(pred)
                pred = torch.where(pred > 0.5, torch.ones_like(pred), pred)
                pred = torch.where(pred <= 0.5, torch.zeros_like(pred), pred)
                y_scores.append(pred)
                y_true.append(batch.y.view(pred.shape))

            y_true = torch.cat(y_true, dim=0).cpu().detach().numpy()
            y_scores = torch.cat(y_scores, dim=0).cpu().detach().numpy()

            roc_list = []
            roc_list.append(roc_auc_score(y_true, y_scores))
            acc = sum(roc_list) / len(roc_list)
            accs.append(acc)

            vector_to_parameters(old_params, self.graph_model.parameters())

        return accs
Exemplo n.º 14
0
    def forward(self, epoch):
        support_loaders = []
        query_loaders = []

        device = torch.device("cuda:" +
                              str(self.device)) if torch.cuda.is_available(
                              ) else torch.device("cpu")
        self.graph_model.train()

        # tasks_list = random.sample(range(0,self.num_train_tasks), self.batch_size)

        for task in range(self.num_train_tasks):
            # for task in tasks_list:
            dataset = MoleculeDataset("Original_datasets/" + self.dataset +
                                      "/new/" + str(task + 1),
                                      dataset=self.dataset)
            support_dataset, query_dataset = sample_datasets(
                dataset, self.dataset, task, self.n_way, self.m_support,
                self.k_query)
            support_loader = DataLoader(support_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=1)
            query_loader = DataLoader(query_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=1)
            support_loaders.append(support_loader)
            query_loaders.append(query_loader)

        for k in range(0, self.update_step):
            # print(self.fi)
            old_params = parameters_to_vector(self.graph_model.parameters())

            losses_q = torch.tensor([0.0]).to(device)

            # support_params = []
            # support_grads = torch.Tensor(self.num_train_tasks, parameters_to_vector(self.graph_model.parameters()).size()[0]).to(device)

            for task in range(self.num_train_tasks):

                losses_s = torch.tensor([0.0]).to(device)
                if self.add_similarity or self.interact:
                    one_task_emb = torch.zeros(300).to(device)

                for step, batch in enumerate(
                        tqdm(support_loaders[task], desc="Iteration")):
                    batch = batch.to(device)

                    pred, node_emb = self.graph_model(batch.x,
                                                      batch.edge_index,
                                                      batch.edge_attr,
                                                      batch.batch)
                    y = batch.y.view(pred.shape).to(torch.float64)

                    loss = torch.sum(self.criterion(pred.double(),
                                                    y)) / pred.size()[0]

                    if self.add_selfsupervise:
                        positive_score = torch.sum(
                            node_emb[batch.edge_index[0, ::2]] *
                            node_emb[batch.edge_index[1, ::2]],
                            dim=1)

                        negative_edge_index = self.build_negative_edges(batch)
                        negative_score = torch.sum(
                            node_emb[negative_edge_index[0]] *
                            node_emb[negative_edge_index[1]],
                            dim=1)

                        self_loss = torch.sum(
                            self.self_criterion(
                                positive_score, torch.ones_like(
                                    positive_score)) + self.self_criterion(
                                        negative_score,
                                        torch.zeros_like(negative_score))
                        ) / negative_edge_index[0].size()[0]

                        loss += (self.add_weight * self_loss)

                    if self.add_masking:
                        mask_num = random.sample(range(0,
                                                       node_emb.size()[0]),
                                                 self.batch_size)
                        pred_emb = self.masking_linear(node_emb[mask_num])
                        loss += (self.add_weight * self.masking_criterion(
                            pred_emb.double(), batch.x[mask_num, 0]))

                    if self.add_similarity or self.interact:
                        one_task_emb = torch.div(
                            (one_task_emb + torch.mean(node_emb, 0)), 2.0)

                    losses_s += loss

                if self.add_similarity or self.interact:
                    if task == 0:
                        tasks_emb = []
                    tasks_emb.append(one_task_emb)

                new_grad, new_params = self.update_params(
                    losses_s, update_lr=self.update_lr)

                vector_to_parameters(new_params, self.graph_model.parameters())

                this_loss_q = torch.tensor([0.0]).to(device)
                for step, batch in enumerate(
                        tqdm(query_loaders[task], desc="Iteration")):
                    batch = batch.to(device)

                    pred, node_emb = self.graph_model(batch.x,
                                                      batch.edge_index,
                                                      batch.edge_attr,
                                                      batch.batch)
                    y = batch.y.view(pred.shape).to(torch.float64)

                    loss_q = torch.sum(self.criterion(pred.double(),
                                                      y)) / pred.size()[0]

                    if self.add_selfsupervise:
                        positive_score = torch.sum(
                            node_emb[batch.edge_index[0, ::2]] *
                            node_emb[batch.edge_index[1, ::2]],
                            dim=1)

                        negative_edge_index = self.build_negative_edges(batch)
                        negative_score = torch.sum(
                            node_emb[negative_edge_index[0]] *
                            node_emb[negative_edge_index[1]],
                            dim=1)

                        self_loss = torch.sum(
                            self.self_criterion(
                                positive_score, torch.ones_like(
                                    positive_score)) + self.self_criterion(
                                        negative_score,
                                        torch.zeros_like(negative_score))
                        ) / negative_edge_index[0].size()[0]

                        loss_q += (self.add_weight * self_loss)

                    if self.add_masking:
                        mask_num = random.sample(range(0,
                                                       node_emb.size()[0]),
                                                 self.batch_size)
                        pred_emb = self.masking_linear(node_emb[mask_num])
                        loss += (self.add_weight * self.masking_criterion(
                            pred_emb.double(), batch.x[mask_num, 0]))

                    this_loss_q += loss_q

                if task == 0:
                    losses_q = this_loss_q
                else:
                    losses_q = torch.cat((losses_q, this_loss_q), 0)

                vector_to_parameters(old_params, self.graph_model.parameters())

            if self.add_similarity:
                for t_index, one_task_e in enumerate(tasks_emb):
                    if t_index == 0:
                        tasks_emb_new = one_task_e
                    else:
                        tasks_emb_new = torch.cat((tasks_emb_new, one_task_e),
                                                  0)

                tasks_emb_new = torch.reshape(
                    tasks_emb_new, (self.num_train_tasks, self.emb_dim))
                tasks_emb_new = tasks_emb_new.detach()

                tasks_weight = self.Attention(tasks_emb_new)
                losses_q = torch.sum(tasks_weight * losses_q)

            elif self.interact:
                for t_index, one_task_e in enumerate(tasks_emb):
                    if t_index == 0:
                        tasks_emb_new = one_task_e
                    else:
                        tasks_emb_new = torch.cat((tasks_emb_new, one_task_e),
                                                  0)

                tasks_emb_new = tasks_emb_new.detach()
                represent_emb = self.Interact_attention(tasks_emb_new)
                represent_emb = F.normalize(represent_emb, p=2, dim=0)

                tasks_emb_new = torch.reshape(
                    tasks_emb_new, (self.num_train_tasks, self.emb_dim))
                tasks_emb_new = F.normalize(tasks_emb_new, p=2, dim=1)

                tasks_weight = torch.mm(
                    tasks_emb_new,
                    torch.reshape(represent_emb, (self.emb_dim, 1)))
                print(tasks_weight)
                print(self.softmax(tasks_weight))
                print(losses_q)

                # tasks_emb_new = tasks_emb_new * torch.reshape(represent_emb_m, (self.batch_size, self.emb_dim))

                losses_q = torch.sum(
                    losses_q *
                    torch.transpose(self.softmax(tasks_weight), 1, 0))
                print(losses_q)

            else:
                losses_q = torch.sum(losses_q)

            loss_q = losses_q / self.num_train_tasks
            self.optimizer.zero_grad()
            loss_q.backward()
            self.optimizer.step()

        return []
Exemplo n.º 15
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.995,
                        help='learning rate decay (default: 0.995)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=768,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="mean",
        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--heads',
                        type=int,
                        default=12,
                        help='multi heads (default: 4)')
    parser.add_argument('--num_message_passing',
                        type=int,
                        default=3,
                        help='message passing steps (default: 3)')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='tox21',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='pretrained_model/MolGNet.pt',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=177,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--iters',
                        type=int,
                        default=10,
                        help='number of run seeds')
    parser.add_argument('--processed_file', type=str, default=None)
    parser.add_argument('--raw_file', type=str, default=None)
    parser.add_argument('--cpu', default=False, action="store_true")
    parser.add_argument('--exp', type=str, default='', help='output filename')
    parser.add_argument('--data_dir', type=str, default="")
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")

    if args.dataset == "freesolv":
        # args.seed =219
        # args.runseed = 142
        args.batch_size = 32
        args.lr = 0.0001
        args.lr_decay = 0.99
        args.dropout_ratio = 0
        args.graph_pooling = 'mean'
        args.data_dir = 'data/downstream/'

    elif args.dataset == "esol":
        args.batch_size = 32
        args.lr = 0.001
        args.lr_decay = 0.995
        args.dropout_ratio = 0.5
        args.graph_pooling = 'set2set'
        args.data_dir = 'data/downstream/'

    elif args.dataset == "lipophilicity":
        args.batch_size = 32
        args.lr = 0.0001
        args.lr_decay = 0.99
        args.dropout_ratio = 0
        args.graph_pooling = 'set2set'

    for i in range(args.iters):
        seed = args.seed + i
        runseed = args.runseed
        torch.manual_seed(runseed)
        np.random.seed(runseed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(runseed)
        #Bunch of classification tasks
        num_tasks = 1
        transform = Compose([
            Self_loop(),
            Add_seg_id(),
            Add_collection_node(num_atom_type=119, bidirection=False)
        ])
        dataset = MoleculeDataset(args.data_dir + args.dataset,
                                  dataset=args.dataset,
                                  transform=transform)

        smiles_list = pd.read_csv(args.data_dir + args.dataset +
                                  '/processed/smiles.csv')['smiles'].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=seed)

        train_loader = DataLoaderMasking(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.num_workers)
        val_loader = DataLoaderMasking(valid_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       num_workers=args.num_workers)
        test_loader = DataLoaderMasking(test_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers)

        # set up model
        model = MolGT_graphpred(args.num_layer,
                                args.emb_dim,
                                args.heads,
                                args.num_message_passing,
                                num_tasks,
                                drop_ratio=args.dropout_ratio,
                                graph_pooling=args.graph_pooling)
        if not args.input_model_file == "":
            model.from_pretrained(args.input_model_file)
            print('Pretrained model loaded')

        model.to(device)

        # set up optimizer
        # different learning rate for different part of GNN
        model_param_group = []
        model_param_group.append({"params": model.gnn.parameters()})
        if args.graph_pooling == "attention":
            model_param_group.append({
                "params": model.pool.parameters(),
                "lr": args.lr * args.lr_scale
            })
        model_param_group.append({
            "params": model.graph_pred_linear.parameters(),
            "lr": args.lr * args.lr_scale
        })
        optimizer = optim.Adam(model_param_group,
                               lr=args.lr,
                               weight_decay=args.decay)
        print(optimizer)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           gamma=args.lr_decay)

        train_acc_list = []
        val_acc_list = []
        test_acc_list = []

        exp_path = '{}/{}_seed{}/'.format(args.exp, args.dataset, seed)
        if not os.path.exists(exp_path):
            os.makedirs(exp_path)
        best_rmse = float('inf')
        for epoch in range(1, args.epochs + 1):
            print("====epoch " + str(epoch))
            train(args, model, device, train_loader, optimizer)
            scheduler.step()
            print("====Evaluation")
            train_loss, train_rmse = eval(args, model, device, train_loader)
            val_loss, val_acc = eval(args, model, device, val_loader)
            test_loss, test_acc = eval(args, model, device, test_loader)
            print("RMSE: train: %f val: %f test: %f" %
                  (train_loss, val_acc, test_acc))
            if val_acc <= best_rmse:
                best_rmse = val_acc
                torch.save(model.state_dict(),
                           exp_path + "model_seed{}.pkl".format(args.seed))
                print('saved')

            val_acc_list.append(val_acc)
            test_acc_list.append(test_acc)
            train_acc_list.append(train_rmse)

        df = pd.DataFrame({
            'train': train_acc_list,
            'valid': val_acc_list,
            'test': test_acc_list
        })
        df.to_csv(exp_path + '{}_seed{}.csv'.format(args.dataset, seed))

        best_epoch = np.argmax(val_acc_list)
        test_acc_at_best_val = test_acc_list[best_epoch]
        print("The test auc at best valid (epoch {}) is {} at seed {}".format(
            best_epoch, test_acc_at_best_val, args.runseed))
Exemplo n.º 16
0

    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #Bunch of classification tasks
    if args.dataset == "chembl_filtered":
        num_tasks = 1310
    else:
        raise ValueError("Invalid dataset name.")

    #set up dataset
    dataset = MoleculeDataset("/raid/home/public/dataset_ContextPred_0219/" + args.dataset, dataset=args.dataset)

    

    #set up model
    model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type)
    if args.use_original == 0:
        model = GNN_graphpred(
        args.num_layer,
        args.node_feat_dim,
        args.edge_feat_dim,
        args.emb_dim,
        num_tasks,
        JK=args.JK,
        drop_ratio=args.dropout_ratio,
        graph_pooling=args.graph_pooling,
Exemplo n.º 17
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        "PyTorch implementation of pre-training of graph neural networks")
    parser.add_argument("--device",
                        type=int,
                        default=0,
                        help="which gpu to use if any (default: 0)")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="input batch size for training (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        help="learning rate (default: 0.001)")
    parser.add_argument(
        "--lr_scale",
        type=float,
        default=1,
        help=
        "relative learning rate for the feature extraction layer (default: 1)",
    )
    parser.add_argument("--decay",
                        type=float,
                        default=0,
                        help="weight decay (default: 0)")
    parser.add_argument(
        "--num_layer",
        type=int,
        default=5,
        help="number of GNN message passing layers (default: 5).",
    )

    parser.add_argument(
        "--node_feat_dim",
        type=int,
        default=154,
        help="dimension of the node features.",
    )
    parser.add_argument("--edge_feat_dim",
                        type=int,
                        default=2,
                        help="dimension ofo the edge features.")

    parser.add_argument("--emb_dim",
                        type=int,
                        default=256,
                        help="embedding dimensions (default: 300)")
    parser.add_argument("--dropout_ratio",
                        type=float,
                        default=0.5,
                        help="dropout ratio (default: 0.5)")
    parser.add_argument(
        "--graph_pooling",
        type=str,
        default="mean",
        help="graph level pooling (sum, mean, max, set2set, attention)",
    )
    parser.add_argument(
        "--JK",
        type=str,
        default="last",
        help=
        "how the node features across layers are combined. last, sum, max or concat",
    )
    parser.add_argument("--gnn_type", type=str, default="gine")
    parser.add_argument(
        "--dataset",
        type=str,
        default="bbbp",
        help="root directory of dataset. For now, only classification.",
    )
    parser.add_argument(
        "--input_model_file",
        type=str,
        default="",
        help="filename to read the model (if there is any)",
    )
    parser.add_argument("--filename",
                        type=str,
                        default="",
                        help="output filename")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        "--runseed",
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="scaffold",
        help="random or scaffold or random_scaffold",
    )
    parser.add_argument("--eval_train",
                        type=int,
                        default=0,
                        help="evaluating training or not")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="number of workers for dataset loading",
    )
    parser.add_argument("--use_original",
                        type=int,
                        default=0,
                        help="run benchmark experiment or not")
    #parser.add_argument('--output_model_file', type = str, default = 'finetuned_model/amu', help='filename to output the finetuned model')

    args = parser.parse_args()

    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    device = (torch.device("cuda:" + str(args.device))
              if torch.cuda.is_available() else torch.device("cpu"))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    # Bunch of classification tasks
    if args.dataset == "tox21":
        num_tasks = 12
    elif args.dataset == "hiv":
        num_tasks = 1
    elif args.dataset == "pcba":
        num_tasks = 128
    elif args.dataset == "muv":
        num_tasks = 17
    elif args.dataset == "bace":
        num_tasks = 1
    elif args.dataset == "bbbp":
        num_tasks = 1
    elif args.dataset == "toxcast":
        num_tasks = 617
    elif args.dataset == "sider":
        num_tasks = 27
    elif args.dataset == "clintox":
        num_tasks = 2
    elif args.dataset in ["jak1", "jak2", "jak3", "amu", "ellinger", "mpro"]:
        num_tasks = 1
    else:
        raise ValueError("Invalid dataset name.")

    # set up dataset
    #    dataset = MoleculeDataset("contextPred/chem/dataset/" + args.dataset, dataset=args.dataset)
    dataset = MoleculeDataset(
        root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset)
    if args.use_original == 0:
        dataset = MoleculeDataset(
            root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset,
            transform=ONEHOT_ENCODING(dataset=dataset),
        )

    print(dataset)

    if args.split == "scaffold":
        smiles_list = pd.read_csv(
            "/raid/home/public/dataset_ContextPred_0219/" + args.dataset +
            "/processed/smiles.csv",
            header=None,
        )[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
        )
        print("scaffold")
    elif args.split == "oversample":
        train_dataset, valid_dataset, test_dataset = oversample_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("oversample")
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(
            dataset,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv(
            "/raid/home/public/dataset_ContextPred_0219/" + args.dataset +
            "/processed/smiles.csv",
            header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(
            dataset,
            smiles_list,
            null_value=0,
            frac_train=0.8,
            frac_valid=0.1,
            frac_test=0.1,
            seed=args.seed,
        )
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")

    print(train_dataset[0])

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
    )
    val_loader = DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )

    # set up model
    model = GNN_graphpred(
        args.num_layer,
        args.node_feat_dim,
        args.edge_feat_dim,
        args.emb_dim,
        num_tasks,
        JK=args.JK,
        drop_ratio=args.dropout_ratio,
        graph_pooling=args.graph_pooling,
        gnn_type=args.gnn_type,
        use_embedding=args.use_original,
    )
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file + ".pth")

    model.to(device)

    # set up optimizer
    # different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({
            "params": model.pool.parameters(),
            "lr": args.lr * args.lr_scale
        })
    model_param_group.append({
        "params": model.graph_pred_linear.parameters(),
        "lr": args.lr * args.lr_scale
    })
    optimizer = optim.Adam(model_param_group,
                           lr=args.lr,
                           weight_decay=args.decay)
    print(optimizer)

    train_roc_list = []
    train_acc_list = []
    train_f1_list = []
    train_ap_list = []
    val_roc_list = []
    val_acc_list = []
    val_f1_list = []
    val_ap_list = []
    test_roc_list = []
    test_acc_list = []
    test_f1_list = []
    test_ap_list = []

    if not args.filename == "":
        fname = ("/raid/home/yoyowu/Weihua_b/BASE_TFlogs/" +
                 str(args.runseed) + "/" + args.filename)
        # delete the directory if there exists one
        if os.path.exists(fname):
            shutil.rmtree(fname)
            print("removed the existing file.")
        writer = SummaryWriter(fname)

    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))

        train(args, model, device, train_loader, optimizer)
        #if not args.output_model_file == "":
        #   torch.save(model.state_dict(), args.output_model_file + str(epoch)+ ".pth")

        print("====Evaluation")
        if args.eval_train:
            train_roc, train_acc, train_f1, train_ap, train_num_positive_true, train_num_positive_scores = eval(
                args, model, device, train_loader)

        else:
            print("omit the training accuracy computation")
            train_roc = 0
            train_acc = 0
            train_f1 = 0
            train_ap = 0
        val_roc, val_acc, val_f1, val_ap, val_num_positive_true, val_num_positive_scores = eval(
            args, model, device, val_loader)
        test_roc, test_acc, test_f1, test_ap, test_num_positive_true, test_num_positive_scores = eval(
            args, model, device, test_loader)
        #with open('debug_ellinger.txt', "a") as f:
        #   f.write("====epoch " + str(epoch) +" \n training:  positive true count {} , positive scores count {} \n".format(train_num_positive_true,train_num_positive_scores))
        #  f.write("val:  positive true count {} , positive scores count {} \n".format(val_num_positive_true,val_num_positive_scores))
        # f.write("test:  positive true count {} , positive scores count {} \n".format(test_num_positive_true,test_num_positive_scores))
        #f.write("\n")

        print("train: %f val: %f test auc: %f " %
              (train_roc, val_roc, test_roc))
        val_roc_list.append(val_roc)
        val_f1_list.append(val_f1)
        val_acc_list.append(val_acc)
        val_ap_list.append(val_ap)
        test_acc_list.append(test_acc)
        test_roc_list.append(test_roc)
        test_f1_list.append(test_f1)
        test_ap_list.append(test_ap)
        train_acc_list.append(train_acc)
        train_roc_list.append(train_roc)
        train_f1_list.append(train_f1)
        train_ap_list.append(train_ap)

        if not args.filename == "":
            writer.add_scalar("data/train roc", train_roc, epoch)
            writer.add_scalar("data/train acc", train_acc, epoch)
            writer.add_scalar("data/train f1", train_f1, epoch)
            writer.add_scalar("data/train ap", train_ap, epoch)

            writer.add_scalar("data/val roc", val_roc, epoch)
            writer.add_scalar("data/val acc", val_acc, epoch)
            writer.add_scalar("data/val f1", val_f1, epoch)
            writer.add_scalar("data/val ap", val_ap, epoch)

            writer.add_scalar("data/test roc", test_roc, epoch)
            writer.add_scalar("data/test acc", test_acc, epoch)
            writer.add_scalar("data/test f1", test_f1, epoch)
            writer.add_scalar("data/test ap", test_ap, epoch)

        print("")

    if not args.filename == "":
        writer.close()
Exemplo n.º 18
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.995,
                        help='learning rate decay (default: 0.995)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--loss_type', type=str, default="bce")
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=768,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--heads',
                        type=int,
                        default=12,
                        help='multi heads (default: 4)')
    parser.add_argument('--num_message_passing',
                        type=int,
                        default=3,
                        help='message passing steps (default: 3)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="collection",
        help=
        'graph level pooling (collection,sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='biosnap',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--input_model_file',
                        type=str,
                        default='pretrained_model/MolGNet.pt',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--iters',
                        type=int,
                        default=1,
                        help='number of run seeds')
    parser.add_argument('--save', type=str, default=None)
    parser.add_argument('--log_freq', type=int, default=0)
    parser.add_argument('--KFold',
                        type=int,
                        default=5,
                        help='number of folds for cross validation')
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument('--cpu', default=False, action="store_true")
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")
    args.seed = args.seed
    args.runseed = args.runseed
    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    exp_path = 'runs/{}/'.format(args.dataset)
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)

    num_tasks = 1
    trn_val_dataset = MoleculeDataset("data/downstream/" + args.dataset,
                                      dataset=args.dataset,
                                      train=True,
                                      transform=None)
    test_dataset = MoleculeDataset("data/downstream/" + args.dataset + '/',
                                   dataset=args.dataset,
                                   train=False,
                                   transform=None)

    kf = KFold(n_splits=8, shuffle=True, random_state=3)
    # labels = [data.y.item() for data in trn_val_dataset]
    idx_list = []
    for idx in kf.split(np.zeros(len(trn_val_dataset))):
        idx_list.append(idx)
    train_idx, val_idx = idx_list[args.fold]

    # idx_list = next(kf.split(np.zeros(len(labels)), labels))
    # train_idx, val_idx = idx_list[0],idx_list[1]
    train_dataset = trn_val_dataset[torch.tensor(train_idx)]
    valid_dataset = trn_val_dataset[torch.tensor(val_idx)]

    train_loader = DataLoaderMasking(train_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)
    val_loader = DataLoaderMasking(valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers)
    test_loader = DataLoaderMasking(test_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.num_workers)
    # set up model
    model = MolGT_graphpred(args.num_layer,
                            args.emb_dim,
                            args.heads,
                            args.num_message_passing,
                            num_tasks,
                            drop_ratio=args.dropout_ratio,
                            graph_pooling=args.graph_pooling)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file)
        print('Pretrained model loaded')
    else:
        print('No pretrain')

    total = sum([param.nelement() for param in model.gnn.parameters()])
    print("Number of GNN parameter: %.2fM" % (total / 1e6))

    model.to(device)

    model_param_group = list(model.gnn.named_parameters())
    if args.graph_pooling == "attention":
        model_param_group += list(model.pool.named_parameters())
    model_param_group += list(model.graph_pred_linear.named_parameters())

    param_optimizer = [n for n in model_param_group if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    num_train_optimization_steps = int(
        len(train_dataset) / args.batch_size) * args.epochs
    print(num_train_optimization_steps)
    optimizer = optim.Adam(optimizer_grouped_parameters,
                           lr=args.lr,
                           weight_decay=args.decay)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       gamma=args.lr_decay)

    if args.loss_type == 'bce':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss_type == 'softmax':
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = FocalLoss(gamma=2, alpha=0.25)

    best_result = []
    acc = 0
    for epoch in range(1, args.epochs + 1):
        print("====epoch " + str(epoch))
        train(args, model, device, train_loader, optimizer, criterion)
        scheduler.step()
        print("====Evaluation")
        if args.eval_train:
            train_acc = eval(args, model, device, train_loader)
        else:
            print("omit the training accuracy computation")
            train_acc = 0
        result = eval(args, model, device, val_loader)
        print('validation: ', result)
        test_result = eval(args, model, device, test_loader)
        print('test: ', test_result)
        if result[0] > acc:
            acc = result[0]
            best_result = test_result
            torch.save(
                model.state_dict(), exp_path +
                "model_seed{}_fold_{}.pkl".format(args.seed, args.fold))
            print("save network for epoch:", epoch, acc)
    with open(exp_path + "log.txt", "a+") as f:
        log = 'Test metrics: auc,prc,f1 is {}, at seed {}'.format(
            best_result, args.seed)
        print(log)
        f.write(log)
        f.write('\n')
Exemplo n.º 19
0
Arquivo: eval.py Projeto: pyli0628/MPG
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_decay',
                        type=float,
                        default=0.995,
                        help='learning rate decay (default: 0.995)')
    parser.add_argument(
        '--lr_scale',
        type=float,
        default=1,
        help=
        'relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay',
                        type=float,
                        default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--loss_type', type=str, default="bce")
    parser.add_argument(
        '--num_layer',
        type=int,
        default=5,
        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim',
                        type=int,
                        default=768,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--heads',
                        type=int,
                        default=12,
                        help='multi heads (default: 4)')
    parser.add_argument('--num_message_passing',
                        type=int,
                        default=3,
                        help='message passing steps (default: 3)')
    parser.add_argument('--dropout_ratio',
                        type=float,
                        default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument(
        '--graph_pooling',
        type=str,
        default="collection",
        help=
        'graph level pooling (collection,sum, mean, max, set2set, attention)')
    parser.add_argument(
        '--JK',
        type=str,
        default="last",
        help=
        'how the node features across layers are combined. last, sum, max or concat'
    )
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument(
        '--dataset',
        type=str,
        default='twosides',
        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--model_dir',
                        type=str,
                        default='pretrained_model/MolGNet.pt',
                        help='filename to read the model (if there is any)')
    parser.add_argument('--filename',
                        type=str,
                        default='',
                        help='output filename')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="Seed for splitting the dataset.")
    parser.add_argument(
        '--runseed',
        type=int,
        default=0,
        help="Seed for minibatch selection, random initialization.")
    parser.add_argument('--split',
                        type=str,
                        default="scaffold",
                        help="random or scaffold or random_scaffold")
    parser.add_argument('--eval_train',
                        type=int,
                        default=0,
                        help='evaluating training or not')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='number of workers for dataset loading')
    parser.add_argument('--iters',
                        type=int,
                        default=1,
                        help='number of run seeds')
    parser.add_argument('--log_file', type=str, default=None)
    parser.add_argument('--log_freq', type=int, default=0)
    parser.add_argument('--KFold',
                        type=int,
                        default=5,
                        help='number of folds for cross validation')
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument('--cpu', default=False, action="store_true")
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available(
    ) and not args.cpu else torch.device("cpu")
    args.seed = args.seed
    args.runseed = args.runseed
    torch.manual_seed(args.runseed)
    np.random.seed(args.runseed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.runseed)

    args.model_dir = os.path.join(args.model_dir, args.dataset)

    num_tasks = 1
    dataset = MoleculeDataset("data/downstream/" + args.dataset,
                              dataset=args.dataset,
                              transform=None)
    if args.dataset == 'twosides':
        print('Run 5-fold cross validation')
        for fold in range(args.KFold):
            train_dataset, test_dataset = cv_random_split(dataset, fold, 5)
            test_loader = DataLoaderMasking(test_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
            # set up model
            model = MolGT_graphpred(args.num_layer,
                                    args.emb_dim,
                                    args.heads,
                                    args.num_message_passing,
                                    num_tasks,
                                    drop_ratio=args.dropout_ratio,
                                    graph_pooling=args.graph_pooling)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.model_dir, 'Fold_{}.pkl'.format(fold))))
            model.to(device)

            acc, f1, prec, rec = eval_twosides(model, device, test_loader)
            print('Dataset:{}, Fold:{}, precision:{}, recall:{}, F1:{}'.format(
                args.dataset, fold, prec, rec, f1))

    if args.dataset == 'biosnap':
        print('Run three random split')
        for i in range(3):
            seed = args.seed + i
            train_dataset, valid_dataset, test_dataset = random_split(
                dataset,
                null_value=0,
                frac_train=0.7,
                frac_valid=0.1,
                frac_test=0.2,
                seed=seed)
            test_loader = DataLoaderMasking(test_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers)
            model = MolGT_graphpred(args.num_layer,
                                    args.emb_dim,
                                    args.heads,
                                    args.num_message_passing,
                                    num_tasks,
                                    drop_ratio=args.dropout_ratio,
                                    graph_pooling=args.graph_pooling)
            model.load_state_dict(
                torch.load(
                    os.path.join(args.model_dir,
                                 'model_seed{}.pkl'.format(seed))))
            model.to(device)
            auc, prc, f1 = eval_biosnap(model, device, test_loader)
            print('Dataset:{}, Fold:{}, AUC:{}, PRC:{}, F1:{}'.format(
                args.dataset, i, auc, prc, f1))