def main(parser):
    global args, best_acc
    args = parser.parse_args()
    best_acc = 0

    #cnn
    model = models.resnet34(pretrained = False)
    model_dict = torch.load('/your_dir/resnet34-333f7ec4.pth')
#   如果加载自己模型就改为使用上述两句命令
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.cuda()
    model = nn.DataParallel(model)
    model.load_state_dict(model_dict['state_dict'])

#    if args.weights==0.5:
#        criterion = nn.CrossEntropyLoss().cuda()
#    else:
#        w = torch.Tensor([1-args.weights, args.weights])
#        criterion = nn.CrossEntropyLoss(w).cuda()
#    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
#
#    cudnn.benchmark = True

    #normalization
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])

    val_trans = transforms.Compose([transforms.ToTensor(), normalize])
    
    #load data
    val_dset = MILdataset(args.val_lib, 0,val_trans)
    val_loader = torch.utils.data.DataLoader(
            val_dset,
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False)

    time_mark = time.strftime('%Y_%m_%d',time.localtime(time.time()))
#    time_mark = '2020_03_06_'
    #以当前时间作为保存的文件名标识        
        
    #open output file
    fconv = open(os.path.join(args.output, time_mark + 'Test_CNN_convergence_512.csv'), 'w')
    fconv.write('test,,\n')
    fconv.write('acc,recall,fnr')
    fconv.close()
    start_time = time.time()
    val_dset.setmode(1)
    val_probs = inference(val_loader, model, args.batch_size, 'test')
    v_topk = group_argtopk(np.array(val_dset.slideIDX), val_probs[:,1], 1)
    v_pred = group_max(val_probs,v_topk,1)

    metrics_meters = calc_accuracy(v_pred, val_dset.targets)
    str_logs = ['{} - {:.4}'.format(k, v) for k, v in metrics_meters.items()]
    s = ', '.join(str_logs)
    print('\tTest  metrics: ' + s)
    result = '\t'+str(metrics_meters['acc']) + ',' + str(metrics_meters['recall']) + ','\
                 + str(metrics_meters['fnr'])
    fconv = open(os.path.join(args.output, time_mark + 'CNN_convergence_512.csv'), 'a')
    fconv.write(result)
    fconv.close()


    result_excel_origin(val_dset,v_pred,time_mark + '_test')
    np.save('output/numpy_save/test_infer_probs_' + time_mark + '.npy',val_probs)
        
    print('\test has been finished, needed %.2f sec.' % (time.time() - start_time))                   
def main(parser):
    global args, best_acc
    args = parser.parse_args()
    best_acc = 0

    #cnn
    model = models.resnet34(pretrained=False)
    #    model_path = model_path = '/your_dir/resnet34-333f7ec4.pth'
    #    model_dict = torch.load('output/2020_03_06_CNN_checkpoint_best_3.9.pth')
    #   如果加载自己模型就改为使用上述两句命令
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.cuda()
    model = nn.DataParallel(model)
    #    model.load_state_dict(model_dict['state_dict'])

    if args.weights == 0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1 - args.weights, args.weights])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    cudnn.benchmark = True

    #normalization
    normalize = transforms.Normalize(mean=[0.736, 0.58, 0.701],
                                     std=[0.126, 0.144, 0.113])
    train_trans = transforms.Compose([
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_trans = transforms.Compose([transforms.ToTensor(), normalize])

    #load data
    train_dset = MILdataset(args.train_lib, args.k, train_trans,
                            args.train_dir)
    train_loader = torch.utils.data.DataLoader(train_dset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=False)
    if args.val_lib:
        val_dset = MILdataset(args.val_lib, 0, val_trans, args.val_dir)
        val_loader = torch.utils.data.DataLoader(val_dset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)

    time_mark = time.strftime('%Y_%m_%d_', time.localtime(time.time()))
    #    time_mark = '2020_03_06_'
    #以当前时间作为保存的文件名标识

    #open output file
    fconv = open(
        os.path.join(args.output, time_mark + 'CNN_convergence_512.csv'), 'w')
    fconv.write(' ,Training,,,,Train_whole,,,Validation,,\n')
    fconv.write(
        'epoch,train_acc,train_recall,train_fnr,train_loss,true_acc,true_recall,true_fnr,acc,recall,fnr'
    )
    fconv.close()
    topk_list = []
    #用于存储每一轮算出来的top k index
    early_stop_count = 0
    #标记是否early stop的变量,该变量>epochs*2/3时,就开始进行停止训练的判断
    list_save_dir = os.path.join('output', 'topk_list')
    if not os.path.isdir(list_save_dir): os.makedirs(list_save_dir)

    best_metric_probs_inf_save = {
        'train_dset_slideIDX': train_dset.slideIDX,
        'train_dset_grid': train_dset.grid,
        'val_dset_slideIDX': val_dset.slideIDX,
        'val_dset_grid': val_dset.grid
    }
    #该字典主要用于保存最佳模型对应的train_probs和val_probs,以便用于后续的test和特征提取。
    # 之所以还要保存对应的dset_slideIDX和dset_grid是因为当train_dset出现一些slide的gird比top k的k数值还少时,\
    # 这部分的slide就会进行随机重复采样,下次直接调用的时候难免会出现部分gird和probs的记录不一致的情况,为确保严谨,\
    # 需要将上述这些列表也保存下来。然而val_dset默认不会出现这种情况,在这里也进行保存只是为了信息的一致性。
    #loop throuh epochs
    for epoch in range(args.nepochs):
        if epoch >= args.nepochs * 2 / 3 and early_stop_count >= 3:
            print('Early stop at Epoch:' + str(epoch + 1))
            break
        start_time = time.time()
        #Train
        topk_exist_flag = False
        if os.path.exists(os.path.join(list_save_dir,
                                       time_mark + '.pkl')) and epoch == 0:
            with open(os.path.join(list_save_dir, time_mark + '.pkl'),
                      'rb') as fp:
                topk_list = pickle.load(fp)

            topk = topk_list[-1][0]
            topk_exist_flag = True

        else:
            train_dset.setmode(1)
            train_probs = inference(epoch, train_loader, model,
                                    args.batch_size, 'train')
            topk = group_argtopk(np.array(train_dset.slideIDX),
                                 train_probs[:, 1], args.k)
            t_pred = group_max(train_probs, topk, args.k)
        repeat = 5
        if epoch >= 2 / 3 * args.nepochs:
            repeat = np.random.choice([3, 5])
            #前10轮设定在训练时复制采样,后10轮后随机决定是否复制采样
            topk_last = topk_list[-1][0]
            if sum(np.not_equal(topk_last, topk)) < 0.01 * len(topk):
                early_stop_count += 1
        if not topk_exist_flag:
            topk_list.append((topk.copy(), train_probs.copy()))
        with open(os.path.join(list_save_dir, time_mark + '.pkl'), 'wb') as fp:
            pickle.dump(topk_list, fp)

        train_dset.maketraindata(topk, repeat)
        train_dset.shuffletraindata()
        train_dset.setmode(2)
        whole_acc, whole_recall, whole_fnr, whole_loss = train(
            epoch, train_loader, model, criterion, optimizer)
        print('\tTraining  Epoch: [{}/{}] Acc: {} Recall:{} Fnr:{} Loss: {}'.format(epoch+1, \
              args.nepochs, whole_acc,whole_recall,whole_fnr,whole_loss))

        topk = group_argtopk(np.array(train_dset.slideIDX), train_probs[:, 1],
                             1)
        t_pred = group_max(train_probs, topk, 1)
        metrics_meters = calc_accuracy(t_pred, train_dset.targets)
        result = '\n'+str(epoch+1) + ',' + str(whole_acc) + ',' +str(whole_recall)+ ',' +str(whole_fnr)+ ',' +str(whole_loss) \
                + ','+ str(metrics_meters['acc']) + ',' + str(metrics_meters['recall']) + ','\
                + str(metrics_meters['fnr'])

        val_dset.setmode(1)
        val_probs = inference(epoch, val_loader, model, args.batch_size, 'val')
        v_topk = group_argtopk(np.array(val_dset.slideIDX), val_probs[:, 1], 1)
        v_pred = group_max(val_probs, v_topk, 1)

        metrics_meters = calc_accuracy(v_pred, val_dset.targets)
        str_logs = [
            '{} - {:.4}'.format(k, v) for k, v in metrics_meters.items()
        ]
        s = ', '.join(str_logs)
        print('\tValidation  Epoch: [{}/{}]  '.format(epoch +
                                                      1, args.nepochs) + s)
        result = result + ','+ str(metrics_meters['acc']) + ',' + str(metrics_meters['recall']) + ','\
                 + str(metrics_meters['fnr'])
        fconv = open(
            os.path.join(args.output, time_mark + 'CNN_convergence_512.csv'),
            'a')
        fconv.write(result)
        fconv.close()
        #Save best model
        tmp_acc = (metrics_meters['acc'] + metrics_meters['recall']
                   ) / 2 - metrics_meters['fnr'] * args.weights
        if tmp_acc >= best_acc:
            best_acc = tmp_acc.copy()
            obj = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                obj,
                os.path.join(args.output,
                             time_mark + 'CNN_checkpoint_best.pth'))

            if epoch > 0:
                result_excel_origin(train_dset, t_pred,
                                    time_mark + 'train_' + str(epoch + 1))
                result_excel_origin(val_dset, v_pred,
                                    time_mark + 'val_' + str(epoch + 1))
                #                np.save('output/numpy_save/' +time_mark + 'train_infer_probs_' + str(epoch+1) + '.npy',train_probs)
                #                np.save('output/numpy_save/' +time_mark + 'val_infer_probs_' + str(epoch+1) + '.npy',val_probs)
                best_metric_probs_inf_save['train_probs'] = train_probs.copy()
                best_metric_probs_inf_save['val_probs'] = val_probs.copy()

        print('\tEpoch %d has been finished, needed %.2f sec.' %
              (epoch + 1, time.time() - start_time))
    with open(os.path.join(list_save_dir, time_mark + '.pkl'), 'wb') as fp:
        pickle.dump(topk_list, fp)

    torch.save(best_metric_probs_inf_save,
               'output/numpy_save/final/best_metric_probs_inf.db')
def main(parser):
    global args, best_acc
    args = parser.parse_args()
    best_acc = 0

    #cnn
    #    model  = densenet121_ibn_b(num_classes=2,pretrained = False)
    model = models.resnet18(pretrained=False)
    #    model_path = model_path = '/your_dir/resnet34-333f7ec4.pth'
    #    model_dict = torch.load('output/2020_03_06_CNN_checkpoint_best_3.9.pth')
    #   如果加载自己模型就改为使用上述两句命令
    model.fc = nn.Linear(model.fc.in_features, 2)
    model = nn.DataParallel(model.cuda())
    #    model.load_state_dict(model_dict['state_dict'])
    #    criterion = focal_loss(alpha=[1,args.weights/(1-args.weights)], gamma=2, num_classes = 2)

    if args.weights == 0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1, args.weights / (1 - args.weights)])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-4)

    cudnn.benchmark = True

    #normalization
    #    normalize = transforms.Normalize(mean=[0.736, 0.58, 0.701],std=[0.126, 0.144, 0.113])
    train_trans = transforms.Compose([
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    val_trans = transforms.Compose([transforms.ToTensor()])

    with open(args.select_lib, 'rb') as fp:
        target_train_slide, target_val_slide = pickle.load(fp)
    #load data
    train_dset = MILdataset(
        args.train_lib,
        0,
        train_trans,
        args.train_dir,
        target_train_slide,
    )
    train_loader = torch.utils.data.DataLoader(train_dset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=False)

    val_dset = MILdataset(args.train_lib, 0, val_trans, args.train_dir,
                          target_val_slide)
    val_loader = torch.utils.data.DataLoader(val_dset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    time_mark = time.strftime('%Y_%m_%d_', time.localtime(time.time()))
    #    time_mark = '2020_03_06_'
    #以当前时间作为保存的文件名标识

    #open output file
    fconv = open(
        os.path.join(args.output,
                     time_mark + 'CNN_convergence_224_resnet18_CE.csv'), 'w')
    fconv.write(' ,Train,,,,Validation,,,,Validation_whole,,\n')
    fconv.write('epoch,train_precision,train_recall,train_f1,train_loss,\
val_precision,val_recall,val_f1,val_loss,true_precision,true_recall,true_f1')
    fconv.close()
    early_stop_count = 0
    #标记是否early stop的变量,该变量>epochs*2/3时,就开始进行停止训练的判断
    list_save_dir = os.path.join('output', 'topk_list', 'minmax')
    if not os.path.isdir(list_save_dir): os.makedirs(list_save_dir)

    #    best_metric_probs_inf_save = {'train_dset_slideIDX':train_dset.slideIDX,
    #                                  'train_dset_grid':train_dset.grid,
    #                                  'val_dset_slideIDX':val_dset.slideIDX,
    #                                  'val_dset_grid':val_dset.grid
    #                                  }
    #该字典主要用于保存最佳模型对应的train_probs和val_probs,以便用于后续的test和特征提取。
    # 之所以还要保存对应的dset_slideIDX和dset_grid是因为当train_dset出现一些slide的gird比top k的k数值还少时,\
    # 这部分的slide就会进行随机重复采样,下次直接调用的时候难免会出现部分gird和probs的记录不一致的情况,为确保严谨,\
    # 需要将上述这些列表也保存下来。然而val_dset默认不会出现这种情况,在这里也进行保存只是为了信息的一致性。
    train_dset.maketraindata('train')
    val_dset.maketraindata('val')
    eopch_save = {}
    val_probs_save = {}
    #loop throuh epochs
    for epoch in range(args.nepochs):
        if epoch >= args.nepochs * 2 / 3 and early_stop_count >= 3:
            print('Early stop at Epoch:' + str(epoch + 1))
            break
        start_time = time.time()
        #Train
        train_dset.setmode(2)
        train_dset.shuffletraindata()
        train_whole_precision, train_whole_recall, train_whole_f1, train_whole_loss = train_predict(
            epoch, train_loader, model, criterion, optimizer, 'train', True, 2)
        print('\tTraining  Epoch: [{}/{}] Precision: {} Recall:{} F1score:{} Loss: {}'.format(epoch+1, \
              args.nepochs, train_whole_precision,train_whole_recall,train_whole_f1,train_whole_loss))

        result = '\n' + str(epoch + 1) + ',' + str(
            train_whole_precision) + ',' + str(train_whole_recall) + ',' + str(
                train_whole_f1) + ',' + str(train_whole_loss)
        #        eopch_save.update({epoch+1:copy.copy(model.state_dict())})
        #        torch.save(eopch_save, os.path.join(args.output, time_mark +'resnet18_ibn_b_checkpoint_224.pth'))

        val_dset.setmode(2)
        val_whole_precision, val_whole_recall, val_whole_f1, val_whole_loss, val_probs = train_predict(
            epoch, val_loader, model, criterion, optimizer, 'val')
        #        v_topk = group_argtopk(np.array(val_dset.slideIDX), val_probs[val_dset.label_mark], args.k)
        #        v_pred = group_max(val_probs,v_topk,args.k)
        v_pred = group_identify(val_dset.slideIDX, val_probs)

        metrics_meters = calc_accuracy(v_pred, val_dset.targets)
        str_logs = [
            '{} - {:.4}'.format(k, v) for k, v in metrics_meters.items()
        ]
        s = ', '.join(str_logs)
        print('\tValidation  Epoch: [{}/{}]  '.format(epoch +
                                                      1, args.nepochs) + s)
        result = result + ',' + str(val_whole_precision) + ',' +str(val_whole_recall) + ',' +str(val_whole_f1) + ',' \
                 + str(val_whole_loss) + ','+ str(metrics_meters['precision']) + ',' + str(metrics_meters['recall']) + ','\
                 + str(metrics_meters['f1score'])
        fconv = open(
            os.path.join(args.output,
                         time_mark + 'CNN_convergence_224_resnet18_CE.csv'),
            'a')
        fconv.write(result)
        fconv.close()
        #Save best model
        tmp_acc = 1 - val_whole_loss  #(metrics_meters['acc'] + metrics_meters['recall'])/2 #- metrics_meters['fnr']*args.weights
        if tmp_acc >= best_acc:
            best_acc = tmp_acc.copy()
            #            obj = {
            #                'epoch': epoch+1,
            #                'state_dict': model.state_dict(),
            #                'best_acc': best_acc,
            #                'optimizer' : optimizer.state_dict()
            #            }
            #            torch.save(obj, os.path.join(args.output, time_mark +'CNN_checkpoint_best.pth'))
            #            best_metric_probs_inf_save['train_probs'] = train_probs.copy()
            early_stop_count = 0
        else:
            early_stop_count += 1

        if epoch > 0:
            eopch_save.update({epoch + 1: copy.copy(model.state_dict())})
            val_probs_save.update({epoch + 1: val_probs})
            #            result_excel_origin(train_dset,t_pred,time_mark + 'train_' + str(epoch+1))
            result_excel_origin(val_dset, v_pred,
                                time_mark + 'val_' + str(epoch + 1))


#                np.save('output/numpy_save/' +time_mark + 'train_infer_probs_' + str(epoch+1) + '.npy',train_probs)
#                np.save('output/numpy_save/' +time_mark + 'val_infer_probs_' + str(epoch+1) + '.npy',val_probs)

        print('\tEpoch %d has been finished, needed %.2f sec.' %
              (epoch + 1, time.time() - start_time))
        #    with open(os.path.join(list_save_dir, time_mark + '.pkl'), 'wb') as fp:
        #        pickle.dump(topk_list, fp)

        torch.save(
            val_probs_save,
            os.path.join(args.output, 'numpy_save/final/minmax',
                         time_mark + 'resnet18_val_probs_224.db'))
        torch.save(
            eopch_save,
            os.path.join(args.output,
                         time_mark + 'resnet18_ibn_b_checkpoint_224.pth'))
def main(parser):
    global args, best_acc
    args = parser.parse_args()
    best_acc = 0

    #cnn
    model = densenet121_ibn_b(num_classes=2, pretrained=False)
    #    model = models.resnet18(pretrained = False)
    #    model_path = model_path = '/your_dir/resnet34-333f7ec4.pth'
    model_dict = torch.load(
        'output/2020_05_13_densenet121_ibn_b_checkpoint_224.pth')
    #   如果加载自己模型就改为使用上述两句命令
    #    model.fc = nn.Linear(model.fc.in_features, 2)
    model = nn.DataParallel(model.cuda())
    model.load_state_dict(model_dict[20])
    #    criterion = focal_loss(alpha=[1,args.weights/(1-args.weights)], gamma=2, num_classes = 2)

    if args.weights == 0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1, args.weights / (1 - args.weights)])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

    cudnn.benchmark = True

    #    train_trans = transforms.Compose([transforms.RandomVerticalFlip(),
    #                                      transforms.RandomHorizontalFlip(),
    #                                      transforms.ToTensor()])
    val_trans = transforms.Compose([transforms.ToTensor()])

    #    with open(args.select_lib, 'rb') as fp:
    #        target_train_slide,target_val_slide = pickle.load(fp)
    #    #load data
    #    train_dset = MILdataset(args.train_lib, args.k,train_trans,args.train_dir,target_train_slide,)
    #    train_loader = torch.utils.data.DataLoader(
    #        train_dset,
    #        batch_size=args.batch_size, shuffle=False,
    #        num_workers=args.workers, pin_memory=False)
    start_time = time.time()
    batch_select = ['batch_3_SYSUCC', 'batch_5_SYSUCC']

    for i in range(len(batch_select)):
        val_dset = MILdataset(args.train_lib, val_trans, args.train_dir,
                              [batch_select[i]])
        val_loader = torch.utils.data.DataLoader(val_dset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)

        val_dset.setmode(1)
        val_whole_precision, val_whole_recall, val_whole_f1, val_whole_loss, val_probs = train_predict(
            0, val_loader, model, criterion, optimizer, 'val')
        #        v_topk = group_argtopk(np.array(val_dset.slideIDX), val_probs[val_dset.label_mark], args.k)
        #        v_pred = group_max(val_probs,v_topk,args.k)
        #    val_probs = np.load('output/numpy_save/' + batch_select[0] + '_val' + '.npy')
        v_pred = group_identify(val_dset.slideIDX, val_probs)

        metrics_meters = calc_accuracy(v_pred, val_dset.targets)
        fconv = open(
            os.path.join(args.output, batch_select[i] + '_metric.csv'), 'w')
        fconv.write(
            'sample_precision,sample_recall,sample_f1,slide_precision,slide_recall,slide_f1'
        )
        result = '\n' + str(val_whole_precision) + ',' +str(val_whole_recall) + ',' +str(val_whole_f1) + ',' \
                     + str(metrics_meters['precision']) + ',' + str(metrics_meters['recall']) + ','\
                     + str(metrics_meters['f1score'])
        fconv.write(result)
        fconv.close()

        result_excel_origin(val_dset, v_pred, batch_select[i] + '_val')
        np.save('output/numpy_save/' + batch_select[i] + '_val' + '.npy',
                val_probs)
        #                np.save('output/numpy_save/' +time_mark + 'val_infer_probs_' + str(epoch+1) + '.npy',val_probs)

        msi_pro = group_proba(val_dset.slideIDX, val_probs, 0.5)
        fpr, tpr, thresholds = roc_curve(val_dset.targets,
                                         msi_pro,
                                         pos_label=1)
        roc_auc = auc(fpr, tpr)

        ## 绘制roc曲线图
        plt.subplots(figsize=(7, 5.5))
        plt.plot(fpr,
                 tpr,
                 color='darkorange',
                 linewidth=2,
                 label='ROC curve (area = %0.3f)' % roc_auc)
        plt.plot([0, 1], [0, 1], color='navy', linewidth=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(batch_select[i] + ' ROC Curve')
        plt.legend(loc="lower right")
        #    plt.show()
        plt.savefig(os.path.join('output', '0.5',
                                 f"{batch_select[i]}_ROC.png"))

        group_log(val_dset.slidenames, val_dset.slideIDX, val_dset.targets,
                  val_dset.label_mark, val_probs,
                  batch_select[0] + '_metric_info')

        print(batch_select[i] + '\t has been finished, needed %.2f sec.' %
              (time.time() - start_time))
def main(parser):
    global args, best_acc
    args = parser.parse_args()
    best_acc = 0

    #cnn
    model = models.resnet34(num_classes=2, pretrained=False)
    #    model_path = model_path = '/your_dir/resnet34-333f7ec4.pth'
    model_dict = torch.load('output/2020_03_06_CNN_checkpoint_best_3.9.pth')
    #   如果加载自己模型就改为使用上述两句命令
    #    model.fc = nn.Linear(model.fc.in_features, 2)
    model.cuda()
    model = nn.DataParallel(model)
    model.load_state_dict(model_dict['state_dict'])

    #    get_feature_model = nn.Sequential(list(model.children())[0].layer4[-1]).cuda()
    #    get_feature_model = nn.Sequential(*list(list(model.children())[0].children())[:-2]).cuda()
    get_feature_model = nn.Sequential(
        *list(list(model.children())[0].children())[:-1], Flatten()).cuda()
    get_feature_model.eval()

    lstm_model = ori_lstm(512, 156, 2, True, 2).cuda()

    if args.weights == 0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1 - args.weights, args.weights])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    cudnn.benchmark = True

    #normalization
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    #    train_trans = transforms.Compose([transforms.RandomVerticalFlip(),
    #                                      transforms.RandomHorizontalFlip(),
    #                                      transforms.ToTensor(),
    #                                        normalize])
    # 在进行根据已有的top k特征进行LSTM训练时,不再需要复杂的transforms方法,只做最基本的就好。
    val_trans = transforms.Compose([transforms.ToTensor(), normalize])

    best_metric_probs_inf_save = torch.load(
        'output/numpy_save/final/best_metric_probs_inf.db')

    #load data
    train_dset = MILdataset(args.train_lib, args.k, val_trans,
                            best_metric_probs_inf_save['train_dset_grid'],
                            best_metric_probs_inf_save['train_dset_slideIDX'])
    train_loader = torch.utils.data.DataLoader(train_dset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=False)
    if args.val_lib:
        val_dset = MILdataset(args.val_lib, 0, val_trans,
                              best_metric_probs_inf_save['val_dset_grid'],
                              best_metric_probs_inf_save['val_dset_slideIDX'])
        val_loader = torch.utils.data.DataLoader(val_dset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)


#    summary(model,input_size=(3,224,224))
#    img = train_dset.slides[0].read_region(train_dset.grid[0],train_dset.level,(train_dset.patch_size[0],\
#                                                    train_dset.patch_size[0])).convert('RGB')
#    img = img.resize((224,224),Image.BILINEAR)
#    img_var = val_trans(img).unsqueeze(0)
#    feature = get_feature_model(img_var.cuda())

    time_mark = time.strftime('%Y_%m_%d_', time.localtime(time.time()))
    #以当前时间作为保存的文件名标识

    #open output file
    fconv = open(
        os.path.join(args.output, time_mark + 'LSTM_convergence_512.csv'), 'w')
    fconv.write(' ,Training,,,,Train_whole,,,Validation,,\n')
    fconv.write(
        'epoch,train_acc,train_recall,train_fnr,train_loss,true_acc,true_recall,true_fnr,acc,recall,fnr'
    )
    fconv.close()
    #    topk_list = []
    #用于存储每一轮算出来的top k index
    early_stop_count = 0
    #标记是否early stop的变量,该变量>epochs*2/3时,就开始进行停止训练的判断

    train_probs = best_metric_probs_inf_save['train_probs']
    topk = group_argtopk(np.array(train_dset.slideIDX), train_probs[:, 1],
                         args.k)
    tmp_topk = group_argtopk(np.array(train_dset.slideIDX), train_probs[:, 1],
                             1)

    val_probs = best_metric_probs_inf_save['val_probs']
    v_topk = group_argtopk(np.array(val_dset.slideIDX), val_probs[:, 1], 1)
    #val数据集直接基于top 1进行相关提取

    val_dset.setmode(3)
    val_dset.settopk(v_topk, get_feature_model)

    #loop throuh epochs
    for epoch in range(args.nepochs):
        if epoch >= args.nepochs * 2 / 3 and early_stop_count >= 3:
            print('Early stop at Epoch:' + str(epoch + 1))
            break
        start_time = time.time()
        #Train

        train_dset.setmode(3)
        train_dset.settopk(topk, get_feature_model)
        whole_acc, whole_recall, whole_fnr, whole_loss = train(
            epoch, train_loader, lstm_model, criterion, optimizer)
        print('\tTraining  Epoch: [{}/{}] Acc: {} Recall:{} Fnr:{} Loss: {}'.format(epoch+1, \
              args.nepochs, whole_acc,whole_recall,whole_fnr,whole_loss))

        train_dset.settopk(tmp_topk, get_feature_model)
        tmp_train_probs = inference(epoch, train_loader, lstm_model,
                                    args.batch_size, 'train')
        metrics_meters = calc_accuracy(np.argmax(tmp_train_probs, axis=1),
                                       train_dset.targets)
        #再以top 1来计算train数据集相关指标
        result = '\n'+str(epoch+1) + ',' + str(whole_acc) + ',' +str(whole_recall)+ ',' +str(whole_fnr)+ ',' +str(whole_loss) \
                + ','+ str(metrics_meters['acc']) + ',' + str(metrics_meters['recall']) + ','\
                + str(metrics_meters['fnr'])

        tmp_val_probs = inference(epoch, val_loader, lstm_model,
                                  args.batch_size, 'val')
        metrics_meters = calc_accuracy(np.argmax(tmp_val_probs, axis=1),
                                       val_dset.targets)
        #计算val数据集相关指标
        str_logs = [
            '{} - {:.4}'.format(k, v) for k, v in metrics_meters.items()
        ]
        s = ', '.join(str_logs)
        print('\tValidation  Epoch: [{}/{}]  '.format(epoch +
                                                      1, args.nepochs) + s)
        result = result + ','+ str(metrics_meters['acc']) + ',' + str(metrics_meters['recall']) + ','\
                 + str(metrics_meters['fnr'])
        fconv = open(
            os.path.join(args.output, time_mark + 'LSTM_convergence_512.csv'),
            'a')
        fconv.write(result)
        fconv.close()
        #Save best model
        tmp_acc = (metrics_meters['acc'] + metrics_meters['recall']
                   ) / 2 - metrics_meters['fnr'] * args.weights
        if tmp_acc >= best_acc:
            best_acc = tmp_acc.copy()
            obj = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                obj,
                os.path.join(args.output,
                             time_mark + 'LSTM_checkpoint_best.pth'))

            if epoch > 0:
                result_excel_origin(train_dset,
                                    np.argmax(tmp_train_probs, axis=1),
                                    time_mark + 'lstm_train_' + str(epoch + 1))
                result_excel_origin(val_dset, np.argmax(tmp_val_probs, axis=1),
                                    time_mark + 'lstm_val_' + str(epoch + 1))

        else:
            early_stop_count += 1

        print('\tEpoch %d has been finished, needed %.2f sec.' %
              (epoch + 1, time.time() - start_time))