def main(args): date_time = datetime.now().strftime('%m-%d-%H:%M:%S') log_path = os.path.join(args.log_root, args.log_path, args.save_name, date_time) load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1] if load_func == 'WebKB': load_func = WebKB dataset = load_func(root=args.data_path, name=subset) elif load_func == 'WikipediaNetwork': load_func = WikipediaNetwork dataset = load_func(root=args.data_path, name=subset) elif load_func == 'WikiCS': load_func = WikiCS dataset = load_func(root=args.data_path) elif load_func == 'cora_ml': dataset = load_citation_link( root='../dataset/data/tmp/cora_ml/cora_ml.npz') elif load_func == 'citeseer': dataset = load_citation_link( root='../dataset/data/tmp/citeseer_npz/citeseer_npz.npz') #load telegram/synthetic here else: dataset = load_syn(args.data_path + args.dataset, None) if os.path.isdir(log_path) == False: os.makedirs(log_path) # load dataset if 'dataset' in locals(): data = dataset[0] edge_index = data.edge_index size = torch.max(edge_index).item() + 1 size = torch.max(edge_index).item() + 1 # generate edge index dataset #if args.task == 2: # datasets = generate_dataset_2class(edge_index, splits = 10, test_prob = args.drop_prob) #else: save_file = args.data_path + args.dataset + '/' + subset datasets = generate_dataset_3class(edge_index, size, save_file, splits=10, probs=args.split_prob, task=args.task, label_dim=args.num_class_link) if args.task != 2: results = np.zeros((10, 4)) else: results = np.zeros((10, 4, 5)) for i in range(10): log_str_full = '' edges = datasets[i]['graph'] if args.to_undirected: edges = to_undirected(edges) ######################################## # initialize model and load dataset ######################################## #x = torch.ones(size).unsqueeze(-1).to(device) x = in_out_degree(edges, size).to(device) edges = edges.long().to(device) model = GIN_Link(x.size(-1), args.num_class_link, filter_num=args.num_filter, dropout=args.dropout).to(device) #model = nn.DataParallel(graphmodel) opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2) y_train = datasets[i]['train']['label'] y_val = datasets[i]['validate']['label'] y_test = datasets[i]['test']['label'] y_train = torch.from_numpy(y_train).long().to(device) y_val = torch.from_numpy(y_val).long().to(device) y_test = torch.from_numpy(y_test).long().to(device) train_index = torch.from_numpy( datasets[i]['train']['pairs']).to(device) val_index = torch.from_numpy( datasets[i]['validate']['pairs']).to(device) test_index = torch.from_numpy(datasets[i]['test']['pairs']).to(device) ################################# # Train/Validation/Test ################################# best_test_err = 1000.0 early_stopping = 0 for epoch in range(args.epochs): start_time = time.time() if early_stopping > 500: break #################### # Train #################### train_loss, train_acc = 0.0, 0.0 model.train() out = model(x, edges, train_index) train_loss = F.nll_loss(out, y_train) pred_label = out.max(dim=1)[1] train_acc = acc(pred_label, y_train) opt.zero_grad() train_loss.backward() opt.step() outstrtrain = 'Train loss: %.6f, acc: %.3f' % ( train_loss.detach().item(), train_acc) #################### # Validation #################### train_loss, train_acc = 0.0, 0.0 model.eval() out = model(x, edges, val_index) test_loss = F.nll_loss(out, y_val) pred_label = out.max(dim=1)[1] test_acc = acc(pred_label, y_val) outstrval = ' Test loss: %.6f, acc: %.3f' % ( test_loss.detach().item(), test_acc) duration = "--- %.4f seconds ---" % (time.time() - start_time) log_str = ( "%d / %d epoch" % (epoch, args.epochs)) + outstrtrain + outstrval + duration #print(log_str) log_str_full += log_str + '\n' #################### # Save weights #################### save_perform = test_loss.detach().item() if save_perform <= best_test_err: early_stopping = 0 best_test_err = save_perform torch.save(model.state_dict(), log_path + '/model' + str(i) + '.t7') else: early_stopping += 1 write_log(vars(args), log_path) torch.save(model.state_dict(), log_path + '/model_latest' + str(i) + '.t7') if args.task != 2: #################### # Testing #################### model.load_state_dict( torch.load(log_path + '/model' + str(i) + '.t7')) model.eval() out = model(x, edges, val_index)[:, :2] pred_label = out.max(dim=1)[1] val_acc = acc(pred_label, y_val) out = model(x, edges, test_index)[:, :2] pred_label = out.max(dim=1)[1] test_acc = acc(pred_label, y_test) model.load_state_dict( torch.load(log_path + '/model_latest' + str(i) + '.t7')) model.eval() out = model(x, edges, val_index)[:, :2] pred_label = out.max(dim=1)[1] val_acc_latest = acc(pred_label, y_val) out = model(x, edges, test_index)[:, :2] pred_label = out.max(dim=1)[1] test_acc_latest = acc(pred_label, y_test) #################### # Save testing results #################### log_str = ('val_acc: {val_acc:.4f}, ' + 'test_acc: {test_acc:.4f}, ') log_str1 = log_str.format(val_acc=val_acc, test_acc=test_acc) log_str_full += log_str1 log_str = ('val_acc_latest: {val_acc_latest:.4f}, ' + 'test_acc_latest: {test_acc_latest:.4f}, ') log_str2 = log_str.format(val_acc_latest=val_acc_latest, test_acc_latest=test_acc_latest) log_str_full += log_str2 + '\n' print(log_str1 + log_str2) results[i] = [val_acc, test_acc, val_acc_latest, test_acc_latest] else: model.load_state_dict( torch.load(log_path + '/model' + str(i) + '.t7')) model.eval() out_val = model(x, edges, val_index) out_test = model(x, edges, test_index) [[val_acc_full, val_acc, val_auc, val_f1_micro, val_f1_macro], [test_acc_full, test_acc, test_auc, test_f1_micro, test_f1_macro] ] = link_prediction_evaluation(out_val, out_test, y_val, y_test) model.load_state_dict( torch.load(log_path + '/model_latest' + str(i) + '.t7')) model.eval() out_val = model(x, edges, val_index) out_test = model(x, edges, test_index) [[ val_acc_full_latest, val_acc_latest, val_auc_latest, val_f1_micro_latest, val_f1_macro_latest ], [ test_acc_full_latest, test_acc_latest, test_auc_latest, test_f1_micro_latest, test_f1_macro_latest ]] = link_prediction_evaluation(out_val, out_test, y_val, y_test) #################### # Save testing results #################### log_str = ( 'val_acc_full:{val_acc_full:.4f}, val_acc: {val_acc:.4f}, Val_auc: {val_auc:.4f},' + 'val_f1_micro: {val_f1_micro:.4f}, val_f1_macro: {val_f1_macro:.4f}, ' + 'test_acc_full:{test_acc_full:.4f}, test_acc: {test_acc:.4f}, ' + 'test_f1_micro: {test_f1_micro:.4f}, test_f1_macro: {test_f1_macro:.4f}' ) log_str = log_str.format(val_acc_full=val_acc_full, val_acc=val_acc, val_auc=val_auc, val_f1_micro=val_f1_micro, val_f1_macro=val_f1_macro, test_acc_full=test_acc_full, test_acc=val_acc, test_f1_micro=val_f1_micro, test_f1_macro=val_f1_macro) log_str_full += log_str + '\n' print(log_str) log_str = ( 'val_acc_full_latest:{val_acc_full_latest:.4f}, val_acc_latest: {val_acc_latest:.4f}, Val_auc_latest: {val_auc_latest:.4f},' + 'val_f1_micro_latest: {val_f1_micro_latest:.4f}, val_f1_macro_latest: {val_f1_macro_latest:.4f},' + 'test_acc_full_latest:{test_acc_full_latest:.4f}, test_acc_latest: {test_acc_latest:.4f}, ' + 'test_f1_micro_latest: {test_f1_micro_latest:.4f}, test_f1_macro_latest: {test_f1_macro_latest:.4f}' ) log_str = log_str.format(val_acc_full_latest=val_acc_full_latest, val_acc_latest=val_acc_latest, val_auc_latest=val_auc_latest, val_f1_micro_latest=test_f1_micro_latest, val_f1_macro_latest=val_f1_macro_latest, test_acc_full_latest=test_acc_full_latest, test_acc_latest=val_acc, test_f1_micro_latest=test_f1_micro_latest, test_f1_macro_latest=test_f1_macro_latest) log_str_full += log_str + '\n' print(log_str) results[i] = [ [val_acc_full, val_acc, val_auc, val_f1_micro, val_f1_macro], [ test_acc_full, test_acc, test_auc, test_f1_micro, test_f1_macro ], [ val_acc_full_latest, val_acc_latest, val_auc_latest, val_f1_micro_latest, val_f1_macro_latest ], [ test_acc_full_latest, test_acc_latest, test_auc_latest, test_f1_micro_latest, test_f1_macro_latest ] ] with open(log_path + '/log' + str(i) + '.csv', 'w') as file: file.write(log_str_full) file.write('\n') torch.cuda.empty_cache() return results
def main(args): if args.randomseed > 0: torch.manual_seed(args.randomseed) date_time = datetime.now().strftime('%m-%d-%H:%M:%S') log_path = os.path.join(args.log_root, args.log_path, args.save_name, date_time) if os.path.isdir(log_path) == False: try: os.makedirs(log_path) except FileExistsError: print('Folder exists!') load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1] if load_func == 'WebKB': load_func = WebKB dataset = load_func(root=args.data_path, name=subset) elif load_func == 'WikipediaNetwork': load_func = WikipediaNetwork dataset = load_func(root=args.data_path, name=subset) elif load_func == 'WikiCS': load_func = WikiCS dataset = load_func(root=args.data_path) elif load_func == 'cora_ml': dataset = citation_datasets( root='../dataset/data/tmp/cora_ml/cora_ml.npz') elif load_func == 'citeseer_npz': dataset = citation_datasets( root='../dataset/data/tmp/citeseer_npz/citeseer_npz.npz') else: dataset = load_syn(args.data_path + args.dataset, None) if os.path.isdir(log_path) == False: os.makedirs(log_path) data = dataset[0] if not data.__contains__('edge_weight'): data.edge_weight = None data.y = data.y.long() num_classes = (data.y.max() - data.y.min() + 1).detach().numpy() data = data.to(device) if data.edge_weight is not None: data.edge_weight = torch.FloatTensor(data.edge_weight).to(device) if args.to_undirected: data.edge_index, data.edge_weight = to_undirected( data.edge_index, data.edge_weight) # normalize label, the minimum should be 0 as class index splits = data.train_mask.shape[1] if len(data.test_mask.shape) == 1: data.test_mask = data.test_mask.unsqueeze(1).repeat(1, splits) results = np.zeros((splits, 4)) for split in range(splits): log_str_full = '' graphmodel = APPNP_Model(data.x.size(-1), num_classes, filter_num=args.num_filter, alpha=args.alpha, dropout=args.dropout, layer=args.layer).to(device) model = graphmodel # nn.DataParallel(graphmodel) opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2) ################################# # Train/Validation/Test ################################# best_test_err = 1000.0 early_stopping = 0 for epoch in range(args.epochs): start_time = time.time() #################### # Train #################### train_loss, train_acc = 0.0, 0.0 # for loop for batch loading model.train() out = model(data) train_loss = F.nll_loss(out[data.train_mask[:, split]], data.y[data.train_mask[:, split]]) pred_label = out.max(dim=1)[1] train_acc = acc(pred_label, data.y, data.train_mask[:, split]) opt.zero_grad() train_loss.backward() opt.step() outstrtrain = 'Train loss:, %.6f, acc:, %.3f,' % ( train_loss.detach().item(), train_acc) #scheduler.step() #################### # Validation #################### model.eval() test_loss, test_acc = 0.0, 0.0 out = model(data) pred_label = out.max(dim=1)[1] test_loss = F.nll_loss(out[data.val_mask[:, split]], data.y[data.val_mask[:, split]]) test_acc = acc(pred_label, data.y, data.val_mask[:, split]) outstrval = ' Test loss:, %.6f, acc: ,%.3f,' % ( test_loss.detach().item(), test_acc) duration = "---, %.4f, seconds ---" % (time.time() - start_time) log_str = ( "%d, / ,%d, epoch," % (epoch, args.epochs)) + outstrtrain + outstrval + duration log_str_full += log_str + '\n' #print(log_str) #################### # Save weights #################### save_perform = test_loss.detach().item() if save_perform <= best_test_err: early_stopping = 0 best_test_err = save_perform torch.save(model.state_dict(), log_path + '/model' + str(split) + '.t7') else: early_stopping += 1 if early_stopping > 500 or epoch == (args.epochs - 1): torch.save(model.state_dict(), log_path + '/model_latest' + str(split) + '.t7') break write_log(vars(args), log_path) #################### # Testing #################### model.load_state_dict( torch.load(log_path + '/model' + str(split) + '.t7')) model.eval() preds = model(data) pred_label = preds.max(dim=1)[1] np.save(log_path + '/pred' + str(split), pred_label.to('cpu')) acc_train = acc(pred_label, data.y, data.val_mask[:, split]) acc_test = acc(pred_label, data.y, data.test_mask[:, split]) model.load_state_dict( torch.load(log_path + '/model_latest' + str(split) + '.t7')) model.eval() preds = model(data) pred_label = preds.max(dim=1)[1] np.save(log_path + '/pred_latest' + str(split), pred_label.to('cpu')) acc_train_latest = acc(pred_label, data.y, data.val_mask[:, split]) acc_test_latest = acc(pred_label, data.y, data.test_mask[:, split]) #################### # Save testing results #################### logstr = 'val_acc: ' + str( np.round(acc_train, 3)) + ' test_acc: ' + str(np.round( acc_test, 3)) + ' val_acc_latest: ' + str( np.round(acc_train_latest, 3)) + ' test_acc_latest: ' + str( np.round(acc_test_latest, 3)) print(logstr) results[split] = [ acc_train, acc_test, acc_train_latest, acc_test_latest ] log_str_full += logstr with open(log_path + '/log' + str(split) + '.csv', 'w') as file: file.write(log_str_full) file.write('\n') torch.cuda.empty_cache() return results
def main(args): if args.randomseed > 0: torch.manual_seed(args.randomseed) date_time = datetime.now().strftime('%m-%d-%H:%M:%S') log_path = os.path.join(args.log_root, args.log_path, args.save_name, date_time) if os.path.isdir(log_path) == False: try: os.makedirs(log_path) except FileExistsError: print('Folder exists!') load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1] if load_func == 'WebKB': load_func = WebKB elif load_func == 'WikipediaNetwork': load_func = WikipediaNetwork elif load_func == 'WikiCS': load_func = WikiCS elif load_func == 'cora_ml': load_func = citation_datasets elif load_func == 'citeseer_npz': load_func = citation_datasets else: load_func = load_syn _file_ = args.data_path + args.dataset + '/data' + str(args.q) + '_' + str( args.K) + '_sparse.pk' if os.path.isfile(_file_): data = pk.load(open(_file_, 'rb')) L = data['L'] X, label, train_mask, val_mask, test_mask = geometric_dataset_sparse( args.q, args.K, root=args.data_path + args.dataset, subset=subset, dataset=load_func, load_only=True, save_pk=False) else: X, label, train_mask, val_mask, test_mask, L = geometric_dataset_sparse( args.q, args.K, root=args.data_path + args.dataset, subset=subset, dataset=load_func, load_only=False, save_pk=True) # normalize label, the minimum should be 0 as class index _label_ = label - np.amin(label) cluster_dim = np.amax(_label_) + 1 # convert dense laplacian to sparse matrix L_img = [] L_real = [] for i in range(len(L)): L_img.append(sparse_mx_to_torch_sparse_tensor(L[i].imag).to(device)) L_real.append(sparse_mx_to_torch_sparse_tensor(L[i].real).to(device)) label = torch.from_numpy(_label_[np.newaxis]).to(device) X_img = torch.FloatTensor(X).to(device) X_real = torch.FloatTensor(X).to(device) criterion = nn.NLLLoss() splits = train_mask.shape[1] if len(test_mask.shape) == 1: #data.test_mask = test_mask.unsqueeze(1).repeat(1, splits) test_mask = np.repeat(test_mask[:, np.newaxis], splits, 1) results = np.zeros((splits, 4)) for split in range(splits): log_str_full = '' model = ChebNet(X_real.size(-1), L_real, L_img, K=args.K, label_dim=cluster_dim, layer=args.layer, activation=args.activation, num_filter=args.num_filter, dropout=args.dropout).to(device) opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2) best_test_acc = 0.0 train_index = train_mask[:, split] val_index = val_mask[:, split] test_index = test_mask[:, split] ################################# # Train/Validation/Test ################################# best_test_err = 1000.0 early_stopping = 0 for epoch in range(args.epochs): start_time = time.time() #################### # Train #################### count, train_loss, train_acc = 0.0, 0.0, 0.0 # for loop for batch loading count += np.sum(train_index) model.train() preds = model(X_real, X_img) train_loss = criterion(preds[:, :, train_index], label[:, train_index]) pred_label = preds.max(dim=1)[1] train_acc = 1.0 * ( (pred_label[:, train_index] == label[:, train_index])).sum().detach().item() / count opt.zero_grad() train_loss.backward() opt.step() outstrtrain = 'Train loss:, %.6f, acc:, %.3f,' % ( train_loss.detach().item(), train_acc) #scheduler.step() #################### # Validation #################### model.eval() count, test_loss, test_acc = 0.0, 0.0, 0.0 # for loop for batch loading count += np.sum(val_index) preds = model(X_real, X_img) pred_label = preds.max(dim=1)[1] test_loss = criterion(preds[:, :, val_index], label[:, val_index]) test_acc = 1.0 * ( (pred_label[:, val_index] == label[:, val_index])).sum().detach().item() / count outstrval = ' Test loss:, %.6f, acc:, %.3f,' % ( test_loss.detach().item(), test_acc) duration = "---, %.4f, seconds ---" % (time.time() - start_time) log_str = ( "%d ,/, %d ,epoch," % (epoch, args.epochs)) + outstrtrain + outstrval + duration log_str_full += log_str + '\n' #print(log_str) #################### # Save weights #################### save_perform = test_loss.detach().item() if save_perform <= best_test_err: early_stopping = 0 best_test_err = save_perform torch.save(model.state_dict(), log_path + '/model' + str(split) + '.t7') else: early_stopping += 1 if early_stopping > 500 or epoch == (args.epochs - 1): torch.save(model.state_dict(), log_path + '/model_latest' + str(split) + '.t7') break write_log(vars(args), log_path) #################### # Testing #################### model.load_state_dict( torch.load(log_path + '/model' + str(split) + '.t7')) model.eval() preds = model(X_real, X_img) pred_label = preds.max(dim=1)[1] np.save(log_path + '/pred' + str(split), pred_label.to('cpu')) count = np.sum(val_index) acc_train = (1.0 * ((pred_label[:, val_index] == label[:, val_index])).sum().detach().item()) / count count = np.sum(test_index) acc_test = (1.0 * ((pred_label[:, test_index] == label[:, test_index])).sum().detach().item()) / count model.load_state_dict( torch.load(log_path + '/model_latest' + str(split) + '.t7')) model.eval() preds = model(X_real, X_img) pred_label = preds.max(dim=1)[1] np.save(log_path + '/pred_latest' + str(split), pred_label.to('cpu')) count = np.sum(val_index) acc_train_latest = (1.0 * ( (pred_label[:, val_index] == label[:, val_index])).sum().detach().item()) / count count = np.sum(test_index) acc_test_latest = (1.0 * ( (pred_label[:, test_index] == label[:, test_index])).sum().detach().item()) / count #################### # Save testing results #################### logstr = 'val_acc: ' + str( np.round(acc_train, 3)) + ' test_acc: ' + str(np.round( acc_test, 3)) + ' val_acc_latest: ' + str( np.round(acc_train_latest, 3)) + ' test_acc_latest: ' + str( np.round(acc_test_latest, 3)) print(logstr) results[split] = [ acc_train, acc_test, acc_train_latest, acc_test_latest ] log_str_full += logstr with open(log_path + '/log' + str(split) + '.csv', 'w') as file: file.write(log_str_full) file.write('\n') torch.cuda.empty_cache() return results