def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # log init
    save_dir = os.path.join(
        args.save_dir, args.model_pre + args.backbone.upper() + '_' +
        datetime.now().strftime('%Y%m%d_%H%M%S'))
    if os.path.exists(save_dir):
        raise NameError('model dir exists!')
    os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    # dataset loader
    transform = transforms.Compose([
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    # validation dataset
    trainset = CASIAWebFace(args.train_root,
                            args.train_file_list,
                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              drop_last=False)
    # test dataset
    lfwdataset = LFW(args.lfw_test_root,
                     args.lfw_file_list,
                     transform=transform)
    lfwloader = torch.utils.data.DataLoader(lfwdataset,
                                            batch_size=128,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)

    # define backbone and margin layer
    if args.backbone == 'MobileFace':
        net = MobileFaceNet(feature_dim=args.feature_dim)
    elif args.backbone == 'Res50':
        net = ResNet50()
    elif args.backbone == 'Res101':
        net = ResNet101()
    elif args.backbone == 'Res50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
    elif args.backbone == 'SERes50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir')
    elif args.backbone == 'SphereNet':
        net = SphereNet(num_layers=64, feature_dim=args.feature_dim)
    else:
        print(args.backbone, ' is not available!')

    if args.margin_type == 'ArcFace':
        margin = ArcMarginProduct(args.feature_dim,
                                  trainset.class_nums,
                                  s=args.scale_size)
    elif args.margin_type == 'CosFace':
        pass
    elif args.margin_type == 'SphereFace':
        pass
    elif args.margin_type == 'InnerProduct':
        margin = InnerProduct(args.feature_dim, trainset.class_nums)
    else:
        print(args.margin_type, 'is not available!')

    if args.resume:
        print('resume the model parameters from: ', args.net_path,
              args.margin_path)
        net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
        margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])

    # define optimizers for different layers
    criterion_classi = torch.nn.CrossEntropyLoss().to(device)
    optimizer_classi = optim.SGD([{
        'params': net.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': margin.parameters(),
        'weight_decay': 5e-4
    }],
                                 lr=0.1,
                                 momentum=0.9,
                                 nesterov=True)
    scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi,
                                                milestones=[35, 60, 85],
                                                gamma=0.1)

    criterion_center = AgentCenterLoss(trainset.class_nums, args.feature_dim,
                                       args.scale_size).to(device)
    optimizer_center = optim.SGD(criterion_center.parameters(), lr=0.5)
    scheduler_center = lr_scheduler.MultiStepLR(optimizer_center,
                                                milestones=[35, 60, 85],
                                                gamma=0.1)

    if multi_gpus:
        net = DataParallel(net).to(device)
        margin = DataParallel(margin).to(device)
    else:
        net = net.to(device)
        margin = margin.to(device)

    best_lfw_acc = 0.0
    best_lfw_iters = 0
    total_iters = 0
    for epoch in range(1, args.total_epoch + 1):
        scheduler_classi.step()
        scheduler_center.step()
        # train model
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()

        if args.plot:
            all_features, all_labels = [], []

        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            feature = net(img)
            output = margin(feature)
            loss_classi = criterion_classi(output, label)
            loss_center = criterion_center(feature, label)
            total_loss = loss_classi + loss_center * args.weight_center

            optimizer_classi.zero_grad()
            optimizer_center.zero_grad()
            total_loss.backward()
            optimizer_classi.step()

            # by doing so, weight_cent would not impact on the learning of centers
            #for param in criterion_center.parameters():
            #    param.grad.data *= (1. / args.weight_center)
            optimizer_center.step()

            total_iters += 1
            if args.plot:
                feat = feature.data.cpu().numpy()
                #for i in range(feat.shape[0]):
                #    feat[i] = feat[i] / np.sqrt((np.dot(feat[i], feat[i])))
                all_features.append(feat)
                all_labels.append(label.data.cpu().numpy())

            # print train information
            if total_iters % 10 == 0:
                # current training accuracy
                _, predict = torch.max(output.data, 1)
                total = label.size(0)
                correct = (np.array(predict.cpu()) == np.array(
                    label.data.cpu())).sum()
                time_cur = (time.time() - since) / 10
                since = time.time()
                print(
                    "Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, loss_center: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                    .format(total_iters, epoch, loss_classi.item(),
                            loss_center.item(), correct / total, time_cur,
                            scheduler_classi.get_lr()[0]))
            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                    margin_state_dict = margin.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                    margin_state_dict = margin.state_dict()

                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': net_state_dict
                    },
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': margin_state_dict
                    },
                    os.path.join(save_dir,
                                 'Iter_%06d_margin.ckpt' % total_iters))
                #torch.save({
                #    'iters': total_iters,
                #    'net_state_dict': criterion_center.state_dict()},
                #    os.path.join(save_dir, 'Iter_%06d_center.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0:

                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device,
                                    lfwdataset, lfwloader)
                lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                _print('LFW Ave Accuracy: {:.4f}'.format(
                    np.mean(lfw_accs) * 100))
                if best_lfw_acc < np.mean(lfw_accs) * 100:
                    best_lfw_acc = np.mean(lfw_accs) * 100
                    best_lfw_iters = total_iters

                net.train()

        if args.plot:
            all_features = np.concatenate(all_features, 0)
            all_labels = np.concatenate(all_labels, 0)
            plot_features(all_features, all_labels, trainset.class_nums, epoch,
                          save_dir)
    _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}'.format(
        best_lfw_acc, best_lfw_iters))
    print('finishing training')
Exemplo n.º 2
0
def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # log init
    save_dir = os.path.join(
        args.save_dir, args.model_pre + args.backbone.upper() + '_' +
        datetime.now().strftime('%Y%m%d_%H%M%S'))
    if os.path.exists(save_dir):
        raise NameError('model dir exists!')
    os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    # dataset loader
    transform = transforms.Compose([
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5,
                                  0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    # validation dataset
    trainset = CASIAWebFace(args.train_root,
                            args.train_file_list,
                            transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              drop_last=False)
    # test dataset
    lfwdataset = LFW(args.lfw_test_root,
                     args.lfw_file_list,
                     transform=transform)
    lfwloader = torch.utils.data.DataLoader(lfwdataset,
                                            batch_size=128,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)
    agedbdataset = AgeDB30(args.agedb_test_root,
                           args.agedb_file_list,
                           transform=transform)
    agedbloader = torch.utils.data.DataLoader(agedbdataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=4,
                                              drop_last=False)
    cfpfpdataset = CFP_FP(args.cfpfp_test_root,
                          args.cfpfp_file_list,
                          transform=transform)
    cfpfploader = torch.utils.data.DataLoader(cfpfpdataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=4,
                                              drop_last=False)

    # define backbone and margin layer
    if args.backbone == 'MobileFace':
        net = MobileFaceNet()
    elif args.backbone == 'Res50':
        net = ResNet50()
    elif args.backbone == 'Res101':
        net = ResNet101()
    elif args.backbone == 'Res50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
    elif args.backbone == 'SERes50_IR':
        net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir')
    elif args.backbone == 'SphereNet':
        net = SphereNet(num_layers=64, feature_dim=args.feature_dim)
    else:
        print(args.backbone, ' is not available!')

    if args.margin_type == 'ArcFace':
        margin = ArcMarginProduct(args.feature_dim,
                                  trainset.class_nums,
                                  s=args.scale_size)
    elif args.margin_type == 'CosFace':
        pass
    elif args.margin_type == 'SphereFace':
        pass
    elif args.margin_type == 'InnerProduct':
        margin = InnerProduct(args.feature_dim, trainset.class_nums)
    else:
        print(args.margin_type, 'is not available!')

    if args.resume:
        print('resume the model parameters from: ', args.net_path,
              args.margin_path)
        net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
        margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])

    # define optimizers for different layer

    criterion_classi = torch.nn.CrossEntropyLoss().to(device)
    optimizer_classi = optim.SGD([{
        'params': net.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': margin.parameters(),
        'weight_decay': 5e-4
    }],
                                 lr=0.1,
                                 momentum=0.9,
                                 nesterov=True)
    scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi,
                                                milestones=[20, 35, 45],
                                                gamma=0.1)

    if multi_gpus:
        net = DataParallel(net).to(device)
        margin = DataParallel(margin).to(device)
    else:
        net = net.to(device)
        margin = margin.to(device)

    best_lfw_acc = 0.0
    best_lfw_iters = 0
    best_agedb30_acc = 0.0
    best_agedb30_iters = 0
    best_cfp_fp_acc = 0.0
    best_cfp_fp_iters = 0
    total_iters = 0
    vis = Visualizer(env='softmax_train')
    for epoch in range(1, args.total_epoch + 1):
        scheduler_classi.step()
        # train model
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()

        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            feature = net(img)
            output = margin(feature)
            loss_classi = criterion_classi(output, label)
            total_loss = loss_classi

            optimizer_classi.zero_grad()
            total_loss.backward()
            optimizer_classi.step()

            total_iters += 1
            # print train information
            if total_iters % 100 == 0:
                #current training accuracy
                _, predict = torch.max(output.data, 1)
                total = label.size(0)
                correct = (np.array(predict) == np.array(label.data)).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()
                vis.plot_curves({'train loss': loss_classi.item()},
                                iters=total_iters,
                                title='train loss',
                                xlabel='iters',
                                ylabel='train loss')
                vis.plot_curves({'train accuracy': correct / total},
                                iters=total_iters,
                                title='train accuracy',
                                xlabel='iters',
                                ylabel='train accuracy')
                print(
                    "Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                    .format(total_iters, epoch, loss_classi.item(),
                            correct / total, time_cur,
                            scheduler_classi.get_lr()[0]))
            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                    margin_state_dict = margin.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                    margin_state_dict = margin.state_dict()

                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': net_state_dict
                    },
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': margin_state_dict
                    },
                    os.path.join(save_dir,
                                 'Iter_%06d_margin.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0:
                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device,
                                    lfwdataset, lfwloader)
                lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                _print('LFW Ave Accuracy: {:.4f}'.format(
                    np.mean(lfw_accs) * 100))
                if best_lfw_acc < np.mean(lfw_accs) * 100:
                    best_lfw_acc = np.mean(lfw_accs) * 100
                    best_lfw_iters = total_iters
                # test model on AgeDB30
                getFeatureFromTorch('./result/cur_agedb30_result.mat', net,
                                    device, agedbdataset, agedbloader)
                age_accs = evaluation_10_fold(
                    './result/cur_agedb30_result.mat')
                _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(
                    np.mean(age_accs) * 100))
                if best_agedb30_acc < np.mean(age_accs) * 100:
                    best_agedb30_acc = np.mean(age_accs) * 100
                    best_agedb30_iters = total_iters
                # test model on CFP-FP
                getFeatureFromTorch('./result/cur_cfpfp_result.mat', net,
                                    device, cfpfpdataset, cfpfploader)
                cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
                _print('CFP-FP Ave Accuracy: {:.4f}'.format(
                    np.mean(cfp_accs) * 100))
                if best_cfp_fp_acc < np.mean(cfp_accs) * 100:
                    best_cfp_fp_acc = np.mean(cfp_accs) * 100
                    best_cfp_fp_iters = total_iters
                _print(
                    'Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
                    .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                            best_agedb30_iters, best_cfp_fp_acc,
                            best_cfp_fp_iters))
                vis.plot_curves(
                    {
                        'lfw': np.mean(lfw_accs),
                        'agedb-30': np.mean(age_accs),
                        'cfp-fp': np.mean(cfp_accs)
                    },
                    iters=total_iters,
                    title='test accuracy',
                    xlabel='iters',
                    ylabel='test accuracy')
                net.train()

    _print(
        'Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'
        .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc,
                best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
    print('finishing training')
Exemplo n.º 3
0
def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # log init
    save_dir = os.path.join(
        args.save_dir, args.model_pre + args.backbone.upper() + '_' +
        datetime.now().strftime('%Y%m%d_%H%M%S'))
    if os.path.exists(save_dir):
        raise NameError('model dir exists!')
    os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    # dataset loader
    if not args.use_gray:
        transform = transforms.Compose([
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5,
                                      0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        ])

    # validation dataset
    trainset = CASIAWebFace(args.train_root,
                            args.train_file_list,
                            transform=transform,
                            use_gray=args.use_gray)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              drop_last=False)
    # test dataset
    lfwdataset = LFW(args.lfw_test_root,
                     args.lfw_file_list,
                     transform=transform,
                     use_gray=args.use_gray)
    lfwloader = torch.utils.data.DataLoader(lfwdataset,
                                            batch_size=128,
                                            shuffle=False,
                                            num_workers=4,
                                            drop_last=False)

    # define backbone and margin layer
    in_channels = 1 if args.use_gray else 3
    if args.backbone == 'MobileFace':
        net = MobileFaceNet()
    elif args.backbone == 'Res50_IR':
        net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir')
    elif args.backbone == 'SERes50_IR':
        net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
    elif args.backbone == 'Res100_IR':
        net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
    elif args.backbone == 'SERes100_IR':
        net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
    elif args.backbone == 'Attention_56':
        net = ResidualAttentionNet_56(feature_dim=args.feature_dim)
    elif args.backbone == 'Attention_92':
        net = ResidualAttentionNet_92(feature_dim=args.feature_dim)
    elif args.backbone == 'SmallVGG':
        net = SmallVGG(in_channels, args.feature_dim, alpha=0.5)
    else:
        print(args.backbone, ' is not available!')
        exit(-1)
    calc_flops(net, in_channels, 112, 112)

    if args.margin_type == 'ArcFace':
        margin = ArcMarginProduct(args.feature_dim,
                                  trainset.class_nums,
                                  s=args.scale_size)
    elif args.margin_type == 'MultiMargin':
        margin = MultiMarginProduct(args.feature_dim,
                                    trainset.class_nums,
                                    s=args.scale_size)
    elif args.margin_type == 'CosFace':
        margin = CosineMarginProduct(args.feature_dim,
                                     trainset.class_nums,
                                     s=args.scale_size)
    elif args.margin_type == 'Softmax':
        margin = InnerProduct(args.feature_dim, trainset.class_nums)
    elif args.margin_type == 'SphereFace':
        pass
    else:
        print(args.margin_type, 'is not available!')

    if args.resume:
        print('resume the model parameters from: ', args.net_path,
              args.margin_path)
        net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
        margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])

    # define optimizers for different layer
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer_ft = optim.SGD([{
        'params': net.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': margin.parameters(),
        'weight_decay': 5e-4
    }],
                             lr=0.1,
                             momentum=0.9,
                             nesterov=True)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft,
                                                milestones=[6, 11, 16],
                                                gamma=0.1)

    if multi_gpus:
        net = DataParallel(net).to(device)
        margin = DataParallel(margin).to(device)
    else:
        net = net.to(device)
        margin = margin.to(device)

    best_lfw_acc = 0.0
    best_lfw_iters = 0
    total_iters = 0
    for epoch in range(1, args.total_epoch + 1):
        if epoch > 1:
            exp_lr_scheduler.step()
        # train model
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()

        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            optimizer_ft.zero_grad()

            raw_logits = net(img)
            output = margin(raw_logits, label)
            total_loss = criterion(output, label)
            total_loss.backward()
            optimizer_ft.step()

            total_iters += 1
            # print train information
            if total_iters % 100 == 0:
                # current training accuracy
                _, predict = torch.max(output.data, 1)
                total = label.size(0)
                correct = (np.array(predict.cpu()) == np.array(
                    label.data.cpu())).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()

                _print(
                    "Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}"
                    .format(total_iters, epoch, total_loss.item(),
                            correct / total, time_cur,
                            exp_lr_scheduler.get_last_lr()[0]))

            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                    margin_state_dict = margin.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                    margin_state_dict = margin.state_dict()
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': net_state_dict
                    },
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
                torch.save(
                    {
                        'iters': total_iters,
                        'net_state_dict': margin_state_dict
                    },
                    os.path.join(save_dir,
                                 'Iter_%06d_margin.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0:

                # test model on lfw
                net.eval()
                getFeatureFromTorch('./result/cur_lfw_result.mat', net, device,
                                    lfwdataset, lfwloader)
                lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
                _print('LFW Ave Accuracy: {:.4f}'.format(
                    np.mean(lfw_accs) * 100))
                if best_lfw_acc <= np.mean(lfw_accs) * 100:
                    best_lfw_acc = np.mean(lfw_accs) * 100
                    best_lfw_iters = total_iters

                net.train()

    _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}'.format(
        best_lfw_acc, best_lfw_iters))
    print('finishing training')