コード例 #1
0
def train_GRAM(seqFile='seqFile.txt',
               labelFile='labelFile.txt',
               treeFile='tree.txt',
               embFile='embFile.txt',
               outFile='out.txt',
               inputDimSize=100,
               numAncestors=100,
               embDimSize=100,
               hiddenDimSize=200,
               attentionDimSize=200,
               max_epochs=100,
               L2=0.,
               numClass=26679,
               batchSize=100,
               dropoutRate=0.5,
               logEps=1e-8,
               verbose=True,
               ignore_level=0):
    options = locals().copy()
    # 这里的leavesList, ancestorsList蕴含着每一个疾病的类别信息
    leavesList = []
    ancestorsList = []
    for i in range(5, 0, -1):
        leaves, ancestors = build_tree(treeFile + '.level' + str(i) + '.pk')
        leavesList.append(leaves)
        ancestorsList.append(ancestors)

    print('Building the model ... ')
    gram = GRAM(inputDimSize, numAncestors, embDimSize, hiddenDimSize,
                attentionDimSize, numClass, dropoutRate, embFile)
    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    #     gram = nn.DataParallel(gram)
    gram.to(device)
    # gram.train()
    print(list(gram.state_dict()))
    loss_fn = CrossEntropy()
    loss_fn.to(device)

    print('Constructing the optimizer ... ')
    optimizer = torch.optim.Adadelta(gram.parameters(), lr=1, weight_decay=L2)

    print('Loading data ... ')
    trainSet, validSet, testSet = load_data(seqFile,
                                            labelFile,
                                            test_ratio=0.15,
                                            valid_ratio=0.1)
    print('Data length:', len(trainSet[0]))
    n_batches = int(np.ceil(float(len(trainSet[0])) / float(batchSize)))
    val_batches = int(np.ceil(float(len(validSet[0])) / float(batchSize)))
    test_batches = int(np.ceil(float(len(testSet[0])) / float(batchSize)))

    print('Optimization start !!')
    # setting the tensorboard
    loss_writer = SummaryWriter('{}/{}'.format(outFile + 'TbLog', 'Loss'))
    acc_writer = SummaryWriter('{}/{}'.format(outFile + 'TbLog', 'Acc'))
    # test_writer = SummaryWriter('{}/{}'.format(outFile+'TbLog', 'Test'))

    logFile = outFile + '.log'
    bestTrainCost = 0.0
    bestValidCost = 100000.0
    bestTestCost = 0.0
    bestTrainAcc = 0.0
    bestValidAcc = 0.0
    bestTestAcc = 0.0
    epochDuration = 0.0
    bestEpoch = 0
    # set the random seed for test
    random.seed(seed)
    # with torchsnooper.snoop():
    for epoch in range(max_epochs):
        iteration = 0
        cost_vec = []
        acc_vec = []
        startTime = time.time()
        gram.train()
        for index in random.sample(range(n_batches), n_batches):
            optimizer.zero_grad()
            batchX = trainSet[0][index * batchSize:(index + 1) * batchSize]
            batchY = trainSet[1][index * batchSize:(index + 1) * batchSize]
            x, y, mask, lengths = padMatrix(batchX, batchY, options)
            x = torch.from_numpy(x).to(device).float()
            mask = torch.from_numpy(mask).to(device).float()
            # print('x,', x.size())
            y_hat = gram(x, mask, leavesList, ancestorsList)
            # print('y_hat', y_hat.size())
            y = torch.from_numpy(y).float().to(device)
            # print('y', y.size())
            lengths = torch.from_numpy(lengths).float().to(device)
            # print(y.size(), y_hat.size())
            loss, acc = loss_fn(y_hat, y, lengths)
            loss.backward()
            optimizer.step()
            if iteration % 100 == 0 and verbose:
                buf = 'Epoch:%d, Iteration:%d/%d, Train_Cost:%f, Train_Acc:%f' % (
                    epoch, iteration, n_batches, loss, acc)
                print(buf)
            cost_vec.append(loss.item())
            acc_vec.append(acc)
            iteration += 1
        duration_optimize = time.time() - startTime
        gram.eval()
        cost = np.mean(cost_vec)
        acc = np.mean(acc_vec)
        startTime = time.time()
        with torch.no_grad():
            # calculate the loss and acc of valid dataset
            cost_vec = []
            acc_vec = []
            for index in range(val_batches):
                validX = validSet[0][index * batchSize:(index + 1) * batchSize]
                validY = validSet[1][index * batchSize:(index + 1) * batchSize]
                val_x, val_y, mask, lengths = padMatrix(
                    validX, validY, options)
                val_x = torch.from_numpy(val_x).float().to(device)
                mask = torch.from_numpy(mask).float().to(device)
                val_y_hat = gram(val_x, mask, leavesList, ancestorsList)
                val_y = torch.from_numpy(val_y).float().to(device)
                lengths = torch.from_numpy(lengths).float().to(device)
                valid_cost, valid_acc = loss_fn(val_y_hat, val_y, lengths)
                cost_vec.append(valid_cost.item())
                acc_vec.append(valid_acc)
            valid_cost = np.mean(cost_vec)
            valid_acc = np.mean(acc_vec)

            # calculate the loss and acc of test dataset
            cost_vec = []
            acc_vec = []
            for index in range(test_batches):
                testX = testSet[0][index * batchSize:(index + 1) * batchSize]
                testY = testSet[1][index * batchSize:(index + 1) * batchSize]
                test_x, test_y, mask, lengths = padMatrix(
                    testX, testY, options)
                test_x = torch.from_numpy(test_x).float().to(device)
                mask = torch.from_numpy(mask).float().to(device)
                test_y_hat = gram(test_x, mask, leavesList, ancestorsList)
                test_y = torch.from_numpy(test_y).float().to(device)
                lengths = torch.from_numpy(lengths).float().to(device)
                test_cost, test_acc = loss_fn(test_y_hat, test_y, lengths)
                cost_vec.append(test_cost.item())
                acc_vec.append(test_acc)
            test_cost = np.mean(cost_vec)
            test_acc = np.mean(acc_vec)
        # record the loss and acc
        loss_writer.add_scalar('Train Loss', cost, epoch)
        loss_writer.add_scalar('Test Loss', test_cost, epoch)
        loss_writer.add_scalar('Valid Loss', valid_cost, epoch)
        acc_writer.add_scalar('Train Acc', acc, epoch)
        acc_writer.add_scalar('Test Acc', test_acc, epoch)
        acc_writer.add_scalar('Valid Acc', valid_acc, epoch)

        # print the loss
        duration_metric = time.time() - startTime
        buf = 'Epoch:%d, Train_Cost:%f, Valid_Cost:%f, Test_Cost:%f' % (
            epoch, cost, valid_cost, test_cost)
        print(buf)
        print2file(buf, logFile)
        buf = 'Train_Acc:%f, Valid_Acc:%f, Test_Acc:%f' % (acc, valid_acc,
                                                           test_acc)
        print(buf)
        print2file(buf, logFile)
        buf = 'Optimize_Duration:%f, Metric_Duration:%f' % (duration_optimize,
                                                            duration_metric)
        print(buf)
        print2file(buf, logFile)

        # save the best model
        if valid_cost < bestValidCost:
            bestValidCost = valid_cost
            bestTestCost = test_cost
            bestTrainCost = cost
            bestEpoch = epoch
            bestTrainAcc = acc
            bestValidAcc = valid_acc
            bestTestAcc = test_acc

        torch.save(gram.state_dict(), outFile + f'.{epoch}')

    buf = 'Best Epoch:%d, Avg_Duration:%f, Train_Cost:%f, Valid_Cost:%f, Test_Cost:%f' % (
        bestEpoch, epochDuration / max_epochs, bestTrainCost, bestValidCost,
        bestTestCost)
    print(buf)
    print2file(buf, logFile)
    buf = 'Train_Acc:%f, Valid_Acc:%f, Test_Acc:%f' % (
        bestTrainAcc, bestValidAcc, bestTestAcc)
    print(buf)
    print2file(buf, logFile)
コード例 #2
0
def test_whole_data(seqFile='seqFile.txt',
                    labelFile='labelFile.txt',
                    treeFile='tree.txt',
                    embFile='embFile.txt',
                    outFile='out.txt',
                    inputDimSize=100,
                    numAncestors=100,
                    embDimSize=100,
                    hiddenDimSize=200,
                    attentionDimSize=200,
                    max_epochs=100,
                    L2=0.,
                    numClass=26679,
                    batchSize=100,
                    dropoutRate=0.5,
                    logEps=1e-8,
                    verbose=True,
                    ignore_level=0):
    options = locals().copy()
    # get the best model through log
    # with open(outFile+'.log') as f:
    #     line = f.readlines()[-2]
    #     best_epoch = line.split(',')[0].split(':')[1]
    #     print('Best parameters occur epoch:', best_epoch)

    leavesList = []
    ancestorsList = []
    for i in range(5, 0, -1):
        leaves, ancestors = build_tree(treeFile + '.level' + str(i) + '.pk')
        leavesList.append(leaves)
        ancestorsList.append(ancestors)

    print('Loading the model ... ')
    # create the model
    gram = GRAM(inputDimSize, numAncestors, embDimSize, hiddenDimSize,
                attentionDimSize, numClass, dropoutRate, '').to(device)
    # read the best parameters
    # gram.load_state_dict(torch.load(outFile + '.' + best_epoch))
    gram.load_state_dict(torch.load(embFile))
    loss_fn = CrossEntropy()
    loss_fn.to(device)

    print('Loading the data ... ')
    dataset, _, _ = load_data(seqFile, labelFile, test_ratio=0, valid_ratio=0)
    typeFile = labelFile.split('.seqs')[0] + '.types'
    types = pickle.load(open(typeFile, 'rb'))
    rTypes = dict([(v, u) for u, v in types.items()])
    print('Data length:', len(dataset[0]))
    n_batches = int(np.ceil(float(len(dataset[0])) / float(batchSize)))

    print('Calculating the result ...')
    cost_vec = []
    acc_vec = []
    num_for_each_disease = defaultdict(float)
    TP_for_each_disease = defaultdict(float)
    rank_for_each_disease = defaultdict(float)

    for index in range(n_batches):
        batchX = dataset[0][index * batchSize:(index + 1) * batchSize]
        batchY = dataset[1][index * batchSize:(index + 1) * batchSize]
        x, y, mask, lengths = padMatrix(batchX, batchY, options)
        x = torch.from_numpy(x).to(device).float()
        mask = torch.from_numpy(mask).to(device).float()
        y_hat = gram(x, mask, leavesList, ancestorsList)
        y = torch.from_numpy(y).float().to(device)
        lengths = torch.from_numpy(lengths).float().to(device)
        loss, acc = loss_fn(y_hat, y, lengths)
        cost_vec.append(loss.item())
        acc_vec.append(acc)

        # Calculating the accuracy for each disease
        y_sorted, indices = torch.sort(y_hat, dim=2, descending=True)
        # indices = indices[:, :, :20]
        for i, j, k in torch.nonzero(y, as_tuple=False):
            k = k.item()
            num_for_each_disease[k] += 1
            # search the rank for k
            m = torch.nonzero(indices[i][j] == k,
                              as_tuple=False).view(-1).item()
            # calculate the top20 accuracy
            if m < 20:
                TP_for_each_disease[k] += 1
            rank_for_each_disease[k] += (m + 1)

    cost = np.mean(cost_vec)
    acc = np.mean(acc_vec)
    print('Whole data average loss:%f, average accuracy@20:%f,' % (cost, acc))

    print('Recording the accuracy for each disease ...')
    acc_out_file = outFile + '_all_acc.txt'
    # sort the disease by num
    num_for_each_disease = OrderedDict(
        sorted(num_for_each_disease.items(),
               key=lambda item: item[1],
               reverse=True))
    for disease in num_for_each_disease.keys():
        d_acc = TP_for_each_disease[disease] / num_for_each_disease[disease]
        avg_rank = rank_for_each_disease[disease] / num_for_each_disease[
            disease]
        buf = 'TypeNum:%d, icd_code:%s, Count:%d, avg_rank:%f, Accuracy:%f' % \
              (disease, rTypes[disease], num_for_each_disease[disease], avg_rank, d_acc)
        print2file(buf, acc_out_file)
    print('Done!')