def train_GRAM(seqFile='seqFile.txt', labelFile='labelFile.txt', treeFile='tree.txt', embFile='embFile.txt', outFile='out.txt', inputDimSize=100, numAncestors=100, embDimSize=100, hiddenDimSize=200, attentionDimSize=200, max_epochs=100, L2=0., numClass=26679, batchSize=100, dropoutRate=0.5, logEps=1e-8, verbose=True, ignore_level=0): options = locals().copy() # 这里的leavesList, ancestorsList蕴含着每一个疾病的类别信息 leavesList = [] ancestorsList = [] for i in range(5, 0, -1): leaves, ancestors = build_tree(treeFile + '.level' + str(i) + '.pk') leavesList.append(leaves) ancestorsList.append(ancestors) print('Building the model ... ') gram = GRAM(inputDimSize, numAncestors, embDimSize, hiddenDimSize, attentionDimSize, numClass, dropoutRate, embFile) # if torch.cuda.device_count() > 1: # print("Let's use", torch.cuda.device_count(), "GPUs!") # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs # gram = nn.DataParallel(gram) gram.to(device) # gram.train() print(list(gram.state_dict())) loss_fn = CrossEntropy() loss_fn.to(device) print('Constructing the optimizer ... ') optimizer = torch.optim.Adadelta(gram.parameters(), lr=1, weight_decay=L2) print('Loading data ... ') trainSet, validSet, testSet = load_data(seqFile, labelFile, test_ratio=0.15, valid_ratio=0.1) print('Data length:', len(trainSet[0])) n_batches = int(np.ceil(float(len(trainSet[0])) / float(batchSize))) val_batches = int(np.ceil(float(len(validSet[0])) / float(batchSize))) test_batches = int(np.ceil(float(len(testSet[0])) / float(batchSize))) print('Optimization start !!') # setting the tensorboard loss_writer = SummaryWriter('{}/{}'.format(outFile + 'TbLog', 'Loss')) acc_writer = SummaryWriter('{}/{}'.format(outFile + 'TbLog', 'Acc')) # test_writer = SummaryWriter('{}/{}'.format(outFile+'TbLog', 'Test')) logFile = outFile + '.log' bestTrainCost = 0.0 bestValidCost = 100000.0 bestTestCost = 0.0 bestTrainAcc = 0.0 bestValidAcc = 0.0 bestTestAcc = 0.0 epochDuration = 0.0 bestEpoch = 0 # set the random seed for test random.seed(seed) # with torchsnooper.snoop(): for epoch in range(max_epochs): iteration = 0 cost_vec = [] acc_vec = [] startTime = time.time() gram.train() for index in random.sample(range(n_batches), n_batches): optimizer.zero_grad() batchX = trainSet[0][index * batchSize:(index + 1) * batchSize] batchY = trainSet[1][index * batchSize:(index + 1) * batchSize] x, y, mask, lengths = padMatrix(batchX, batchY, options) x = torch.from_numpy(x).to(device).float() mask = torch.from_numpy(mask).to(device).float() # print('x,', x.size()) y_hat = gram(x, mask, leavesList, ancestorsList) # print('y_hat', y_hat.size()) y = torch.from_numpy(y).float().to(device) # print('y', y.size()) lengths = torch.from_numpy(lengths).float().to(device) # print(y.size(), y_hat.size()) loss, acc = loss_fn(y_hat, y, lengths) loss.backward() optimizer.step() if iteration % 100 == 0 and verbose: buf = 'Epoch:%d, Iteration:%d/%d, Train_Cost:%f, Train_Acc:%f' % ( epoch, iteration, n_batches, loss, acc) print(buf) cost_vec.append(loss.item()) acc_vec.append(acc) iteration += 1 duration_optimize = time.time() - startTime gram.eval() cost = np.mean(cost_vec) acc = np.mean(acc_vec) startTime = time.time() with torch.no_grad(): # calculate the loss and acc of valid dataset cost_vec = [] acc_vec = [] for index in range(val_batches): validX = validSet[0][index * batchSize:(index + 1) * batchSize] validY = validSet[1][index * batchSize:(index + 1) * batchSize] val_x, val_y, mask, lengths = padMatrix( validX, validY, options) val_x = torch.from_numpy(val_x).float().to(device) mask = torch.from_numpy(mask).float().to(device) val_y_hat = gram(val_x, mask, leavesList, ancestorsList) val_y = torch.from_numpy(val_y).float().to(device) lengths = torch.from_numpy(lengths).float().to(device) valid_cost, valid_acc = loss_fn(val_y_hat, val_y, lengths) cost_vec.append(valid_cost.item()) acc_vec.append(valid_acc) valid_cost = np.mean(cost_vec) valid_acc = np.mean(acc_vec) # calculate the loss and acc of test dataset cost_vec = [] acc_vec = [] for index in range(test_batches): testX = testSet[0][index * batchSize:(index + 1) * batchSize] testY = testSet[1][index * batchSize:(index + 1) * batchSize] test_x, test_y, mask, lengths = padMatrix( testX, testY, options) test_x = torch.from_numpy(test_x).float().to(device) mask = torch.from_numpy(mask).float().to(device) test_y_hat = gram(test_x, mask, leavesList, ancestorsList) test_y = torch.from_numpy(test_y).float().to(device) lengths = torch.from_numpy(lengths).float().to(device) test_cost, test_acc = loss_fn(test_y_hat, test_y, lengths) cost_vec.append(test_cost.item()) acc_vec.append(test_acc) test_cost = np.mean(cost_vec) test_acc = np.mean(acc_vec) # record the loss and acc loss_writer.add_scalar('Train Loss', cost, epoch) loss_writer.add_scalar('Test Loss', test_cost, epoch) loss_writer.add_scalar('Valid Loss', valid_cost, epoch) acc_writer.add_scalar('Train Acc', acc, epoch) acc_writer.add_scalar('Test Acc', test_acc, epoch) acc_writer.add_scalar('Valid Acc', valid_acc, epoch) # print the loss duration_metric = time.time() - startTime buf = 'Epoch:%d, Train_Cost:%f, Valid_Cost:%f, Test_Cost:%f' % ( epoch, cost, valid_cost, test_cost) print(buf) print2file(buf, logFile) buf = 'Train_Acc:%f, Valid_Acc:%f, Test_Acc:%f' % (acc, valid_acc, test_acc) print(buf) print2file(buf, logFile) buf = 'Optimize_Duration:%f, Metric_Duration:%f' % (duration_optimize, duration_metric) print(buf) print2file(buf, logFile) # save the best model if valid_cost < bestValidCost: bestValidCost = valid_cost bestTestCost = test_cost bestTrainCost = cost bestEpoch = epoch bestTrainAcc = acc bestValidAcc = valid_acc bestTestAcc = test_acc torch.save(gram.state_dict(), outFile + f'.{epoch}') buf = 'Best Epoch:%d, Avg_Duration:%f, Train_Cost:%f, Valid_Cost:%f, Test_Cost:%f' % ( bestEpoch, epochDuration / max_epochs, bestTrainCost, bestValidCost, bestTestCost) print(buf) print2file(buf, logFile) buf = 'Train_Acc:%f, Valid_Acc:%f, Test_Acc:%f' % ( bestTrainAcc, bestValidAcc, bestTestAcc) print(buf) print2file(buf, logFile)
def test_whole_data(seqFile='seqFile.txt', labelFile='labelFile.txt', treeFile='tree.txt', embFile='embFile.txt', outFile='out.txt', inputDimSize=100, numAncestors=100, embDimSize=100, hiddenDimSize=200, attentionDimSize=200, max_epochs=100, L2=0., numClass=26679, batchSize=100, dropoutRate=0.5, logEps=1e-8, verbose=True, ignore_level=0): options = locals().copy() # get the best model through log # with open(outFile+'.log') as f: # line = f.readlines()[-2] # best_epoch = line.split(',')[0].split(':')[1] # print('Best parameters occur epoch:', best_epoch) leavesList = [] ancestorsList = [] for i in range(5, 0, -1): leaves, ancestors = build_tree(treeFile + '.level' + str(i) + '.pk') leavesList.append(leaves) ancestorsList.append(ancestors) print('Loading the model ... ') # create the model gram = GRAM(inputDimSize, numAncestors, embDimSize, hiddenDimSize, attentionDimSize, numClass, dropoutRate, '').to(device) # read the best parameters # gram.load_state_dict(torch.load(outFile + '.' + best_epoch)) gram.load_state_dict(torch.load(embFile)) loss_fn = CrossEntropy() loss_fn.to(device) print('Loading the data ... ') dataset, _, _ = load_data(seqFile, labelFile, test_ratio=0, valid_ratio=0) typeFile = labelFile.split('.seqs')[0] + '.types' types = pickle.load(open(typeFile, 'rb')) rTypes = dict([(v, u) for u, v in types.items()]) print('Data length:', len(dataset[0])) n_batches = int(np.ceil(float(len(dataset[0])) / float(batchSize))) print('Calculating the result ...') cost_vec = [] acc_vec = [] num_for_each_disease = defaultdict(float) TP_for_each_disease = defaultdict(float) rank_for_each_disease = defaultdict(float) for index in range(n_batches): batchX = dataset[0][index * batchSize:(index + 1) * batchSize] batchY = dataset[1][index * batchSize:(index + 1) * batchSize] x, y, mask, lengths = padMatrix(batchX, batchY, options) x = torch.from_numpy(x).to(device).float() mask = torch.from_numpy(mask).to(device).float() y_hat = gram(x, mask, leavesList, ancestorsList) y = torch.from_numpy(y).float().to(device) lengths = torch.from_numpy(lengths).float().to(device) loss, acc = loss_fn(y_hat, y, lengths) cost_vec.append(loss.item()) acc_vec.append(acc) # Calculating the accuracy for each disease y_sorted, indices = torch.sort(y_hat, dim=2, descending=True) # indices = indices[:, :, :20] for i, j, k in torch.nonzero(y, as_tuple=False): k = k.item() num_for_each_disease[k] += 1 # search the rank for k m = torch.nonzero(indices[i][j] == k, as_tuple=False).view(-1).item() # calculate the top20 accuracy if m < 20: TP_for_each_disease[k] += 1 rank_for_each_disease[k] += (m + 1) cost = np.mean(cost_vec) acc = np.mean(acc_vec) print('Whole data average loss:%f, average accuracy@20:%f,' % (cost, acc)) print('Recording the accuracy for each disease ...') acc_out_file = outFile + '_all_acc.txt' # sort the disease by num num_for_each_disease = OrderedDict( sorted(num_for_each_disease.items(), key=lambda item: item[1], reverse=True)) for disease in num_for_each_disease.keys(): d_acc = TP_for_each_disease[disease] / num_for_each_disease[disease] avg_rank = rank_for_each_disease[disease] / num_for_each_disease[ disease] buf = 'TypeNum:%d, icd_code:%s, Count:%d, avg_rank:%f, Accuracy:%f' % \ (disease, rTypes[disease], num_for_each_disease[disease], avg_rank, d_acc) print2file(buf, acc_out_file) print('Done!')