Пример #1
0
def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None):

    if backbone_net == 'MobileFace':
        net = mobilefacenet.MobileFaceNet()
    elif backbone_net == 'Res50_IR':
        net = cbam.CBAMResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
    elif backbone_net == 'SERes50_IR':
        net = cbam.CBAMResNet_IR(50,
                                 feature_dim=args.feature_dim,
                                 mode='se_ir')
    elif backbone_net == 'CBAMRes50_IR':
        net = cbam.CBAMResNet_IR(50,
                                 feature_dim=args.feature_dim,
                                 mode='cbam_ir')
    elif backbone_net == 'Res100_IR':
        net = cbam.CBAMResNet_IR(100, feature_dim=args.feature_dim, mode='ir')
    elif backbone_net == 'SERes100_IR':
        net = cbam.CBAMResNet_IR(100,
                                 feature_dim=args.feature_dim,
                                 mode='se_ir')
    elif backbone_net == 'CBAMRes100_IR':
        net = cbam.CBAMResNet_IR(100,
                                 feature_dim=args.feature_dim,
                                 mode='cbam_ir')
    else:
        print(args.backbone, ' is not available!')

    # gpu init
    multi_gpus = False
    if len(gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net.load_state_dict(torch.load(resume)['net_state_dict'])

    if multi_gpus:
        net = DataParallel(net).to(device)
    else:
        net = net.to(device)

    transform = transforms.Compose([
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    agedb_dataset = AgeDB30(data_root, file_list, transform=transform)
    agedb_loader = torch.utils.data.DataLoader(agedb_dataset,
                                               batch_size=128,
                                               shuffle=False,
                                               num_workers=2,
                                               drop_last=False)

    return net.eval(), device, agedb_dataset, agedb_loader
Пример #2
0
def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # log init
    save_dir = os.path.join(
        args.save_dir, args.model_pre + args.backbone.upper() + '_' +
        datetime.now().strftime('%Y%m%d_%H%M%S'))
    if os.path.exists(save_dir):
        raise NameError('model dir exists!')
    os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    # dataset loader
    transform = transforms.Compose([
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    # validation dataset
    trainset = CASIAWebFace(args.train_root,
                            args.train_file_list,
                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              drop_last=False)
    # test dataset
    lfwdataset = LFW(args.lfw_test_root,
                     args.lfw_file_list,
                     transform=transform)
    lfwloader = torch.utils.data.DataLoader(lfwdataset,
                                            batch_size=128,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)
    agedbdataset = AgeDB30(args.agedb_test_root,
                           args.agedb_file_list,
                           transform=transform)
    agedbloader = torch.utils.data.DataLoader(agedbdataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=4,
                                              drop_last=False)
    cfpfpdataset = CFP_FP(args.cfpfp_test_root,
                          args.cfpfp_file_list,
                          transform=transform)
    cfpfploader = torch.utils.data.DataLoader(cfpfpdataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=4,
                                              drop_last=False)

    # define backbone and margin layer
    if args.backbone == 'MobileFace':
        net = MobileFaceNet()
    elif args.backbone == 'Res50':
        net = ResNet50()
    elif args.backbone == 'Res101':
        net = ResNet101()
    elif args.backbone == 'Res50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
    elif args.backbone == 'SERes50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir')
    elif args.backbone == 'SphereNet':
        net = SphereNet(num_layers=64, feature_dim=args.feature_dim)
    else:
        print(args.backbone, ' is not available!')

    if args.margin_type == 'ArcFace':
        margin = ArcMarginProduct(args.feature_dim,
                                  trainset.class_nums,
                                  s=args.scale_size)
    elif args.margin_type == 'CosFace':
        pass
    elif args.margin_type == 'SphereFace':
        pass
    elif args.margin_type == 'InnerProduct':
        margin = InnerProduct(args.feature_dim, trainset.class_nums)
    else:
        print(args.margin_type, 'is not available!')

    if args.resume:
        print('resume the model parameters from: ', args.net_path,
              args.margin_path)
        net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
        margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])

    # define optimizers for different layer

    criterion_classi = torch.nn.CrossEntropyLoss().to(device)
    optimizer_classi = optim.SGD([{
        'params': net.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': margin.parameters(),
        'weight_decay': 5e-4
    }],
                                 lr=0.1,
                                 momentum=0.9,
                                 nesterov=True)
    scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi,
                                                milestones=[20, 35, 45],
                                                gamma=0.1)

    if multi_gpus:
        net = DataParallel(net).to(device)
        margin = DataParallel(margin).to(device)
    else:
        net = net.to(device)
        margin = margin.to(device)

    best_lfw_acc = 0.0
    best_lfw_iters = 0
    best_agedb30_acc = 0.0
    best_agedb30_iters = 0
    best_cfp_fp_acc = 0.0
    best_cfp_fp_iters = 0
    total_iters = 0
    vis = Visualizer(env='softmax_train')
    for epoch in range(1, args.total_epoch + 1):
        scheduler_classi.step()
        # train model
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()

        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            feature = net(img)
            output = margin(feature)
            loss_classi = criterion_classi(output, label)
            total_loss = loss_classi

            optimizer_classi.zero_grad()
            total_loss.backward()
            optimizer_classi.step()

            total_iters += 1
            # print train information
            if total_iters % 100 == 0:
                #current training accuracy
                _, predict = torch.max(output.data, 1)
                total = label.size(0)
                correct = (np.array(predict) == np.array(label.data)).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()
                vis.plot_curves({'train loss': loss_classi.item()},
                                iters=total_iters,
                                title='train loss',
                                xlabel='iters',
                                ylabel='train loss')
                vis.plot_curves({'train accuracy': correct / total},
                                iters=total_iters,
                                title='train accuracy',
                                xlabel='iters',
                                ylabel='train accuracy')
                print(
                    "Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                    .format(total_iters, epoch, loss_classi.item(),
                            correct / total, time_cur,
                            scheduler_classi.get_lr()[0]))
            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                    margin_state_dict = margin.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                    margin_state_dict = margin.state_dict()

                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': net_state_dict
                    },
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': margin_state_dict
                    },
                    os.path.join(save_dir,
                                 'Iter_%06d_margin.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0:
                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device,
                                    lfwdataset, lfwloader)
                lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                _print('LFW Ave Accuracy: {:.4f}'.format(
                    np.mean(lfw_accs) * 100))
                if best_lfw_acc < np.mean(lfw_accs) * 100:
                    best_lfw_acc = np.mean(lfw_accs) * 100
                    best_lfw_iters = total_iters
                # test model on AgeDB30
                getFeatureFromTorch('./result/cur_agedb30_result.mat', net,
                                    device, agedbdataset, agedbloader)
                age_accs = evaluation_10_fold(
                    './result/cur_agedb30_result.mat')
                _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(
                    np.mean(age_accs) * 100))
                if best_agedb30_acc < np.mean(age_accs) * 100:
                    best_agedb30_acc = np.mean(age_accs) * 100
                    best_agedb30_iters = total_iters
                # test model on CFP-FP
                getFeatureFromTorch('./result/cur_cfpfp_result.mat', net,
                                    device, cfpfpdataset, cfpfploader)
                cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
                _print('CFP-FP Ave Accuracy: {:.4f}'.format(
                    np.mean(cfp_accs) * 100))
                if best_cfp_fp_acc < np.mean(cfp_accs) * 100:
                    best_cfp_fp_acc = np.mean(cfp_accs) * 100
                    best_cfp_fp_iters = total_iters
                _print(
                    'Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
                    .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                            best_agedb30_iters, best_cfp_fp_acc,
                            best_cfp_fp_iters))
                vis.plot_curves(
                    {
                        'lfw': np.mean(lfw_accs),
                        'agedb-30': np.mean(age_accs),
                        'cfp-fp': np.mean(cfp_accs)
                    },
                    iters=total_iters,
                    title='test accuracy',
                    xlabel='iters',
                    ylabel='test accuracy')
                net.train()

    _print(
        'Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
        .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
    print('finishing training')
Пример #3
0
def train(args):
    # gpu init
    multi_gpus = False
    best_lfw_acc = 0.0
    best_lfw_iters = 0
    best_agedb30_acc = 0.0
    best_agedb30_iters = 0
    best_cfp_fp_acc = 0.0
    best_cfp_fp_iters = 0
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # log init
    save_dir = os.path.join(
        args.save_dir,
        args.backbone.upper() + datetime.now().date().strftime('%Y%m%d'))
    if not os.path.exists(save_dir):
        #raise NameError('model dir exists!')
        os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    # define backbone and margin layer
    if args.backbone == 'MobileFace':
        net = MobileFaceNet(512).to(config.device)
    elif args.backbone == 'MNasMobile':
        net = MnasNet(512).to(config.device)
    elif args.backbone == 'ProxyNas':
        net = ProxyNas(512).to(config.device)
    elif args.backbone == 'SERes50_IR':
        net = SE_IR(50, 0.6, 'ir_se').to(config.device)
    elif args.backbone == 'IR_50':
        net = SE_IR(50, 0.6, 'ir').to(config.device)
    else:
        print(args.backbone, ' is not available!')
    summary(net.to(config.device), (3, 112, 112))
    #define tranform
    if args.backbone == 'ProxyNas':
        transform = transforms.Compose([
            transforms.Resize(112, 112),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    else:
        # dataset loader
        transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5,
                                      0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
        ])

    # validation dataset
    trainset = VGG_FP(config=config, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              drop_last=False)
    num_iter = len(trainset) // config.batch_size
    numclass = trainset.class_nums

    if args.has_test:

        lfwdataset = LFW(config=config, transform=transform)
        lfwloader = torch.utils.data.DataLoader(lfwdataset,
                                                batch_size=config.batch_size,
                                                shuffle=False,
                                                num_workers=8,
                                                drop_last=False)
        agedbdataset = AgeDB30(config=config, transform=transform)
        agedbloader = torch.utils.data.DataLoader(agedbdataset,
                                                  batch_size=config.batch_size,
                                                  shuffle=False,
                                                  num_workers=8,
                                                  drop_last=False)
        cfpfpdataset = CFP_FP(config=config, transform=transform)
        cfpfploader = torch.utils.data.DataLoader(cfpfpdataset,
                                                  batch_size=config.batch_size,
                                                  shuffle=False,
                                                  num_workers=8,
                                                  drop_last=False)

    if args.margin_type == 'ArcFace':
        margin = ArcMarginProduct(512, numclass, s=args.scale_size)
    elif args.margin_type == 'CosFace':
        pass
    elif args.margin_type == 'SphereFace':
        pass
    else:
        print(args.margin_type, 'is not available!')
    if args.resume:
        print('resume the model parameters from: ', args.net_path,
              args.margin_path)
        net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
        margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])

    # define optimizers for different layer
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer_ft = optim.SGD([{
        'params': net.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': margin.parameters(),
        'weight_decay': 5e-4
    }],
                             lr=0.001,
                             momentum=0.9,
                             nesterov=True)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft,
                                                milestones=config.milestones,
                                                gamma=0.1)

    if multi_gpus:
        net = DataParallel(net).to(device)
        margin = DataParallel(margin).to(device)
    else:
        net = net.to(device)
        margin = margin.to(device)

    total_iters = 1
    vis = Visualizer(env=args.backbone)
    start_epoch = total_iters // num_iter
    if args.resume:
        total_iters = args.resume
        with open('result/log_vis_train.txt', 'r') as fw:
            for line in fw.readlines():
                nodes = line.split(':')
                vis.plot_curves({'softmax loss': np.float(nodes[1])},
                                iters=np.float(nodes[0]),
                                title='train loss',
                                xlabel='iters',
                                ylabel='train loss')
                vis.plot_curves({'train accuracy': np.float(nodes[2])},
                                iters=np.float(nodes[0]),
                                title='train accuracy',
                                xlabel='iters',
                                ylabel='train accuracy')
        with open('result/log_vis_test.txt', 'r') as fw2:
            for line in fw2.readlines():
                nodes = line.split(':')
                vis.plot_curves(
                    {
                        'lfw': np.float(nodes[1]),
                        'agedb-30': np.float(nodes[2]),
                        'cfp-fp': np.float(nodes[3])
                    },
                    iters=np.float(nodes[0]),
                    title='test accuracy',
                    xlabel='iters',
                    ylabel='test accuracy')

    for epoch in range(1, args.total_epoch + 1):
        exp_lr_scheduler.step()
        if epoch < start_epoch:
            continue
        # train model
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()
        log_vis_train = open('result/log_vis_train.txt', 'a')
        log_vis_test = open('result/log_vis_test.txt', 'a')

        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            optimizer_ft.zero_grad()

            raw_logits = net(img)
            output = margin(raw_logits, label)
            total_loss = criterion(output, label)
            total_loss.backward()
            optimizer_ft.step()
            # print train information
            if total_iters % 200 == 0:
                # current training accuracy
                _, predict = torch.max(output.data, 1)
                total = label.size(0)
                correct = (np.array(predict) == np.array(label.data)).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()
                vis.plot_curves({'softmax loss': total_loss.item()},
                                iters=total_iters,
                                title='train loss',
                                xlabel='iters',
                                ylabel='train loss')
                vis.plot_curves({'train accuracy': correct / total},
                                iters=total_iters,
                                title='train accuracy',
                                xlabel='iters',
                                ylabel='train accuracy')
                log_vis_train.write("%d:%f:%f\n" %
                                    (total_iters, total_loss.item(),
                                     (correct / total)))

                print(
                    "Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                    .format(total_iters, epoch, total_loss.item(),
                            correct / total, time_cur,
                            exp_lr_scheduler.get_lr()[0]))

            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                    margin_state_dict = margin.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                    margin_state_dict = margin.state_dict()
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': net_state_dict
                    },
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': margin_state_dict
                    },
                    os.path.join(save_dir,
                                 'Iter_%06d_margin.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0 and args.has_test:
                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device,
                                    lfwdataset, lfwloader)
                lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                _print('LFW Ave Accuracy: {:.4f}'.format(
                    np.mean(lfw_accs) * 100))
                if best_lfw_acc <= np.mean(lfw_accs) * 100:
                    best_lfw_acc = np.mean(lfw_accs) * 100
                    best_lfw_iters = total_iters

                # test model on AgeDB30
                getFeatureFromTorch('./result/cur_agedb30_result.mat', net,
                                    device, agedbdataset, agedbloader)
                age_accs = evaluation_10_fold(
                    './result/cur_agedb30_result.mat')
                _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(
                    np.mean(age_accs) * 100))
                if best_agedb30_acc <= np.mean(age_accs) * 100:
                    best_agedb30_acc = np.mean(age_accs) * 100
                    best_agedb30_iters = total_iters

                # test model on CFP-FP
                getFeatureFromTorch('./result/cur_cfpfp_result.mat', net,
                                    device, cfpfpdataset, cfpfploader)
                cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
                _print('CFP-FP Ave Accuracy: {:.4f}'.format(
                    np.mean(cfp_accs) * 100))
                if best_cfp_fp_acc <= np.mean(cfp_accs) * 100:
                    best_cfp_fp_acc = np.mean(cfp_accs) * 100
                    best_cfp_fp_iters = total_iters
                _print(
                    'Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
                    .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                            best_agedb30_iters, best_cfp_fp_acc,
                            best_cfp_fp_iters))
                # _print('Current Best Accuracy:LFW: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
                #                             best_lfw_acc, best_lfw_iters, best_cfp_fp_acc, best_cfp_fp_iters))

                vis.plot_curves(
                    {
                        'lfw': np.mean(lfw_accs),
                        'agedb-30': np.mean(age_accs),
                        'cfp-fp': np.mean(cfp_accs)
                    },
                    iters=total_iters,
                    title='test accuracy',
                    xlabel='iters',
                    ylabel='test accuracy')
                log_vis_test.write('%d:%f:%f:%f\n' %
                                   (total_iters, np.mean(lfw_accs),
                                    np.mean(cfp_accs), np.mean(age_accs)))
                net.train()
            total_iters += 1

    _print(
        'Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
        .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
    _print(
        'Finally Best Accuracy: LFW: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
        .format(best_lfw_acc, best_lfw_iters, best_cfp_fp_acc,
                best_cfp_fp_iters))
    print('finishing training')
Пример #4
0
def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # log init
    save_dir = os.path.join(
        args.save_dir, args.model_pre + args.backbone.upper() + '_' +
        datetime.now().strftime('%Y%m%d_%H%M%S'))
    if os.path.exists(save_dir):
        raise NameError('model dir exists!')
    os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    # dataset loader
    transform = transforms.Compose([
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    # validation dataset
    trainset = CASIAWebFace(args.train_root,
                            args.train_file_list,
                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=12,
                                              drop_last=False)
    # test dataset
    lfwdataset = LFW(args.lfw_test_root,
                     args.lfw_file_list,
                     transform=transform)
    lfwloader = torch.utils.data.DataLoader(lfwdataset,
                                            batch_size=128,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)
    agedbdataset = AgeDB30(args.agedb_test_root,
                           args.agedb_file_list,
                           transform=transform)
    agedbloader = torch.utils.data.DataLoader(agedbdataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=4,
                                              drop_last=False)
    cfpfpdataset = CFP_FP(args.cfpfp_test_root,
                          args.cfpfp_file_list,
                          transform=transform)
    cfpfploader = torch.utils.data.DataLoader(cfpfpdataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=4,
                                              drop_last=False)

    # define backbone and margin layer
    if args.backbone == 'MobileFace':
        net = MobileFaceNet()
    elif args.backbone is 'Res50':
        net = ResNet50()
    elif args.backbone == 'Res101':
        net = ResNet101()
    elif args.backbone == 'Res50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
    elif args.backbone == 'SERes50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir')
    else:
        print(args.backbone, ' is not available!')

    if args.margin_type == 'arcface':
        margin = ArcMarginProduct(args.feature_dim,
                                  trainset.class_nums,
                                  s=args.scale_size)
    elif args.margin_type == 'cosface':
        pass
    elif args.margin_type == 'sphereface':
        pass
    else:
        print(args.margin_type, 'is not available!')

    if args.resume:
        print('resume the model parameters from: ', args.resume)
        ckpt = torch.load(args.resume)
        net.load_state_dict(ckpt['net_state_dict'])

    # define optimizers for different layer
    ignored_params_id = []
    ignored_params_id += list(map(id, margin.weight))
    prelu_params = []
    for m in net.modules():
        if isinstance(m, nn.PReLU):
            ignored_params_id += list(map(id, m.parameters()))
            prelu_params += m.parameters()
    base_params = filter(lambda p: id(p) not in ignored_params_id,
                         net.parameters())

    optimizer_ft = optim.SGD([{
        'params': base_params,
        'weight_decay': 5e-4
    }, {
        'params': margin.weight,
        'weight_decay': 5e-4
    }, {
        'params': prelu_params,
        'weight_decay': 0.0
    }],
                             lr=0.1,
                             momentum=0.9,
                             nesterov=True)

    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft,
                                                milestones=[10, 18, 25],
                                                gamma=0.1)

    if multi_gpus:
        net = DataParallel(net).to(device)
        margin = DataParallel(margin).to(device)
    else:
        net = net.to(device)
        margin = margin.to(device)
    criterion = torch.nn.CrossEntropyLoss().to(device)

    best_lfw_acc = 0.0
    best_lfw_iters = 0
    best_agedb30_acc = 0.0
    best_agedb30_iters = 0
    best_cfp_fp_acc = 0.0
    best_cfp_fp_iters = 0

    total_iters = 0
    for epoch in range(1, args.total_epoch + 1):
        exp_lr_scheduler.step()
        # train model
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()

        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            batch_size = img.size(0)
            optimizer_ft.zero_grad()

            raw_logits = net(img)
            output = margin(raw_logits, label)
            total_loss = criterion(output, label)
            total_loss.backward()
            optimizer_ft.step()

            total_iters += 1
            # print train information
            if total_iters % 100 == 0:
                time_cur = (time.time() - since) / 100
                since = time.time()
                print(
                    "Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                    .format(total_iters, epoch, total_loss.item(), time_cur,
                            exp_lr_scheduler.get_lr()[0]))

            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': net_state_dict
                    }, os.path.join(save_dir, 'Iter_%06d.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0:

                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device,
                                    lfwdataset, lfwloader)
                accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(accs) * 100))
                if best_lfw_acc < np.mean(accs) * 100:
                    best_lfw_acc = np.mean(accs) * 100
                    best_lfw_iters = total_iters

                # test model on AgeDB30
                getFeatureFromTorch('./result/cur_agedb30_result.mat', net,
                                    device, agedbdataset, agedbloader)
                accs = evaluation_10_fold('./result/cur_agedb30_result.mat')
                _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(
                    np.mean(accs) * 100))
                if best_agedb30_acc < np.mean(accs) * 100:
                    best_agedb30_acc = np.mean(accs) * 100
                    best_agedb30_iters = total_iters

                # test model on CFP-FP
                getFeatureFromTorch('./result/cur_cfpfp_result.mat', net,
                                    device, cfpfpdataset, cfpfploader)
                accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
                _print('CFP-FP Ave Accuracy: {:.4f}'.format(
                    np.mean(accs) * 100))
                if best_cfp_fp_acc < np.mean(accs) * 100:
                    best_cfp_fp_acc = np.mean(accs) * 100
                    best_cfp_fp_iters = total_iters
                _print(
                    'Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
                    .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                            best_agedb30_iters, best_cfp_fp_acc,
                            best_cfp_fp_iters))

                net.train()

    _print(
        'Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
        .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
    print('finishing training')