def main(): dataset = MoleculeDataset( root="/raid/home/public/dataset_ContextPred_0219/" + "repurposing") dataset = MoleculeDataset( root="/raid/home/public/dataset_ContextPred_0219/" + "repurposing", transform=ONEHOT_ENCODING(dataset=dataset), ) loader = DataLoader( dataset, batch_size=1, shuffle=True, num_workers=4, ) model = GNN_graphpred( num_layer=5, node_feat_dim=154, edge_feat_dim=2, emb_dim=256, num_tasks=1, JK="last", drop_ratio=0.5, graph_pooling="mean", gnn_type="gine", use_embedding=0, ) model.load_state_dict(torch.load("tuned_model/jak3/90.pth")) model.eval() id = [] cid = [] score = [] fields = ['id', 'cid', 'score'] for step, batch in enumerate(tqdm(loader, desc="Iteration")): with torch.no_grad(): pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) id.append(batch.id) cid.append(batch.cid) score.append(pred) dict = {'id': id, 'cid': cid, 'score': score} df = pd.DataFrame(dict) df.to_csv('jak3_score_90.csv')
def main(): # Training settings parser = argparse.ArgumentParser( description= 'PyTorch implementation of pre-training of graph neural networks') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument('--decay', type=float, default=0, help='weight decay (default: 0)') parser.add_argument( '--num_layer', type=int, default=5, help='number of GNN message passing layers (default: 5).') parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)') parser.add_argument('--dropout_ratio', type=float, default=0.2, help='dropout ratio (default: 0.2)') parser.add_argument( '--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') parser.add_argument( '--JK', type=str, default="last", help= 'how the node features across layers are combined. last, sum, max or concat' ) parser.add_argument( '--dataset', type=str, default='chembl_filtered', help='root directory of dataset. For now, only classification.') parser.add_argument('--gnn_type', type=str, default="gin") parser.add_argument('--input_model_file', type=str, default='', help='filename to read the model (if there is any)') parser.add_argument('--output_model_file', type=str, default='', help='filename to output the pre-trained model') parser.add_argument('--num_workers', type=int, default=8, help='number of workers for dataset loading') args = parser.parse_args() torch.manual_seed(0) np.random.seed(0) device = torch.device( "cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) #Bunch of classification tasks if args.dataset == "chembl_filtered": num_tasks = 1310 else: raise ValueError("Invalid dataset name.") #set up dataset dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) #set up model model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.input_model_file == "": model.from_pretrained(args.input_model_file + ".pth") model.to(device) #set up optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) print(optimizer) for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train(args, model, device, loader, optimizer) if not args.output_model_file == "": torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")
def main(): # Training settings parser = argparse.ArgumentParser( description= 'PyTorch implementation of pre-training of graph neural networks') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)') parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 50)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument('--decay', type=float, default=0, help='weight decay (default: 0)') parser.add_argument( '--num_layer', type=int, default=5, help='number of GNN message passing layers (default: 5).') parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)') parser.add_argument('--dropout_ratio', type=float, default=0.5, help='dropout ratio (default: 0.5)') parser.add_argument( '--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') parser.add_argument( '--JK', type=str, default="last", help= 'how the node features across layers are combined. last, sum, max or concat' ) parser.add_argument('--model_file', type=str, default='', help='filename to read the model (if there is any)') parser.add_argument('--filename', type=str, default='', help='output filename') parser.add_argument('--gnn_type', type=str, default="gin") parser.add_argument('--seed', type=int, default=42, help="Seed for splitting dataset.") parser.add_argument('--runseed', type=int, default=0, help="Seed for running experiments.") parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataset loading') parser.add_argument('--eval_train', type=int, default=0, help='evaluating training or not') parser.add_argument('--split', type=str, default="species", help='Random or species split') args = parser.parse_args() torch.manual_seed(args.runseed) np.random.seed(args.runseed) device = torch.device( "cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.runseed) root_supervised = 'dataset/supervised' dataset = BioDataset(root_supervised, data_type='supervised') print(dataset) node_num = 0 edge_num = 0 for d in dataset: node_num += d.x.size()[0] edge_num += d.edge_index.size()[1] print(node_num / len(dataset)) print(edge_num / len(dataset)) assert False if args.split == "random": print("random splitting") train_dataset, valid_dataset, test_dataset = random_split( dataset, seed=args.seed) elif args.split == "species": trainval_dataset, test_dataset = species_split(dataset) train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed=args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed=args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) print("species splitting") else: raise ValueError("Unknown split name.") train_loader = DataLoaderFinetune(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) val_loader = DataLoaderFinetune(valid_dataset, batch_size=10 * args.batch_size, shuffle=False, num_workers=args.num_workers) if args.split == "random": test_loader = DataLoaderFinetune(test_dataset, batch_size=10 * args.batch_size, shuffle=False, num_workers=args.num_workers) else: ### for species splitting test_easy_loader = DataLoaderFinetune(test_dataset_broad, batch_size=10 * args.batch_size, shuffle=False, num_workers=args.num_workers) test_hard_loader = DataLoaderFinetune(test_dataset_none, batch_size=10 * args.batch_size, shuffle=False, num_workers=args.num_workers) num_tasks = len(dataset[0].go_target_downstream) print(train_dataset[0]) #set up model model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.model_file == "": model.from_pretrained(args.model_file) model.to(device) #set up optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) train_acc_list = [] val_acc_list = [] ### for random splitting test_acc_list = [] ### for species splitting test_acc_easy_list = [] test_acc_hard_list = [] if not args.filename == "": if os.path.exists(args.filename): print("removed existing file!!") os.remove(args.filename) for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train(args, model, device, train_loader, optimizer) print("====Evaluation") if args.eval_train: train_acc = eval(args, model, device, train_loader) else: train_acc = 0 print("ommitting training evaluation") val_acc = eval(args, model, device, val_loader) val_acc_list.append(np.mean(val_acc)) train_acc_list.append(train_acc) if args.split == "random": test_acc = eval(args, model, device, test_loader) test_acc_list.append(test_acc) else: test_acc_easy = eval(args, model, device, test_easy_loader) test_acc_hard = eval(args, model, device, test_hard_loader) test_acc_easy_list.append(np.mean(test_acc_easy)) test_acc_hard_list.append(np.mean(test_acc_hard)) print(val_acc_list[-1]) print(test_acc_easy_list[-1]) print(test_acc_hard_list[-1]) print("") with open('result.log', 'a+') as f: f.write( str(args.runseed) + ' ' + str(np.array(test_acc_easy_list)[np.array( val_acc_list).argmax()]) + ' ' + str(np.array(test_acc_hard_list)[np.array(val_acc_list).argmax()])) f.write('\n')
def main(): # Training settings parser = argparse.ArgumentParser( description= 'PyTorch implementation of pre-training of graph neural networks') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument( '--lr_scale', type=float, default=1, help= 'relative learning rate for the feature extraction layer (default: 1)') parser.add_argument('--decay', type=float, default=0, help='weight decay (default: 0)') parser.add_argument( '--num_layer', type=int, default=5, help='number of GNN message passing layers (default: 5).') parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)') parser.add_argument('--dropout_ratio', type=float, default=0.5, help='dropout ratio (default: 0.5)') parser.add_argument( '--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') parser.add_argument( '--JK', type=str, default="last", help= 'how the node features across layers are combined. last, sum, max or concat' ) parser.add_argument('--gnn_type', type=str, default="gin") parser.add_argument( '--dataset', type=str, default='tox21', help='root directory of dataset. For now, only classification.') parser.add_argument('--input_model_file', type=str, default='', help='filename to read the model (if there is any)') parser.add_argument('--filename', type=str, default='', help='output filename') parser.add_argument('--seed', type=int, default=42, help="Seed for splitting the dataset.") parser.add_argument( '--runseed', type=int, default=0, help="Seed for minibatch selection, random initialization.") parser.add_argument('--split', type=str, default="scaffold", help="random or scaffold or random_scaffold") parser.add_argument('--eval_train', type=int, default=0, help='evaluating training or not') parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataset loading') args = parser.parse_args() torch.manual_seed(args.runseed) np.random.seed(args.runseed) device = torch.device( "cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.runseed) #Bunch of classification tasks if args.dataset == "tox21": num_tasks = 12 elif args.dataset == "hiv": num_tasks = 1 elif args.dataset == "pcba": num_tasks = 128 elif args.dataset == "muv": num_tasks = 17 elif args.dataset == "bace": num_tasks = 1 elif args.dataset == "bbbp": num_tasks = 1 elif args.dataset == "toxcast": num_tasks = 617 elif args.dataset == "sider": num_tasks = 27 elif args.dataset == "clintox": num_tasks = 2 else: raise ValueError("Invalid dataset name.") #set up dataset dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset) print(dataset) if args.split == "scaffold": smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist() train_dataset, valid_dataset, test_dataset = scaffold_split( dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1) print("scaffold") elif args.split == "random": train_dataset, valid_dataset, test_dataset = random_split( dataset, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed) print("random") elif args.split == "random_scaffold": smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist() train_dataset, valid_dataset, test_dataset = random_scaffold_split( dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed) print("random scaffold") else: raise ValueError("Invalid split option.") print(train_dataset[0]) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) #set up model model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.input_model_file == "": model.from_pretrained(args.input_model_file) model.to(device) #set up optimizer #different learning rate for different part of GNN model_param_group = [] model_param_group.append({"params": model.gnn.parameters()}) if args.graph_pooling == "attention": model_param_group.append({ "params": model.pool.parameters(), "lr": args.lr * args.lr_scale }) model_param_group.append({ "params": model.graph_pred_linear.parameters(), "lr": args.lr * args.lr_scale }) optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) print(optimizer) train_acc_list = [] val_acc_list = [] test_acc_list = [] if not args.filename == "": fname = 'runs/finetune_cls_runseed' + str( args.runseed) + '/' + args.filename #delete the directory if there exists one if os.path.exists(fname): shutil.rmtree(fname) print("removed the existing file.") writer = SummaryWriter(fname) for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train(args, model, device, train_loader, optimizer) print("====Evaluation") if args.eval_train: train_acc = eval(args, model, device, train_loader) else: print("omit the training accuracy computation") train_acc = 0 val_acc = eval(args, model, device, val_loader) test_acc = eval(args, model, device, test_loader) print("train: %f val: %f test: %f" % (train_acc, val_acc, test_acc)) val_acc_list.append(val_acc) test_acc_list.append(test_acc) train_acc_list.append(train_acc) if not args.filename == "": writer.add_scalar('data/train auc', train_acc, epoch) writer.add_scalar('data/val auc', val_acc, epoch) writer.add_scalar('data/test auc', test_acc, epoch) print("") if not args.filename == "": writer.close()
def main(): # Training settings parser = argparse.ArgumentParser( description= "PyTorch implementation of pre-training of graph neural networks") parser.add_argument("--device", type=int, default=0, help="which gpu to use if any (default: 0)") parser.add_argument( "--batch_size", type=int, default=32, help="input batch size for training (default: 32)", ) parser.add_argument( "--epochs", type=int, default=100, help="number of epochs to train (default: 100)", ) parser.add_argument("--lr", type=float, default=0.001, help="learning rate (default: 0.001)") parser.add_argument( "--lr_scale", type=float, default=1, help= "relative learning rate for the feature extraction layer (default: 1)", ) parser.add_argument("--decay", type=float, default=0, help="weight decay (default: 0)") parser.add_argument( "--num_layer", type=int, default=5, help="number of GNN message passing layers (default: 5).", ) parser.add_argument( "--node_feat_dim", type=int, default=154, help="dimension of the node features.", ) parser.add_argument("--edge_feat_dim", type=int, default=2, help="dimension ofo the edge features.") parser.add_argument("--emb_dim", type=int, default=256, help="embedding dimensions (default: 300)") parser.add_argument("--dropout_ratio", type=float, default=0.5, help="dropout ratio (default: 0.5)") parser.add_argument( "--graph_pooling", type=str, default="mean", help="graph level pooling (sum, mean, max, set2set, attention)", ) parser.add_argument( "--JK", type=str, default="last", help= "how the node features across layers are combined. last, sum, max or concat", ) parser.add_argument("--gnn_type", type=str, default="gine") parser.add_argument( "--dataset", type=str, default="bbbp", help="root directory of dataset. For now, only classification.", ) parser.add_argument( "--input_model_file", type=str, default="", help="filename to read the model (if there is any)", ) parser.add_argument("--filename", type=str, default="", help="output filename") parser.add_argument("--seed", type=int, default=42, help="Seed for splitting the dataset.") parser.add_argument( "--runseed", type=int, default=0, help="Seed for minibatch selection, random initialization.", ) parser.add_argument( "--split", type=str, default="scaffold", help="random or scaffold or random_scaffold", ) parser.add_argument("--eval_train", type=int, default=0, help="evaluating training or not") parser.add_argument( "--num_workers", type=int, default=4, help="number of workers for dataset loading", ) parser.add_argument("--use_original", type=int, default=0, help="run benchmark experiment or not") #parser.add_argument('--output_model_file', type = str, default = 'finetuned_model/amu', help='filename to output the finetuned model') args = parser.parse_args() torch.manual_seed(args.runseed) np.random.seed(args.runseed) device = (torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.runseed) # Bunch of classification tasks if args.dataset == "tox21": num_tasks = 12 elif args.dataset == "hiv": num_tasks = 1 elif args.dataset == "pcba": num_tasks = 128 elif args.dataset == "muv": num_tasks = 17 elif args.dataset == "bace": num_tasks = 1 elif args.dataset == "bbbp": num_tasks = 1 elif args.dataset == "toxcast": num_tasks = 617 elif args.dataset == "sider": num_tasks = 27 elif args.dataset == "clintox": num_tasks = 2 elif args.dataset in ["jak1", "jak2", "jak3", "amu", "ellinger", "mpro"]: num_tasks = 1 else: raise ValueError("Invalid dataset name.") # set up dataset # dataset = MoleculeDataset("contextPred/chem/dataset/" + args.dataset, dataset=args.dataset) dataset = MoleculeDataset( root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset) if args.use_original == 0: dataset = MoleculeDataset( root="/raid/home/public/dataset_ContextPred_0219/" + args.dataset, transform=ONEHOT_ENCODING(dataset=dataset), ) print(dataset) if args.split == "scaffold": smiles_list = pd.read_csv( "/raid/home/public/dataset_ContextPred_0219/" + args.dataset + "/processed/smiles.csv", header=None, )[0].tolist() train_dataset, valid_dataset, test_dataset = scaffold_split( dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, ) print("scaffold") elif args.split == "oversample": train_dataset, valid_dataset, test_dataset = oversample_split( dataset, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed, ) print("oversample") elif args.split == "random": train_dataset, valid_dataset, test_dataset = random_split( dataset, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed, ) print("random") elif args.split == "random_scaffold": smiles_list = pd.read_csv( "/raid/home/public/dataset_ContextPred_0219/" + args.dataset + "/processed/smiles.csv", header=None)[0].tolist() train_dataset, valid_dataset, test_dataset = random_scaffold_split( dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed, ) print("random scaffold") else: raise ValueError("Invalid split option.") print(train_dataset[0]) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, ) val_loader = DataLoader( valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) test_loader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ) # set up model model = GNN_graphpred( args.num_layer, args.node_feat_dim, args.edge_feat_dim, args.emb_dim, num_tasks, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type, use_embedding=args.use_original, ) if not args.input_model_file == "": model.from_pretrained(args.input_model_file + ".pth") model.to(device) # set up optimizer # different learning rate for different part of GNN model_param_group = [] model_param_group.append({"params": model.gnn.parameters()}) if args.graph_pooling == "attention": model_param_group.append({ "params": model.pool.parameters(), "lr": args.lr * args.lr_scale }) model_param_group.append({ "params": model.graph_pred_linear.parameters(), "lr": args.lr * args.lr_scale }) optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) print(optimizer) train_roc_list = [] train_acc_list = [] train_f1_list = [] train_ap_list = [] val_roc_list = [] val_acc_list = [] val_f1_list = [] val_ap_list = [] test_roc_list = [] test_acc_list = [] test_f1_list = [] test_ap_list = [] if not args.filename == "": fname = ("/raid/home/yoyowu/Weihua_b/BASE_TFlogs/" + str(args.runseed) + "/" + args.filename) # delete the directory if there exists one if os.path.exists(fname): shutil.rmtree(fname) print("removed the existing file.") writer = SummaryWriter(fname) for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train(args, model, device, train_loader, optimizer) #if not args.output_model_file == "": # torch.save(model.state_dict(), args.output_model_file + str(epoch)+ ".pth") print("====Evaluation") if args.eval_train: train_roc, train_acc, train_f1, train_ap, train_num_positive_true, train_num_positive_scores = eval( args, model, device, train_loader) else: print("omit the training accuracy computation") train_roc = 0 train_acc = 0 train_f1 = 0 train_ap = 0 val_roc, val_acc, val_f1, val_ap, val_num_positive_true, val_num_positive_scores = eval( args, model, device, val_loader) test_roc, test_acc, test_f1, test_ap, test_num_positive_true, test_num_positive_scores = eval( args, model, device, test_loader) #with open('debug_ellinger.txt', "a") as f: # f.write("====epoch " + str(epoch) +" \n training: positive true count {} , positive scores count {} \n".format(train_num_positive_true,train_num_positive_scores)) # f.write("val: positive true count {} , positive scores count {} \n".format(val_num_positive_true,val_num_positive_scores)) # f.write("test: positive true count {} , positive scores count {} \n".format(test_num_positive_true,test_num_positive_scores)) #f.write("\n") print("train: %f val: %f test auc: %f " % (train_roc, val_roc, test_roc)) val_roc_list.append(val_roc) val_f1_list.append(val_f1) val_acc_list.append(val_acc) val_ap_list.append(val_ap) test_acc_list.append(test_acc) test_roc_list.append(test_roc) test_f1_list.append(test_f1) test_ap_list.append(test_ap) train_acc_list.append(train_acc) train_roc_list.append(train_roc) train_f1_list.append(train_f1) train_ap_list.append(train_ap) if not args.filename == "": writer.add_scalar("data/train roc", train_roc, epoch) writer.add_scalar("data/train acc", train_acc, epoch) writer.add_scalar("data/train f1", train_f1, epoch) writer.add_scalar("data/train ap", train_ap, epoch) writer.add_scalar("data/val roc", val_roc, epoch) writer.add_scalar("data/val acc", val_acc, epoch) writer.add_scalar("data/val f1", val_f1, epoch) writer.add_scalar("data/val ap", val_ap, epoch) writer.add_scalar("data/test roc", test_roc, epoch) writer.add_scalar("data/test acc", test_acc, epoch) writer.add_scalar("data/test f1", test_f1, epoch) writer.add_scalar("data/test ap", test_ap, epoch) print("") if not args.filename == "": writer.close()
def main(): # Training settings parser = argparse.ArgumentParser( description= 'PyTorch implementation of pre-training of graph neural networks') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument('--decay', type=float, default=0, help='weight decay (default: 0)') parser.add_argument( '--num_layer', type=int, default=5, help='number of GNN message passing layers (default: 5).') parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions (default: 300)') parser.add_argument('--dropout_ratio', type=float, default=0.2, help='dropout ratio (default: 0.2)') parser.add_argument( '--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') parser.add_argument( '--JK', type=str, default="last", help= 'how the node features across layers are combined. last, sum, max or concat' ) parser.add_argument('--input_model_file', type=str, default='', help='filename to read the model (if there is any)') parser.add_argument('--output_model_file', type=str, default='', help='filename to output the pre-trained model') parser.add_argument('--gnn_type', type=str, default="gin") parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataset loading') parser.add_argument('--seed', type=int, default=42, help="Seed for splitting dataset.") parser.add_argument('--split', type=str, default="species", help='Random or species split') args = parser.parse_args() torch.manual_seed(0) np.random.seed(0) device = torch.device( "cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) root_supervised = 'dataset/supervised' dataset = BioDataset(root_supervised, data_type='supervised') if args.split == "random": print("random splitting") train_dataset, valid_dataset, test_dataset = random_split( dataset, seed=args.seed) print(train_dataset) print(valid_dataset) pretrain_dataset = combine_dataset(train_dataset, valid_dataset) print(pretrain_dataset) elif args.split == "species": print("species splitting") trainval_dataset, test_dataset = species_split(dataset) test_dataset_broad, test_dataset_none, _ = random_split(test_dataset, seed=args.seed, frac_train=0.5, frac_valid=0.5, frac_test=0) print(trainval_dataset) print(test_dataset_broad) pretrain_dataset = combine_dataset(trainval_dataset, test_dataset_broad) print(pretrain_dataset) #train_dataset, valid_dataset, _ = random_split(trainval_dataset, seed = args.seed, frac_train=0.85, frac_valid=0.15, frac_test=0) else: raise ValueError("Unknown split name.") # train_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers) # (Note) Fixed the bug here. DataloaderFinetune should be used here to increment the center_node_idx. # The resluts in the paper are obtained with the original pytorch geometric dataloder, so the results with the correct dataloader might be slightly different. train_loader = DataLoaderFinetune(pretrain_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) num_tasks = len(pretrain_dataset[0].go_target_pretrain) #set up model model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.input_model_file == "": model.from_pretrained(args.input_model_file + ".pth") model.to(device) #set up optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) print(optimizer) for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train_loss = train(args, model, device, train_loader, optimizer) if not args.output_model_file == "": torch.save(model.gnn.state_dict(), args.output_model_file + ".pth")
def main(): # Training settings parser = argparse.ArgumentParser( description= 'PyTorch implementation of pre-training of graph neural networks') parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument( '--lr_scale', type=float, default=1, help= 'relative learning rate for the feature extraction layer (default: 1)') parser.add_argument('--decay', type=float, default=0, help='weight decay (default: 0)') parser.add_argument( '--num_layer', type=int, default=3, help='number of GNN message passing layers (default: 3).') parser.add_argument('--emb_dim', type=int, default=512, help='embedding dimensions (default: 300)') parser.add_argument('--dropout_ratio', type=float, default=0.5, help='dropout ratio (default: 0.5)') parser.add_argument( '--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') parser.add_argument( '--JK', type=str, default="last", help= 'how the node features across layers are combined. last, sum, max or concat' ) parser.add_argument('--gnn_type', type=str, default="gin") parser.add_argument( '--dataset', type=str, default='esol', help='root directory of dataset. For now, only classification.') parser.add_argument('--input_model_file', type=str, default='', help='filename to read the model (if there is any)') parser.add_argument('--filename', type=str, default='', help='output filename') parser.add_argument('--seed', type=int, default=0, help="Seed for splitting the dataset.") parser.add_argument( '--runseed', type=int, default=1, help="Seed for minibatch selection, random initialization.") parser.add_argument('--split', type=str, default="random", help="random or scaffold or random_scaffold") parser.add_argument('--eval_train', type=int, default=1, help='evaluating training or not') parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataset loading') parser.add_argument('--aug1', type=str, default='dropN_random', help='augmentation1') parser.add_argument('--aug2', type=str, default='dropN_random', help='augmentation2') parser.add_argument('--aug_ratio1', type=float, default=0.0, help='aug ratio1') parser.add_argument('--aug_ratio2', type=float, default=0.0, help='aug ratio2') parser.add_argument('--dataset_load', type=str, default='esol', help='load pretrain model from which dataset.') parser.add_argument('--protocol', type=str, default='linear', help='downstream protocol, linear, nonlinear') parser.add_argument( '--semi_ratio', type=float, default=1.0, help='proportion of labels in semi-supervised settings') parser.add_argument('--pretrain_method', type=str, default='local', help='pretrain_method: local, global') parser.add_argument('--lamb', type=float, default=0.0, help='hyper para of global-structure loss') parser.add_argument('--n_nb', type=int, default=0, help='number of neighbors for global-structure loss') args = parser.parse_args() torch.manual_seed(args.runseed) np.random.seed(args.runseed) device = torch.device( "cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.runseed) if args.dataset in [ 'tox21', 'hiv', 'pcba', 'muv', 'bace', 'bbbp', 'toxcast', 'sider', 'clintox', 'mutag' ]: task_type = 'cls' else: task_type = 'reg' #Bunch of classification tasks if args.dataset == "tox21": num_tasks = 12 elif args.dataset == "hiv": num_tasks = 1 elif args.dataset == "pcba": num_tasks = 128 elif args.dataset == "muv": num_tasks = 17 elif args.dataset == "bace": num_tasks = 1 elif args.dataset == "bbbp": num_tasks = 1 elif args.dataset == "toxcast": num_tasks = 617 elif args.dataset == "sider": num_tasks = 27 elif args.dataset == "clintox": num_tasks = 2 elif args.dataset == 'esol': num_tasks = 1 elif args.dataset == 'freesolv': num_tasks = 1 elif args.dataset == 'mutag': num_tasks = 1 else: raise ValueError("Invalid dataset name.") #set up dataset dataset = MoleculeDataset("dataset/" + args.dataset, dataset=args.dataset) print('The whole dataset:', dataset) if args.split == "scaffold": smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist() train_dataset, valid_dataset, test_dataset = scaffold_split( dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1) print("scaffold") elif args.split == "random": train_dataset, valid_dataset, test_dataset = random_split( dataset, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed) print("random") elif args.split == "random_scaffold": smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist() train_dataset, valid_dataset, test_dataset = random_scaffold_split( dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=args.seed) print("random scaffold") else: raise ValueError("Invalid split option.") # semi-supervised settings if args.semi_ratio != 1.0: n_total, n_sample = len(train_dataset), int( len(train_dataset) * args.semi_ratio) print( 'sample {:.2f} = {:d} labels for semi-supervised training!'.format( args.semi_ratio, n_sample)) all_idx = list(range(n_total)) random.seed(0) idx_semi = random.sample(all_idx, n_sample) train_dataset = train_dataset[torch.tensor( idx_semi)] #int(len(train_dataset)*args.semi_ratio) print('new train dataset size:', len(train_dataset)) else: print('finetune using all data!') if args.dataset == 'freesolv': train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) else: train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) if args.pretrain_method == 'local': load_dir = 'results/' + args.dataset + '/pretrain_local/' save_dir = 'results/' + args.dataset + '/finetune_local/' elif args.pretrain_method == 'global': load_dir = 'results/' + args.dataset + '/pretrain_global/nb_' + str( args.n_nb) + '/' save_dir = 'results/' + args.dataset + '/finetune_global/nb_' + str( args.n_nb) + '/' else: print('Invalid method!!') if not os.path.exists(save_dir): os.system('mkdir -p %s' % save_dir) if not args.input_model_file == "": input_model_str = args.dataset_load + '_aug1_' + args.aug1 + '_' + str( args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str( args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str( args.dropout_ratio) + '_seed_' + str(args.runseed) output_model_str = args.dataset + '_semi_' + str( args.semi_ratio ) + '_protocol_' + args.protocol + '_aug1_' + args.aug1 + '_' + str( args.aug_ratio1) + '_aug2_' + args.aug2 + '_' + str( args.aug_ratio2) + '_lamb_' + str(args.lamb) + '_do_' + str( args.dropout_ratio) + '_seed_' + str( args.runseed) + '_' + str(args.seed) else: output_model_str = 'scratch_' + args.dataset + '_semi_' + str( args.semi_ratio) + '_protocol_' + args.protocol + '_do_' + str( args.dropout_ratio) + '_seed_' + str(args.runseed) + '_' + str( args.seed) txtfile = save_dir + output_model_str + ".txt" nowTime = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') if os.path.exists(txtfile): os.system('mv %s %s' % (txtfile, txtfile + ".bak-%s" % nowTime)) # rename exsist file for collison #set up model model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.input_model_file == "": model.from_pretrained(load_dir + args.input_model_file + input_model_str + '.pth') print('successfully load pretrained model!') else: print('No pretrain! train from scratch!') model.to(device) #set up optimizer #different learning rate for different part of GNN model_param_group = [] model_param_group.append({"params": model.gnn.parameters()}) if args.graph_pooling == "attention": model_param_group.append({ "params": model.pool.parameters(), "lr": args.lr * args.lr_scale }) model_param_group.append({ "params": model.graph_pred_linear.parameters(), "lr": args.lr * args.lr_scale }) optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) print(optimizer) # if linear protocol, fix GNN layers if args.protocol == 'linear': print("linear protocol, only train the top layer!") for name, param in model.named_parameters(): if not 'pred_linear' in name: param.requires_grad = False elif args.protocol == 'nonlinear': print("finetune protocol, train all the layers!") else: print("invalid protocol!") # all task info summary print('=========task summary=========') print('Dataset: ', args.dataset) if args.semi_ratio == 1.0: print('full-supervised {:.2f}'.format(args.semi_ratio)) else: print('semi-supervised {:.2f}'.format(args.semi_ratio)) if args.input_model_file == '': print('scratch or finetune: scratch') print('loaded model from: - ') else: print('scratch or finetune: finetune') print('loaded model from: ', args.dataset_load) print('global_mode: n_nb = ', args.n_nb) print('Protocol: ', args.protocol) print('task type:', task_type) print('=========task summary=========') # training based on task type if task_type == 'cls': with open(txtfile, "a") as myfile: myfile.write('epoch: train_auc val_auc test_auc\n') wait = 0 best_auc = 0 patience = 10 for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train_cls(args, model, device, train_loader, optimizer) print("====Evaluation") if args.eval_train: train_auc = eval_cls(args, model, device, train_loader) else: print("omit the training accuracy computation") train_auc = 0 val_auc = eval_cls(args, model, device, val_loader) test_auc = eval_cls(args, model, device, test_loader) with open(txtfile, "a") as myfile: myfile.write( str(int(epoch)) + ': ' + str(train_auc) + ' ' + str(val_auc) + ' ' + str(test_auc) + "\n") print("train: %f val: %f test: %f" % (train_auc, val_auc, test_auc)) # Early stopping if np.greater(val_auc, best_auc): # change for train loss best_auc = val_auc wait = 0 else: wait += 1 if wait >= patience: print( 'Early stop at Epoch: {:d} with final val auc: {:.4f}'. format(epoch, val_auc)) break elif task_type == 'reg': with open(txtfile, "a") as myfile: myfile.write( 'epoch: train_mse train_cor val_mse val_cor test_mse test_cor\n' ) for epoch in range(1, args.epochs + 1): print("====epoch " + str(epoch)) train(args, model, device, train_loader, optimizer) print("====Evaluation") if args.eval_train: train_mse, train_cor = eval_reg(args, model, device, train_loader) else: print("omit the training accuracy computation") train_mse, train_cor = 0, 0 val_mse, val_cor = eval_reg(args, model, device, val_loader) test_mse, test_cor = eval_reg(args, model, device, test_loader) with open(txtfile, "a") as myfile: myfile.write( str(int(epoch)) + ': ' + str(train_mse) + ' ' + str(train_cor) + ' ' + str(val_mse) + ' ' + str(val_cor) + ' ' + str(test_mse) + ' ' + str(test_cor) + "\n") print("train: %f val: %f test: %f" % (train_mse, val_mse, test_mse)) print("train: %f val: %f test: %f" % (train_cor, val_cor, test_cor))
def __init__(self, args): super(Meta_model, self).__init__() self.dataset = args.dataset self.num_tasks = args.num_tasks self.num_train_tasks = args.num_train_tasks self.num_test_tasks = args.num_test_tasks self.n_way = args.n_way self.m_support = args.m_support self.k_query = args.k_query self.gnn_type = args.gnn_type self.emb_dim = args.emb_dim self.device = args.device self.add_similarity = args.add_similarity self.add_selfsupervise = args.add_selfsupervise self.add_masking = args.add_masking self.add_weight = args.add_weight self.interact = args.interact self.batch_size = args.batch_size self.meta_lr = args.meta_lr self.update_lr = args.update_lr self.update_step = args.update_step self.update_step_test = args.update_step_test self.criterion = nn.BCEWithLogitsLoss() self.graph_model = GNN_graphpred(args.num_layer, args.emb_dim, 1, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.input_model_file == "": self.graph_model.from_pretrained(args.input_model_file) if self.add_selfsupervise: self.self_criterion = nn.BCEWithLogitsLoss() if self.add_masking: self.masking_criterion = nn.CrossEntropyLoss() self.masking_linear = nn.Linear(self.emb_dim, 119) if self.add_similarity: self.Attention = attention(self.emb_dim) if self.interact: self.softmax = nn.Softmax(dim=0) self.Interact_attention = Interact_attention( self.emb_dim, self.num_train_tasks) model_param_group = [] model_param_group.append({"params": self.graph_model.gnn.parameters()}) if args.graph_pooling == "attention": model_param_group.append({ "params": self.graph_model.pool.parameters(), "lr": args.lr * args.lr_scale }) model_param_group.append({ "params": self.graph_model.graph_pred_linear.parameters(), "lr": args.lr * args.lr_scale }) if self.add_masking: model_param_group.append( {"params": self.masking_linear.parameters()}) if self.add_similarity: model_param_group.append({"params": self.Attention.parameters()}) if self.interact: model_param_group.append( {"params": self.Interact_attention.parameters()}) self.optimizer = optim.Adam(model_param_group, lr=args.meta_lr, weight_decay=args.decay)
class Meta_model(nn.Module): def __init__(self, args): super(Meta_model, self).__init__() self.dataset = args.dataset self.num_tasks = args.num_tasks self.num_train_tasks = args.num_train_tasks self.num_test_tasks = args.num_test_tasks self.n_way = args.n_way self.m_support = args.m_support self.k_query = args.k_query self.gnn_type = args.gnn_type self.emb_dim = args.emb_dim self.device = args.device self.add_similarity = args.add_similarity self.add_selfsupervise = args.add_selfsupervise self.add_masking = args.add_masking self.add_weight = args.add_weight self.interact = args.interact self.batch_size = args.batch_size self.meta_lr = args.meta_lr self.update_lr = args.update_lr self.update_step = args.update_step self.update_step_test = args.update_step_test self.criterion = nn.BCEWithLogitsLoss() self.graph_model = GNN_graphpred(args.num_layer, args.emb_dim, 1, JK=args.JK, drop_ratio=args.dropout_ratio, graph_pooling=args.graph_pooling, gnn_type=args.gnn_type) if not args.input_model_file == "": self.graph_model.from_pretrained(args.input_model_file) if self.add_selfsupervise: self.self_criterion = nn.BCEWithLogitsLoss() if self.add_masking: self.masking_criterion = nn.CrossEntropyLoss() self.masking_linear = nn.Linear(self.emb_dim, 119) if self.add_similarity: self.Attention = attention(self.emb_dim) if self.interact: self.softmax = nn.Softmax(dim=0) self.Interact_attention = Interact_attention( self.emb_dim, self.num_train_tasks) model_param_group = [] model_param_group.append({"params": self.graph_model.gnn.parameters()}) if args.graph_pooling == "attention": model_param_group.append({ "params": self.graph_model.pool.parameters(), "lr": args.lr * args.lr_scale }) model_param_group.append({ "params": self.graph_model.graph_pred_linear.parameters(), "lr": args.lr * args.lr_scale }) if self.add_masking: model_param_group.append( {"params": self.masking_linear.parameters()}) if self.add_similarity: model_param_group.append({"params": self.Attention.parameters()}) if self.interact: model_param_group.append( {"params": self.Interact_attention.parameters()}) self.optimizer = optim.Adam(model_param_group, lr=args.meta_lr, weight_decay=args.decay) # for name, para in self.named_parameters(): # if para.requires_grad: # print(name, para.data.shape) # raise TypeError def update_params(self, loss, update_lr): grads = torch.autograd.grad(loss, self.graph_model.parameters()) return parameters_to_vector(grads), parameters_to_vector( self.graph_model.parameters( )) - parameters_to_vector(grads) * update_lr def build_negative_edges(self, batch): font_list = batch.edge_index[0, ::2].tolist() back_list = batch.edge_index[1, ::2].tolist() all_edge = {} for count, front_e in enumerate(font_list): if front_e not in all_edge: all_edge[front_e] = [back_list[count]] else: all_edge[front_e].append(back_list[count]) negative_edges = [] for num in range(batch.x.size()[0]): if num in all_edge: for num_back in range(num, batch.x.size()[0]): if num_back not in all_edge[num] and num != num_back: negative_edges.append((num, num_back)) else: for num_back in range(num, batch.x.size()[0]): if num != num_back: negative_edges.append((num, num_back)) negative_edge_index = torch.tensor(np.array( random.sample(negative_edges, len(font_list))).T, dtype=torch.long) return negative_edge_index def forward(self, epoch): support_loaders = [] query_loaders = [] device = torch.device("cuda:" + str(self.device)) if torch.cuda.is_available( ) else torch.device("cpu") self.graph_model.train() # tasks_list = random.sample(range(0,self.num_train_tasks), self.batch_size) for task in range(self.num_train_tasks): # for task in tasks_list: dataset = MoleculeDataset("Original_datasets/" + self.dataset + "/new/" + str(task + 1), dataset=self.dataset) support_dataset, query_dataset = sample_datasets( dataset, self.dataset, task, self.n_way, self.m_support, self.k_query) support_loader = DataLoader(support_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1) query_loader = DataLoader(query_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1) support_loaders.append(support_loader) query_loaders.append(query_loader) for k in range(0, self.update_step): # print(self.fi) old_params = parameters_to_vector(self.graph_model.parameters()) losses_q = torch.tensor([0.0]).to(device) # support_params = [] # support_grads = torch.Tensor(self.num_train_tasks, parameters_to_vector(self.graph_model.parameters()).size()[0]).to(device) for task in range(self.num_train_tasks): losses_s = torch.tensor([0.0]).to(device) if self.add_similarity or self.interact: one_task_emb = torch.zeros(300).to(device) for step, batch in enumerate( tqdm(support_loaders[task], desc="Iteration")): batch = batch.to(device) pred, node_emb = self.graph_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) y = batch.y.view(pred.shape).to(torch.float64) loss = torch.sum(self.criterion(pred.double(), y)) / pred.size()[0] if self.add_selfsupervise: positive_score = torch.sum( node_emb[batch.edge_index[0, ::2]] * node_emb[batch.edge_index[1, ::2]], dim=1) negative_edge_index = self.build_negative_edges(batch) negative_score = torch.sum( node_emb[negative_edge_index[0]] * node_emb[negative_edge_index[1]], dim=1) self_loss = torch.sum( self.self_criterion( positive_score, torch.ones_like( positive_score)) + self.self_criterion( negative_score, torch.zeros_like(negative_score)) ) / negative_edge_index[0].size()[0] loss += (self.add_weight * self_loss) if self.add_masking: mask_num = random.sample(range(0, node_emb.size()[0]), self.batch_size) pred_emb = self.masking_linear(node_emb[mask_num]) loss += (self.add_weight * self.masking_criterion( pred_emb.double(), batch.x[mask_num, 0])) if self.add_similarity or self.interact: one_task_emb = torch.div( (one_task_emb + torch.mean(node_emb, 0)), 2.0) losses_s += loss if self.add_similarity or self.interact: if task == 0: tasks_emb = [] tasks_emb.append(one_task_emb) new_grad, new_params = self.update_params( losses_s, update_lr=self.update_lr) vector_to_parameters(new_params, self.graph_model.parameters()) this_loss_q = torch.tensor([0.0]).to(device) for step, batch in enumerate( tqdm(query_loaders[task], desc="Iteration")): batch = batch.to(device) pred, node_emb = self.graph_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) y = batch.y.view(pred.shape).to(torch.float64) loss_q = torch.sum(self.criterion(pred.double(), y)) / pred.size()[0] if self.add_selfsupervise: positive_score = torch.sum( node_emb[batch.edge_index[0, ::2]] * node_emb[batch.edge_index[1, ::2]], dim=1) negative_edge_index = self.build_negative_edges(batch) negative_score = torch.sum( node_emb[negative_edge_index[0]] * node_emb[negative_edge_index[1]], dim=1) self_loss = torch.sum( self.self_criterion( positive_score, torch.ones_like( positive_score)) + self.self_criterion( negative_score, torch.zeros_like(negative_score)) ) / negative_edge_index[0].size()[0] loss_q += (self.add_weight * self_loss) if self.add_masking: mask_num = random.sample(range(0, node_emb.size()[0]), self.batch_size) pred_emb = self.masking_linear(node_emb[mask_num]) loss += (self.add_weight * self.masking_criterion( pred_emb.double(), batch.x[mask_num, 0])) this_loss_q += loss_q if task == 0: losses_q = this_loss_q else: losses_q = torch.cat((losses_q, this_loss_q), 0) vector_to_parameters(old_params, self.graph_model.parameters()) if self.add_similarity: for t_index, one_task_e in enumerate(tasks_emb): if t_index == 0: tasks_emb_new = one_task_e else: tasks_emb_new = torch.cat((tasks_emb_new, one_task_e), 0) tasks_emb_new = torch.reshape( tasks_emb_new, (self.num_train_tasks, self.emb_dim)) tasks_emb_new = tasks_emb_new.detach() tasks_weight = self.Attention(tasks_emb_new) losses_q = torch.sum(tasks_weight * losses_q) elif self.interact: for t_index, one_task_e in enumerate(tasks_emb): if t_index == 0: tasks_emb_new = one_task_e else: tasks_emb_new = torch.cat((tasks_emb_new, one_task_e), 0) tasks_emb_new = tasks_emb_new.detach() represent_emb = self.Interact_attention(tasks_emb_new) represent_emb = F.normalize(represent_emb, p=2, dim=0) tasks_emb_new = torch.reshape( tasks_emb_new, (self.num_train_tasks, self.emb_dim)) tasks_emb_new = F.normalize(tasks_emb_new, p=2, dim=1) tasks_weight = torch.mm( tasks_emb_new, torch.reshape(represent_emb, (self.emb_dim, 1))) print(tasks_weight) print(self.softmax(tasks_weight)) print(losses_q) # tasks_emb_new = tasks_emb_new * torch.reshape(represent_emb_m, (self.batch_size, self.emb_dim)) losses_q = torch.sum( losses_q * torch.transpose(self.softmax(tasks_weight), 1, 0)) print(losses_q) else: losses_q = torch.sum(losses_q) loss_q = losses_q / self.num_train_tasks self.optimizer.zero_grad() loss_q.backward() self.optimizer.step() return [] def test(self, support_grads): accs = [] old_params = parameters_to_vector(self.graph_model.parameters()) for task in range(self.num_test_tasks): print(self.num_tasks - task) dataset = MoleculeDataset("Original_datasets/" + self.dataset + "/new/" + str(self.num_tasks - task), dataset=self.dataset) support_dataset, query_dataset = sample_test_datasets( dataset, self.dataset, self.num_tasks - task - 1, self.n_way, self.m_support, self.k_query) support_loader = DataLoader(support_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1) query_loader = DataLoader(query_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1) device = torch.device("cuda:" + str(self.device)) if torch.cuda.is_available( ) else torch.device("cpu") self.graph_model.eval() for k in range(0, self.update_step_test): loss = torch.tensor([0.0]).to(device) for step, batch in enumerate( tqdm(support_loader, desc="Iteration")): batch = batch.to(device) pred, node_emb = self.graph_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) y = batch.y.view(pred.shape).to(torch.float64) loss += torch.sum(self.criterion(pred.double(), y)) / pred.size()[0] if self.add_selfsupervise: positive_score = torch.sum( node_emb[batch.edge_index[0, ::2]] * node_emb[batch.edge_index[1, ::2]], dim=1) negative_edge_index = self.build_negative_edges(batch) negative_score = torch.sum( node_emb[negative_edge_index[0]] * node_emb[negative_edge_index[1]], dim=1) self_loss = torch.sum( self.self_criterion( positive_score, torch.ones_like( positive_score)) + self.self_criterion( negative_score, torch.zeros_like(negative_score)) ) / negative_edge_index[0].size()[0] loss += (self.add_weight * self_loss) if self.add_masking: mask_num = random.sample(range(0, node_emb.size()[0]), self.batch_size) pred_emb = self.masking_linear(node_emb[mask_num]) loss += (self.add_weight * self.masking_criterion( pred_emb.double(), batch.x[mask_num, 0])) print(loss) new_grad, new_params = self.update_params( loss, update_lr=self.update_lr) # if self.add_similarity: # new_params = self.update_similarity_params(new_grad, support_grads) vector_to_parameters(new_params, self.graph_model.parameters()) y_true = [] y_scores = [] for step, batch in enumerate(tqdm(query_loader, desc="Iteration")): batch = batch.to(device) pred, node_emb = self.graph_model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) # print(pred) pred = F.sigmoid(pred) pred = torch.where(pred > 0.5, torch.ones_like(pred), pred) pred = torch.where(pred <= 0.5, torch.zeros_like(pred), pred) y_scores.append(pred) y_true.append(batch.y.view(pred.shape)) y_true = torch.cat(y_true, dim=0).cpu().detach().numpy() y_scores = torch.cat(y_scores, dim=0).cpu().detach().numpy() roc_list = [] roc_list.append(roc_auc_score(y_true, y_scores)) acc = sum(roc_list) / len(roc_list) accs.append(acc) vector_to_parameters(old_params, self.graph_model.parameters()) return accs