optimizer.zero_grad() # print('term', term) # print('law', law) # print('accu', accu) pred1, pred2, pred3 = model(curr_fact) accu_loss = loss_function(pred1, accu) law_loss = loss_function(pred2, law) term_loss = loss_function(pred3, term) loss = accu_loss + law_loss + term_loss loss.backward() optimizer.step() total_loss += loss.item() truth1, truth2, truth3 = [], [], [] pred1_list, pred2_list, pred3_list = [], [], [] for valid_data in test_dataloader: model.eval() valid_fact, term, law, accu = to_cuda(valid_data) pred1, pred2, pred3 = model(valid_fact) pred1_list.append(torch.argmax(pred1, dim=-1).cpu().data.numpy()) pred2_list.append(torch.argmax(pred2, dim=-1).cpu().data.numpy()) pred3_list.append(torch.argmax(pred3, dim=-1).cpu().data.numpy()) truth1.append(accu.cpu().data.numpy()) truth2.append(law.cpu().data.numpy()) truth3.append(term.cpu().data.numpy()) print('task1_acc', accuracy_score(np.concatenate(truth1), np.concatenate(pred1_list))) print('task2_acc', accuracy_score(np.concatenate(truth2), np.concatenate(pred2_list))) print('task3_acc', accuracy_score(np.concatenate(truth3), np.concatenate(pred3_list))) print(
def main(miRNA_Disease_Association, disease_feature, disease_graph1, disease_graph2, disease_graph3, miRNA_feature, miRNA_graph1, miRNA_graph2, miRNA_graph3): adjProcess = adjTrainTestSplit(miRNA_Disease_Association) graph_train_kfold, graph_test_kfold = adjProcess.split_graph(KFold, SEED) auc_kfold = [] aupr_kfold = [] mean_tpr = 0.0 # 用来记录画平均ROC曲线的信息 mean_fpr = np.linspace(0, 1, 100) for i in range(KFold): print("Using {} th fold dataset.".format(i + 1)) graph_train = graph_train_kfold[i] graph_test = graph_test_kfold[i] # m = graph_train.shape[0] # n = graph_train.shape[1] # eval_coord = [(i, j) for i in range(m) for j in range(n)] # train_edge_x, train_edge_y = graph_train.nonzero() # one_index = list(zip(train_edge_x, train_edge_y)) # zero_index = set(eval_coord) - set(set(zip(train_edge_x, train_edge_y))) # zero_index = list(zero_index) adj_traget = torch.FloatTensor(graph_train) model = MyModel(disease_feature, disease_graph1, disease_graph2, disease_graph3, miRNA_feature, miRNA_graph1, miRNA_graph2, miRNA_graph3) model.cuda() obj = Myloss(adj_traget.cuda()) optimizer = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True, weight_decay=GAMA) evaluator = Evaluator(graph_train, graph_test) #obj_test = Myloss(torch.FloatTensor(graph_test).cuda()) for j in range(EPOCH): model.train() optimizer.zero_grad() Y_hat, m_x0, m_x1, m_x2, m_x3, d_x0, d_x1, d_x2, d_x3 = model() loss = obj.cal_loss(Y_hat, m_x0.cuda(), m_x1.cuda(), m_x2.cuda(), m_x3.cuda(), d_x0.cuda(), d_x1.cuda(), d_x2.cuda(), d_x3.cuda(), ALPHA, BETA, GAMA) loss = loss.cuda() # loss = obj.cal_loss(Y_hat,one_index,zero_index) #loss = obj.cal_loss(Y_hat,one_index,zero_index,m_x0,m_x1,m_x2,m_x3,d_x0, d_x1, d_x2,d_x3) loss.backward() optimizer.step() need_early_stop_check = j > TOLERANCE_EPOCH and abs( (loss.item() - last_loss) / last_loss) < STOP_THRESHOLD if (j % EVAL_INTER == 0) or need_early_stop_check or j + 1 >= EPOCH: t = time.time() model.eval() with torch.no_grad(): Y_hat, m_x0, m_x1, m_x2, m_x3, d_x0, d_x1, d_x2, d_x3 = model( ) #test_loss = obj_test.cal_loss(Y_hat, m_x0, m_x1, m_x2, m_x3, d_x0, d_x1, d_x2, d_x3,ALPHA,BETA) # Y_hat = torch.sigmoid(Y_hat) # eval_coord = [(i, j) for i in range(m) for j in range(n)] # test_edge_x, test_edge_y = graph_test.nonzero() # test_one_index = list(zip(test_edge_x, test_edge_y)) # #test_zero_index = set(eval_coord) - set(set(zip(test_edge_x, test_edge_y))) # test_zero_index = list(test_zero_index) # #test_loss = obj_test.cal_loss(Y_hat, test_one_index, test_zero_index) auc_test, aupr_test, fpr, tpr = evaluator.eval(Y_hat.cpu()) print("Epoch:", '%04d' % (j + 1), "train_loss=", "{:0>9.5f}".format(loss.item()), "test_auc=", "{:.5f}".format(auc_test), "test_aupr=", "{:.5f}".format(aupr_test), "time=", "{:.2f}".format(time.time() - t)) if need_early_stop_check or j + 1 >= EPOCH: auc_kfold.append(auc_test) aupr_kfold.append(aupr_test) mean_tpr += np.interp(mean_fpr, fpr, tpr) mean_tpr[0] = 0.0 if need_early_stop_check: print("Early stopping...") else: print("Arrived at the last Epoch...") break last_loss = loss.item() torch.cuda.empty_cache() print("\nOptimization Finished!") mean_tpr /= KFold mean_tpr[-1] = 1.0 np.save("../Data/Result/mean_tpr.npy", mean_tpr) mean_auc = sum(auc_kfold) / len(auc_kfold) mean_aupr = sum(aupr_kfold) / len(aupr_kfold) print("mean_auc:{0:.3f},mean_aupr:{1:.3f}".format(mean_auc, mean_aupr))