Ejemplo n.º 1
0
def test(use_cuda):
    # data
    transformations = get_transforms(input_size=args.image_size,
                                     test_size=args.image_size)
    test_set = data_gen.TestDataset(root=args.test_txt_path,
                                    transform=transformations['test'])
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False)
    # load model
    model = make_model(args)

    if args.model_path:
        # 加载模型
        model.load_state_dict(torch.load(args.model_path))

    if use_cuda:
        model.cuda()

    # evaluate
    y_pred = []
    y_true = []
    img_paths = []
    with torch.no_grad():
        model.eval()  # 设置成eval模式
        for (inputs, targets, paths) in tqdm(test_loader):
            y_true.extend(targets.detach().tolist())
            img_paths.extend(list(paths))
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(
                inputs), torch.autograd.Variable(targets)
            # compute output
            outputs = model(inputs)  # (16,2)
            # dim=1 表示按行计算 即对每一行进行softmax
            # probability = torch.nn.functional.softmax(outputs,dim=1)[:,1].tolist()
            # probability = [1 if prob >= 0.5 else 0 for prob in probability]
            # 返回最大值的索引
            probability = torch.max(outputs,
                                    dim=1)[1].data.cpu().numpy().squeeze()
            y_pred.extend(probability)
        print("y_pred=", y_pred)

        accuracy = metrics.accuracy_score(y_true, y_pred)
        print("accuracy=", accuracy)
        confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
        print("confusion_matrix=", confusion_matrix)
        print(metrics.classification_report(y_true, y_pred))
        # fpr,tpr,thresholds = metrics.roc_curve(y_true,y_pred)
        print("roc-auc score=", metrics.roc_auc_score(y_true, y_pred))

        res_dict = {
            'img_path': img_paths,
            'label': y_true,
            'predict': y_pred,
        }
        df = pd.DataFrame(res_dict)
        df.to_csv(args.result_csv, index=False)
        print(f"write to {args.result_csv} succeed ")
Ejemplo n.º 2
0
    def __init__(self, model_name, model_path):
        self.model_name = model_name
        self.model_path = model_path
        self.model = make_model(args)
        #self.model = models.__dict__['resnet50'](num_classes=54)
        self.use_cuda = False
        if torch.cuda.is_available():
            print('Using GPU for inference')
            self.use_cuda = True
            self.model = torch.nn.DataParallel(self.model).cuda()
            checkpoint = torch.load(self.model_path)
            #self.model.load_state_dict(checkpoint['state_dict'])
            self.model.load_state_dict(checkpoint['state_dict'])
        else:
            print('Using CPU for inference')
            checkpoint = torch.load(self.model_path, map_location='cpu')
            state_dict = OrderedDict()
            # 训练脚本 main.py 中保存了'epoch', 'arch', 'state_dict', 'best_acc1', 'optimizer'五个key值,
            # 其中'state_dict'对应的value才是模型的参数。
            # 训练脚本 main.py 中创建模型时用了torch.nn.DataParallel,因此模型保存时的dict都会有‘module.’的前缀,
            # 下面 tmp = key[7:] 这行代码的作用就是去掉‘module.’前缀
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                state_dict[tmp] = value
            self.model.load_state_dict(state_dict)

        self.model.eval()

        #self.idx_to_class = checkpoint['idx_to_class']
        #self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                                 std=[0.229, 0.224, 0.225])

        #self.transforms = transforms.Compose([
        #    transforms.Resize(256),
        #   transforms.CenterCrop(224),
        #   transforms.ToTensor(),
        #    self.normalize
        #])
        self.transforms = get_test_transform([0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225],
                                             224)


        self.label_id_name_dict = \
            {
                "0": "工艺品/仿唐三彩",
                "1": "工艺品/仿宋木叶盏",
                "2": "工艺品/布贴绣",
                "3": "工艺品/景泰蓝",
                "4": "工艺品/木马勺脸谱",
                "5": "工艺品/柳编",
                "6": "工艺品/葡萄花鸟纹银香囊",
                "7": "工艺品/西安剪纸",
                "8": "工艺品/陕历博唐妞系列",
                "9": "景点/关中书院",
                "10": "景点/兵马俑",
                "11": "景点/南五台",
                "12": "景点/大兴善寺",
                "13": "景点/大观楼",
                "14": "景点/大雁塔",
                "15": "景点/小雁塔",
                "16": "景点/未央宫城墙遗址",
                "17": "景点/水陆庵壁塑",
                "18": "景点/汉长安城遗址",
                "19": "景点/西安城墙",
                "20": "景点/钟楼",
                "21": "景点/长安华严寺",
                "22": "景点/阿房宫遗址",
                "23": "民俗/唢呐",
                "24": "民俗/皮影",
                "25": "特产/临潼火晶柿子",
                "26": "特产/山茱萸",
                "27": "特产/玉器",
                "28": "特产/阎良甜瓜",
                "29": "特产/陕北红小豆",
                "30": "特产/高陵冬枣",
                "31": "美食/八宝玫瑰镜糕",
                "32": "美食/凉皮",
                "33": "美食/凉鱼",
                "34": "美食/德懋恭水晶饼",
                "35": "美食/搅团",
                "36": "美食/枸杞炖银耳",
                "37": "美食/柿子饼",
                "38": "美食/浆水面",
                "39": "美食/灌汤包",
                "40": "美食/烧肘子",
                "41": "美食/石子饼",
                "42": "美食/神仙粉",
                "43": "美食/粉汤羊血",
                "44": "美食/羊肉泡馍",
                "45": "美食/肉夹馍",
                "46": "美食/荞面饸饹",
                "47": "美食/菠菜面",
                "48": "美食/蜂蜜凉粽子",
                "49": "美食/蜜饯张口酥饺",
                "50": "美食/西安油茶",
                "51": "美食/贵妃鸡翅",
                "52": "美食/醪糟",
                "53": "美食/金线油塔"
            }
Ejemplo n.º 3
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):  # 创建文件夹来保存训练后的模型
        mkdir_p(args.checkpoint)

    # Data
    transform = get_transforms(input_size=args.image_size,
                               test_size=args.image_size,
                               backbone=None)
    print('==> Preparing dataset %s' % args.trainroot)
    trainset = datasets.ImageFolder(root=args.trainroot,
                                    transform=transform['val_train'])
    train_loader = data.DataLoader(trainset,
                                   batch_size=args.train_batch,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)

    valset = datasets.ImageFolder(root=args.valroot,
                                  transform=transform['val_test'])
    val_loader = data.DataLoader(valset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)
    """
    # 用来做可视化,检查数据的标签与实际是否对应
    inputs, classes = next(iter(train_loader))
    out = torchvision.utils.make_grid(inputs)
    class_names = ["cat", "dog"]
    visualize.imshow(out, title=[class_names[x] for x in classes])
    """

    #  构建模型
    model = make_model(args)
    if use_cuda:
        model = model.to(device)
    # print(model, dir(model))
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # define loss function (criterion) and optimizer
    if use_cuda:
        criterion = nn.CrossEntropyLoss().cuda()  # cuda version
    else:
        criterion = nn.CrossEntropyLoss().to(device)
    optimizer = get_optimizer(model, args)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.2,
                                                           patience=5,
                                                           verbose=False)

    # Resume
    title = 'ImageNet-' + args.arch
    if args.resume:  # 从中断的地方开始训练
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)  #
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.module.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(val_loader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        train_loss, train_acc, train_5 = train(train_loader, model, criterion,
                                               optimizer, epoch, use_cuda)
        test_loss, test_acc, test_5 = test(val_loader, model, criterion, epoch,
                                           use_cuda)

        scheduler.step(test_loss)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])
        print(
            'train_loss:%f, val_loss:%f, train_acc:%f, train_5:%f, val_acc:%f, val_5:%f'
            % (train_loss, test_loss, train_acc, train_5, test_acc, test_5))

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        if len(args.gpu_id) > 1:
            save_checkpoint(
                {
                    'fold': 0,
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict(),
                    'train_acc': train_acc,
                    'acc': test_acc,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                single=True,
                checkpoint=args.checkpoint)
        else:
            save_checkpoint(
                {
                    'fold': 0,
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'train_acc': train_acc,
                    'acc': test_acc,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                single=True,
                checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
Ejemplo n.º 4
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    transform = get_transforms(input_size=args.image_size,
                               test_size=args.image_size,
                               backbone=None)

    print('==> Preparing dataset %s' % args.trainroot)
    trainset = dataset.Dataset(root=args.trainroot,
                               transform=transform['val_train'])
    train_loader = data.DataLoader(trainset,
                                   batch_size=args.train_batch,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)

    valset = dataset.TestDataset(root=args.valroot,
                                 transform=transform['val_test'])
    val_loader = data.DataLoader(valset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True)

    model = make_model(args)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = get_optimizer(model, args)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.2,
                                                           patience=5,
                                                           verbose=False)

    # Resume
    title = 'ImageNet-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.module.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(val_loader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        train_loss, train_acc, train_5 = train(train_loader, model, criterion,
                                               optimizer, epoch, use_cuda)
        test_loss, test_acc, test_5 = test(val_loader, model, criterion, epoch,
                                           use_cuda)

        scheduler.step(test_loss)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])
        print(
            'train_loss:%f, val_loss:%f, train_acc:%f, train_5:%f, val_acc:%f, val_5:%f'
            % (train_loss, test_loss, train_acc, train_5, test_acc, test_5))

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        if len(args.gpu_id) > 1:
            save_checkpoint(
                {
                    'fold': 0,
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict(),
                    'train_acc': train_acc,
                    'acc': test_acc,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                single=True,
                checkpoint=args.checkpoint)
        else:
            save_checkpoint(
                {
                    'fold': 0,
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'train_acc': train_acc,
                    'acc': test_acc,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                single=True,
                checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
Ejemplo n.º 5
0
def main():
    # Data
    TRAIN = args.trainroot
    VAL = args.valroot
    # TRAIN = '/content/train'
    # VAL = '/content/val'
    transform = get_transforms(input_size=args.image_size, test_size=args.image_size, backbone=None)

    print('==> Preparing dataset %s' % args.trainroot)
    # trainset = datasets.ImageFolder(root=TRAIN, transform=transform['train'])
    # valset = datasets.ImageFolder(root=VAL, transform=transform['val'])
    trainset = dataset.Dataset(root=args.trainroot, transform=transform['train'])
    valset = dataset.TestDataset(root=args.valroot, transform=transform['val'])

    train_loader = DataLoader(
        trainset, 
        batch_size=args.train_batch, 
        shuffle=True, 
        num_workers=args.workers, 
        pin_memory=True)
    
    val_loader = DataLoader(
        valset, 
        batch_size=args.test_batch, 
        shuffle=False, 
        num_workers=args.workers,
        pin_memory=True)

    # model initial
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = make_model(args)

    # TODO:merge in function
    for k,v in model.named_parameters():
      # print("{}: {}".format(k,v.requires_grad))
      if not k.startswith('layer4') and not k.startswith('fc'):
        # print(k)
        v.requires_grad = False
    # sys.exit(0)
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    elif args.ngpu:
        model = torch.nn.DataParallel(model).cuda()

    model.to(device)

    cudnn.benchmark = True

    print('Total params:%.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))  # 打印模型参数数量
    print('Trainable params:%.2fM' % (sum(p.numel() for p in model.parameters() if p.requires_grad)/ 1000000.0))
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = get_optimizer(model, args)
    
    # 基于标准的学习率更新
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=False)

    # Resume
    epochs = args.epochs
    start_epoch = args.start_epoch
    title = 'log-' + args.arch
    if args.resume:
        # --resume checkpoint/checkpoint.pth.tar
        # load checkpoint
        print('Resuming from checkpoint...')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!!'
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        state_dict = checkpoint['state_dict']
        optim = checkpoint['optimizer']
        model.load_state_dict(state_dict, strict=False)
        optimizer.load_state_dict(optim)
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        # logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])
        logger.set_names(['LR', 'epoch', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.',])
    
    # Evaluation:Confusion Matrix:Precision  Recall F1-score
    if args.evaluate and args.resume:
        print('\nEvaluate only')
        test_loss, test_acc, test_acc_5, predict_all, labels_all = test_model(val_loader, model, criterion, device, test=True)
        
        print('Test Loss:%.8f,Test top1:%.2f top5:%.2f' %(test_loss,test_acc,test_acc_5))
        # 混淆矩阵
        report = metrics.classification_report(labels_all,predict_all,target_names=class_list,digits=4)
        confusion = metrics.confusion_matrix(labels_all,predict_all)
        print('\n report ',report)
        print('\n confusion',confusion)
        with open(args.resume[:-3]+"txt", "w+") as f_obj:
          f_obj.write(report)
        # plot_Matrix(args.resume[:-3], confusion, class_list)
        return
    
    # model train and val
    best_acc = 0
    for epoch in range(start_epoch, epochs + 1):
        print('[{}/{}] Training'.format(epoch, args.epochs))
        # train
        train_loss, train_acc, train_acc_5 = train_model(train_loader, model, criterion, optimizer, device)
        # val
        test_loss, test_acc, test_acc_5 = test_model(val_loader, model, criterion, device, test=None)

        scheduler.step(test_loss)

        lr_ = optimizer.param_groups[0]['lr']
        # 核心参数保存logger
        logger.append([lr_, int(epoch), train_loss, test_loss, train_acc, test_acc,])
        print('train_loss:%f, val_loss:%f, train_acc:%f,  val_acc:%f, train_acc_5:%f,  val_acc_5:%f' % (train_loss, test_loss, train_acc, test_acc, train_acc_5, test_acc_5))
        # 保存模型 保存最优
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        if not args.ngpu:
            name = 'checkpoint_' + str(epoch) + '.pth.tar'
        else:
            name = 'ngpu_checkpoint_' + str(epoch) + '.pth.tar'
        save_checkpoint({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'train_acc': train_acc,
            'test_acc': test_acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict()

        }, is_best, checkpoint=args.checkpoint, filename=name)
        # logger.close()
        # logger.plot()
        # savefig(os.path.join(args.checkpoint, 'log.eps'))
        print('Best acc:')
        print(best_acc)
Ejemplo n.º 6
0
from args import args
from build_net import make_model

model = make_model(args)
# print(model)
for k, v in model.named_parameters():
    # print("{}: {}".format(k,v.requires_grad))
    if not k.startswith('7') and not k.startswith('6') and not k.startswith(
            'fc'):
        # print(k)
        v.requires_grad = False
# sys.exit(0)
for k, v in model.named_parameters():
    # print(k)
    print("{}: {}".format(k, v.requires_grad))
Ejemplo n.º 7
0



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备
# device = torch.device('cpu')  # 设备

print('Pytorch garbage-classification Serving on {} ...'.format(device))
# num_classes = len(class_id2name)
model_name = 'resnext101_32x16d_wsl'
model_path = '/content/drive/Shareddrives/riverfjs/model/checkpoint/40_layer4_cbam_resize/best_checkpoint_8.pth.tar'  # args.resume # --resume checkpoint/garbage_resnext101_model_2_1111_4211.pth
model_path_4 = "/content/drive/Shareddrives/riverfjs/model/checkpoint/4_layer4_cbam_resize/best_checkpoint_9.pth.tar"
# print("model_name = ",model_name)
# print("model_path = ",model_path)

net4 = make_model(predict=True, modelname=model_name, num_classes=4)
net4.to(device)
# GCNet = GarbageClassifier(model_name, num_classes, ngpu=0, feature_extract=True)
# GCNet.model.to(device)  # 设置模型运行环境
# 如果要使用cpu环境,请指定 map_location='cpu' 
# state_dict = torch.load(model_path, map_location='cpu')['state_dict']  # state_dict=torch.load(model_path)
#modelState = torch.load(model_path, map_location='cpu')['state_dict']
"""from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in modelState.items():
    head = k[:7]
    if head == 'module.':
        name = k[7:]  # remove `module.`
    else:
        name = k
Ejemplo n.º 8
0
def main():
    global best_acc

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    # data
    transformations = get_transforms(input_size=args.image_size,
                                     test_size=args.image_size)
    train_set = data_gen.Dataset(root=args.train_txt_path,
                                 transform=transformations['val_train'])
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=4)

    val_set = data_gen.ValDataset(root=args.val_txt_path,
                                  transform=transformations['val_test'])
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)

    # model
    model = make_model(args)
    if use_cuda:
        model.cuda()

    # define loss function and optimizer
    if use_cuda:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = get_optimizer(model, args)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.2,
                                                           patience=5,
                                                           verbose=False)

    # load checkpoint
    start_epoch = args.start_epoch
    # if args.resume:
    #     print("===> Resuming from checkpoint")
    #     assert os.path.isfile(args.resume),'Error: no checkpoint directory found'
    #     args.checkpoint = os.path.dirname(args.resume)  # 去掉文件名 返回目录
    #     checkpoint = torch.load(args.resume)
    #     best_acc = checkpoint['best_acc']
    #     start_epoch = checkpoint['epoch']
    #     model.module.load_state_dict(checkpoint['state_dict'])
    #     optimizer.load_state_dict(checkpoint['optimizer'])

    # train
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, optimizer.param_groups[0]['lr']))

        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, use_cuda)
        test_loss, val_acc = val(val_loader, model, criterion, epoch, use_cuda)

        scheduler.step(test_loss)

        print(
            f'train_loss:{train_loss}\t val_loss:{test_loss}\t train_acc:{train_acc} \t val_acc:{val_acc}'
        )

        # save_model
        is_best = val_acc >= best_acc
        best_acc = max(val_acc, best_acc)

        save_checkpoint(
            {
                'fold': 0,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'train_acc': train_acc,
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            single=True,
            checkpoint=args.checkpoint)

    print("best acc = ", best_acc)