def get_net(name):
    if name == 'densenet121':
        net = densenet121()
    elif name == 'densenet161':
        net = densenet161()
    elif name == 'densenet169':
        net = densenet169()
    elif name == 'googlenet':
        net = googlenet()
    elif name == 'inception_v3':
        net = inception_v3()
    elif name == 'mobilenet_v2':
        net = mobilenet_v2()
    elif name == 'resnet18':
        net = resnet18()
    elif name == 'resnet34':
        net = resnet34()
    elif name == 'resnet50':
        net = resnet50()
    elif name == 'resnet_orig':
        net = resnet_orig()
    elif name == 'vgg11_bn':
        net = vgg11_bn()
    elif name == 'vgg13_bn':
        net = vgg13_bn()
    elif name == 'vgg16_bn':
        net = vgg16_bn()
    elif name == 'vgg19_bn':
        net = vgg19_bn()
    else:
        print(f'{name} not a valid model name')
        sys.exit(0)

    return net.to(device)
def main(rank, imgs, jpg_path, out_dir):
    img_scale = 224
    img_crop = 224

    cwd = os.getcwd()
    model_dir = 'models'
    densenet = models.densenet161(pretrained=True,
                                  model_dir=os.path.join(cwd, model_dir))
    densenet.cuda()
    densenet.eval()

    seed(123)  # make reproducible
    dir_fc = os.path.join(out_dir, opt.fc_dir)
    dir_att = os.path.join(out_dir, opt.att_dir)

    if not os.path.isdir(dir_fc):
        os.mkdir(dir_fc)
    if not os.path.isdir(dir_att):
        os.mkdir(dir_att)

    for i, img_name in enumerate(imgs):
        t0 = time.time()
        print(img_name)
        image_id = get_image_id(img_name)

        # load the image
        img = Image.open(os.path.join(jpg_path, img_name))  # (640, 480), RGB
        img = transforms.Compose([
            transforms.Scale(img_scale),
            transforms.CenterCrop(img_crop),
            transforms.ToTensor(),
            normalize,
        ])(img)

        if img.size(0) == 1:
            img = torch.cat([img, img, img], dim=0)

        img = img.unsqueeze(0)
        # print(img.size())
        input_var = Variable(img, volatile=True)
        result = densenet.forward(input_var.cuda())
        fc = result[1].squeeze()
        att = result[2].squeeze()
        att = torch.transpose(att, 0, 2)

        # write to pkl
        np.save(os.path.join(dir_fc, str(image_id)),
                fc.data.cpu().float().numpy())
        np.savez_compressed(os.path.join(dir_att, str(image_id)),
                            feat=att.data.cpu().float().numpy())

        print("{} {}  {}  time cost: {:.3f}".format(rank, i, img_name,
                                                    time.time() - t0))
    def __init__(self, depth=50, numclass=40):
        super(densenet_cnn, self).__init__()

        self.backbone = densenet161(pretrained=True, )
        #self.avgpool = nn.AdaptiveAvgPool2d((1,1))

        #self.classifier = nn.Sequential(
        #    nn.Linear(2048, 512),
        #    nn.ReLU(True),
        #    nn.Dropout(),
        #    nn.Linear(512,numclass)
        # )
        self.classifier = nn.Linear(2048, numclass)
        #self.atn_s3 = MultiHeadAttention(4,2048,512,512)
        self.atn_s4 = MultiHeadAttention(1, 2048, 512, 512)
Ejemplo n.º 4
0
def define_model(is_resnet, is_densenet, is_senet):
    if is_resnet:
        original_model = resnet.resnet50(pretrained = True)
        Encoder = modules.E_resnet(original_model) 
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        model = net.model(Encoder, num_features=2208, block_channel = [192, 384, 1056, 2208])
    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        model = net.model(Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])

    return model
Ejemplo n.º 5
0
def define_model(target):
    model = None
    if target is 'resnet':
        original_model = resnet.resnet50(pretrained=True)
        encoder = modules.E_resnet(original_model)
        model = net.model(encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])
    elif target is 'densenet':
        original_model = densenet.densenet161(pretrained=True)
        encoder = modules.E_densenet(original_model)
        model = net.model(encoder,
                          num_features=2208,
                          block_channel=[192, 384, 1056, 2208])
    elif target is 'senet':
        original_model = senet.senet154(pretrained='imagenet')
        encoder = modules.E_senet(original_model)
        model = net.model(encoder,
                          num_features=2048,
                          block_channel=[256, 512, 1024, 2048])
    return model
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    stats.append(mean)
    stats.append(std)

cudnn.benchmark = True

if model_name == 'ResNet50':
    # ResNet50
    model = resnet.resnet50(pretrained=True)
    model = torch.nn.DataParallel(model)
    dim_descriptor = 2048

elif model_name == 'DenseNet161':
    # DenseNet161
    model = densenet.densenet161(pretrained=True)
    model.features = torch.nn.DataParallel(model.features)
    dim_descriptor = 2208

model.cuda()

print model

# CAM Extraction (# CAMs)

if aggregation_type == 'Offline':
    num_classes = 64
elif aggregation_type == 'Online':
    num_classes = 1000

# Images to load into the net (+ images, + memory, + fast)
import os

import numpy as np
import torch
from torch.autograd import Variable

import torchvision.transforms as transforms
from PIL import Image
import skimage.io
# from torchvision import transforms as trn

import densenet as models

cwd = os.getcwd()
model_dir = 'models'
densenet = models.densenet161(pretrained=True,
                              model_dir=os.path.join(cwd, model_dir))
densenet.eval()

img_scale = 224
img_crop = 224

# preprocess = transforms.Compose([
#     # trn.Scale(img_scale),
#     # trn.CenterCrop(img_crop),
#     # trn.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
Ejemplo n.º 8
0
def main():
    torch.manual_seed(23)
    # Band_num = 2
    # Tag_id = 4
    data_l = data_loader_(batch_size=32,
                          proportion=0.85,
                          shuffle=True,
                          data_add=2,
                          onehot=False,
                          data_size=224,
                          nb_classes=100)
    print data_l.train_length
    print data_l.test_length
    # print 'loading....'
    # trX = np.load('bddog/trX.npy')
    # trY = np.load('bddog/trY.npy')
    # print 'load train data'
    # trX = torch.from_numpy(trX).float()
    # trY = torch.from_numpy(trY).long()
    # teX = np.load('bddog/teX.npy').astype(np.float)
    # teY = np.load('bddog/teY.npy')
    # print 'load test data'
    # teX[:, 0, ...] -= MEAN_VALUE[0]
    # teX[:, 1, ...] -= MEAN_VALUE[1]
    # teX[:, 2, ...] -= MEAN_VALUE[2]
    # teX = torch.from_numpy(teX).float()
    # teY = torch.from_numpy(teY).long()
    # print 'numpy data to tensor'
    # n_examples = len(trX)
    # n_classes = 100
    # model = torch.load('models/resnet_model_pretrained_adam_2_2_SGD_1.pkl')
    model = densenet161(pretrained=True)
    print '==============================='
    print model
    # for param in model.parameters():
    #     param.requires_grad = False
    # model.classifier[-1] = nn.Linear(4096, 100)
    # n = model.classifier[-1].weight.size(1)
    # model.classifier[-1].weight.data.normal_(0, 0.01)
    # model.classifier[-1].bias.data.zero_()

    # VGG16 classifier层
    # model.classifier = nn.Sequential(
    #     nn.Linear(512 * 7 * 7, 4096),
    #     nn.ReLU(inplace=True),
    #     nn.Dropout(),
    #     nn.Linear(4096, 4096),
    #     nn.ReLU(inplace=True),
    #     nn.Dropout(),
    #     nn.Linear(4096, 100),
    # )
    # count = 0
    # print '==============================='
    # for module in model.modules():
    #     print '**** %d' % count
    #     print(module)
    #     count+=1
    # print '==============================='
    # count= 0
    # model.classifier[6] = nn.Linear(4096, 100)
    # for m in model.classifier:
    #     if count == 6:
    #         m = nn.Linear(4096, 100)
    #         if isinstance(m, nn.Conv2d):
    #             n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    #             m.weight.data.normal_(0, math.sqrt(2. / n))
    #             if m.bias is not None:
    #                 m.bias.data.zero_()
    #         elif isinstance(m, nn.BatchNorm2d):
    #             m.weight.data.fill_(1)
    #             m.bias.data.zero_()
    #         elif isinstance(m, nn.Linear):
    #             n = m.weight.size(1)
    #             m.weight.data.normal_(0, 0.01)
    #             m.bias.data.zero_()
    #     count+=1
    # try:
    #     print model.classifier[0]
    # except Exception as e:
    #     print e

    # print '==============================='
    # for module in model.modules()[-7:]:
    #     print '****'
    #     print(module)
    # resnet50 FC层
    # model.group1 = nn.Sequential(
    #     OrderedDict([
    #         ('fc', nn.Linear(2048, 100))
    #     ])
    # )
    model.classifier = nn.Linear(2208, 100)
    # ignored_params = list(map(id, model.group2.parameters()))
    # base_params = filter(lambda p: id(p) not in ignored_params,
    #                      model.parameters())
    # print '==============================='
    # print model
    model = model.cuda()
    loss = torch.nn.CrossEntropyLoss(size_average=True)
    loss = loss.cuda()
    # 对局部优化
    # optimizer = optim.SGD(model.group2.parameters(), lr=(1e-03), momentum=0.9,weight_decay=0.001)
    # optimizer = optim.Adam([{'params':model.layer4[2].parameters()},
    #                         {'params':model.group2.parameters()}
    #                         ],lr=(1e-04),eps=1e-08, betas=(0.9, 0.999), weight_decay=0.0005)
    # optimizer_a = optim.Adam([{'params':model.group2.parameters()}
    #                         ],lr=(1e-04))

    # optimizer = optim.Adam(model.group1.parameters(),lr=(1e-04))

    # optimizer.lr = (1e-04)
    # print optimizer.lr
    # print optimizer.momentum
    # for param_group in optimizer.param_groups:
    #     print param_group['lr']
    # 全局优化
    optimizer = optim.SGD(model.parameters(),
                          lr=(0.001),
                          momentum=0.9,
                          weight_decay=0.0005)
    batch_size = data_l.batch_szie
    data_aug_num = data_l.data_add
    mini_batch_size = batch_size / data_aug_num
    epochs = 1000
    print '1'
    for e in range(epochs):
        cost = 0.0
        train_acc = 0.0
        if e == 4:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.3
        if e == 8:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.3

        num_batches_train = data_l.train_length / mini_batch_size
        print num_batches_train
        train_acc = 0.0
        cost = 0.0
        k = 1
        for k in range(num_batches_train + 1):
            batch_train_data_X, batch_train_data_Y = data_l.get_train_data()
            batch_train_data_X = batch_train_data_X.transpose(0, 3, 1, 2)
            batch_train_data_X[:, 0, ...] -= MEAN_VALUE[0]
            batch_train_data_X[:, 1, ...] -= MEAN_VALUE[1]
            batch_train_data_X[:, 2, ...] -= MEAN_VALUE[2]
            # print batch_train_data_X.shape
            # print batch_train_data_Y.shape
            # batch_train_data_X = preprocess_input(batch_train_data_X)
            torch_batch_train_data_X = torch.from_numpy(
                batch_train_data_X).float()
            torch_batch_train_data_Y = torch.from_numpy(
                batch_train_data_Y).long()
            cost_temp, acc_temp = train(model, loss, optimizer,
                                        torch_batch_train_data_X,
                                        torch_batch_train_data_Y)
            train_acc += acc_temp
            cost += cost_temp
            if (k + 1) % 10 == 0:
                print 'now step train loss is : %f' % (cost_temp)
                print 'now step train acc is : %f' % (acc_temp)
            if (k + 1) % 20 == 0:
                print 'all average train loss is : %f' % (cost / (k + 1))
                print 'all average train acc is : %f' % (train_acc / (k + 1))
            # if (k + 1) % 100 == 0:
            #     model.training = False
            #     acc = 0.0
            #     num_batches_test = data_l.test_length / batch_size
            #     for j in range(num_batches_test+1):
            #         teX, teY = data_l.get_test_data()
            #         teX = teX.transpose(0, 3, 1, 2)
            #         # teX[:, 0, ...] -= MEAN_VALUE[0]
            #         # teX[:, 1, ...] -= MEAN_VALUE[1]
            #         # teX[:, 2, ...] -= MEAN_VALUE[2]
            #         teX = preprocess_input(teX)
            #         teX = torch.from_numpy(teX).float()
            #         # teY = torch.from_numpy(teY).long()
            #         predY = predict(model, teX)
            #         # print predY.dtype
            #         # print teY[start:end]
            #         acc += 1. * np.mean(predY == teY)
            #         # print ('Epoch %d ,Step %d, acc = %.2f%%'%(e,k,100.*np.mean(predY==teY[start:end])))
            #     model.training = True
            #     print 'Epoch %d ,Step %d, all test acc is : %f' % (e, k, acc / num_batches_test)
            #     torch.save(model, 'models/inception_model_pretrained_%s_%s_%s_1.pkl' % ('SGD', str(e), str(k)))
        # model.training = False
        acc = 0.0
        num_batches_test = data_l.test_length / batch_size
        for j in range(num_batches_test + 1):
            teX, teY = data_l.get_test_data()
            teX = teX.transpose(0, 3, 1, 2)
            teX[:, 0, ...] -= MEAN_VALUE[0]
            teX[:, 1, ...] -= MEAN_VALUE[1]
            teX[:, 2, ...] -= MEAN_VALUE[2]
            # teX = preprocess_input(teX)
            teX = torch.from_numpy(teX).float()
            # teY = torch.from_numpy(teY).long()
            predY = predict(model, teX)
            # print predY.dtype
            # print teY[start:end]
            acc += 1. * np.mean(predY == teY)
            # print ('Epoch %d ,Step %d, acc = %.2f%%'%(e,k,100.*np.mean(predY==teY[start:end])))
        # model.training = True
        print 'Epoch %d ,Step %d, all test acc is : %f' % (e, k, acc /
                                                           num_batches_test)
        torch.save(
            model, 'models/densenet161_model_pretrained_%s_%s_%s_4.pkl' %
            ('SGD', str(e), str(k)))
    print 'train over'
Ejemplo n.º 9
0
def get_model(args):
    network = args.network

    if network == 'vgg11':
        model = vgg.vgg11(num_classes=args.class_num)
    elif network == 'vgg13':
        model = vgg.vgg13(num_classes=args.class_num)
    elif network == 'vgg16':
        model = vgg.vgg16(num_classes=args.class_num)
    elif network == 'vgg19':
        model = vgg.vgg19(num_classes=args.class_num)
    elif network == 'vgg11_bn':
        model = vgg.vgg11_bn(num_classes=args.class_num)
    elif network == 'vgg13_bn':
        model = vgg.vgg13_bn(num_classes=args.class_num)
    elif network == 'vgg16_bn':
        model = vgg.vgg16_bn(num_classes=args.class_num)
    elif network == 'vgg19_bn':
        model = vgg.vgg19_bn(num_classes=args.class_num)
    elif network == 'resnet18':
        model = models.resnet18(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet34':
        model = models.resnet34(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet50':
        model = models.resnet50(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet101':
        model = models.resnet101(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet152':
        model = models.resnet152(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'densenet121':
        model = densenet.densenet121(num_classes=args.class_num)
    elif network == 'densenet169':
        model = densenet.densenet169(num_classes=args.class_num)
    elif network == 'densenet161':
        model = densenet.densenet161(num_classes=args.class_num)
    elif network == 'densenet201':
        model = densenet.densenet201(num_classes=args.class_num)

    return model
def get_model_dics(device, model_list= None):
    if model_list is None:
        model_list = ['densenet121', 'densenet161', 'resnet50', 'resnet152',
                      'incept_v1', 'incept_v3', 'inception_v4', 'incept_resnet_v2',
                      'incept_v4_adv2', 'incept_resnet_v2_adv2',
                      'black_densenet161','black_resnet50','black_incept_v3',
                      'old_vgg','old_res','old_incept']
    models = {}
    for model in model_list:
        if model=='densenet121':
            models['densenet121'] = densenet121(num_classes=110)
            load_model(models['densenet121'],"../pre_weights/ep_38_densenet121_val_acc_0.6527.pth",device)
        if model=='densenet161':
            models['densenet161'] = densenet161(num_classes=110)
            load_model(models['densenet161'],"../pre_weights/ep_30_densenet161_val_acc_0.6990.pth",device)
        if model=='resnet50':
            models['resnet50'] = resnet50(num_classes=110)
            load_model(models['resnet50'],"../pre_weights/ep_41_resnet50_val_acc_0.6900.pth",device)
        if model=='incept_v3':
            models['incept_v3'] = inception_v3(num_classes=110)
            load_model(models['incept_v3'],"../pre_weights/ep_36_inception_v3_val_acc_0.6668.pth",device)
        if model=='incept_v1':
            models['incept_v1'] = googlenet(num_classes=110)
            load_model(models['incept_v1'],"../pre_weights/ep_33_googlenet_val_acc_0.7091.pth",device)
    #vgg16 = vgg16_bn(num_classes=110)
    #load_model(vgg16, "./pre_weights/ep_30_vgg16_bn_val_acc_0.7282.pth",device)
        if model=='incept_resnet_v2':
            models['incept_resnet_v2'] = InceptionResNetV2(num_classes=110)  
            load_model(models['incept_resnet_v2'], "../pre_weights/ep_17_InceptionResNetV2_ori_0.8320.pth",device)

        if model=='incept_v4':
            models['incept_v4'] = InceptionV4(num_classes=110)
            load_model(models['incept_v4'],"../pre_weights/ep_17_InceptionV4_ori_0.8171.pth",device)
        if model=='incept_resnet_v2_adv':
            models['incept_resnet_v2_adv'] = InceptionResNetV2(num_classes=110)  
            load_model(models['incept_resnet_v2_adv'], "../pre_weights/ep_22_InceptionResNetV2_val_acc_0.8214.pth",device)

        if model=='incept_v4_adv':
            models['incept_v4_adv'] = InceptionV4(num_classes=110)
            load_model(models['incept_v4_adv'],"../pre_weights/ep_24_InceptionV4_val_acc_0.6765.pth",device)
        if model=='incept_resnet_v2_adv2':
            models['incept_resnet_v2_adv2'] = InceptionResNetV2(num_classes=110)  
            #load_model(models['incept_resnet_v2_adv2'], "../test_weights/ep_29_InceptionResNetV2_adv2_0.8115.pth",device)
            load_model(models['incept_resnet_v2_adv2'], "../test_weights/ep_13_InceptionResNetV2_val_acc_0.8889.pth",device)

        if model=='incept_v4_adv2':
            models['incept_v4_adv2'] = InceptionV4(num_classes=110)
#            load_model(models['incept_v4_adv2'],"../test_weights/ep_32_InceptionV4_adv2_0.7579.pth",device)
            load_model(models['incept_v4_adv2'],"../test_weights/ep_50_InceptionV4_val_acc_0.8295.pth",device)

        if model=='resnet152':
            models['resnet152'] = resnet152(num_classes=110)
            load_model(models['resnet152'],"../pre_weights/ep_14_resnet152_ori_0.6956.pth",device)
        if model=='resnet152_adv':
            models['resnet152_adv'] = resnet152(num_classes=110)
            load_model(models['resnet152_adv'],"../pre_weights/ep_29_resnet152_adv_0.6939.pth",device)
        if model=='resnet152_adv2':
            models['resnet152_adv2'] = resnet152(num_classes=110)
            load_model(models['resnet152_adv2'],"../pre_weights/ep_31_resnet152_adv2_0.6931.pth",device)



        if model=='black_resnet50':
            models['black_resnet50'] = resnet50(num_classes=110)
            load_model(models['black_resnet50'],"../test_weights/ep_0_resnet50_val_acc_0.7063.pth",device)
        if model=='black_densenet161':
            models['black_densenet161'] = densenet161(num_classes=110)
            load_model(models['black_densenet161'],"../test_weights/ep_4_densenet161_val_acc_0.6892.pth",device)
        if model=='black_incept_v3':
            models['black_incept_v3']=inception_v3(num_classes=110)
            load_model(models['black_incept_v3'],"../test_weights/ep_28_inception_v3_val_acc_0.6680.pth",device)
        if model=='old_res':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_resnet_v1_50.py")
            models['old_res'] = torch.load('./models_old/tf_to_pytorch_resnet_v1_50.pth').to(device)
        if model=='old_vgg':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_vgg16.py")
            models[model] = torch.load('./models_old/tf_to_pytorch_vgg16.pth').to(device)
        if model=='old_incept':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_inception_v1.py")
            models[model]  = torch.load('./models_old/tf_to_pytorch_inception_v1.pth').to(device)
       
    return models
Ejemplo n.º 11
0
def main(logger, args):
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    if args.seed is not None:
        random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    gpus = torch.cuda.device_count()
    logger.info(f'use {gpus} gpus')
    logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    logger.info('start loading data')
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    val_loader = DataLoader(Config.val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    logger.info('finish loading data')

    logger.info(f"creating model '{args.network}'")
    model = densenet161(**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")

    for name, param in model.named_parameters():
        logger.info(f"{name},{param.requires_grad}")

    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.1)

    if args.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model = nn.DataParallel(model)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            raise Exception(
                f"{args.resume} is not a file, please check it again")
        logger.info('start only evaluating')
        logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        acc1, acc5, throughput = validate(val_loader, model, args)
        logger.info(
            f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
        )

        return

    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        logger.info(
            f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc1']}%")

    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        acc1, acc5, losses = train(train_loader, model, criterion, optimizer,
                                   scheduler, epoch, logger, args)
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, losses: {losses:.2f}"
        )

        acc1, acc5, throughput = validate(val_loader, model, args)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
        )

        # remember best prec@1 and save checkpoint
        torch.save(
            {
                'epoch': epoch,
                'acc1': acc1,
                'loss': losses,
                'lr': scheduler.get_lr()[0],
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, os.path.join(args.checkpoints, 'latest.pth'))
        if epoch == args.epochs:
            torch.save(
                model.module.state_dict(),
                os.path.join(
                    args.checkpoints,
                    "{}-epoch{}-acc{}.pth".format(args.network, epoch, acc1)))

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, total training time: {training_time:.2f} hours")