def train(args): # gpu init multi_gpus = False if len(args.gpus.split(',')) > 1: multi_gpus = True os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # log init save_dir = os.path.join( args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) if os.path.exists(save_dir): raise NameError('model dir exists!') os.makedirs(save_dir) logging = init_log(save_dir) _print = logging.info # dataset loader transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) # validation dataset trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=False) # test dataset lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform) lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform) agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform) cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) # define backbone and margin layer if args.backbone == 'MobileFace': net = MobileFaceNet() elif args.backbone == 'Res50': net = ResNet50() elif args.backbone == 'Res101': net = ResNet101() elif args.backbone == 'Res50_IR': net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir') elif args.backbone == 'SERes50_IR': net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir') elif args.backbone == 'SphereNet': net = SphereNet(num_layers=64, feature_dim=args.feature_dim) else: print(args.backbone, ' is not available!') if args.margin_type == 'ArcFace': margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) elif args.margin_type == 'CosFace': pass elif args.margin_type == 'SphereFace': pass elif args.margin_type == 'InnerProduct': margin = InnerProduct(args.feature_dim, trainset.class_nums) else: print(args.margin_type, 'is not available!') if args.resume: print('resume the model parameters from: ', args.net_path, args.margin_path) net.load_state_dict(torch.load(args.net_path)['net_state_dict']) margin.load_state_dict(torch.load(args.margin_path)['net_state_dict']) # define optimizers for different layer criterion_classi = torch.nn.CrossEntropyLoss().to(device) optimizer_classi = optim.SGD([{ 'params': net.parameters(), 'weight_decay': 5e-4 }, { 'params': margin.parameters(), 'weight_decay': 5e-4 }], lr=0.1, momentum=0.9, nesterov=True) scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi, milestones=[20, 35, 45], gamma=0.1) if multi_gpus: net = DataParallel(net).to(device) margin = DataParallel(margin).to(device) else: net = net.to(device) margin = margin.to(device) best_lfw_acc = 0.0 best_lfw_iters = 0 best_agedb30_acc = 0.0 best_agedb30_iters = 0 best_cfp_fp_acc = 0.0 best_cfp_fp_iters = 0 total_iters = 0 vis = Visualizer(env='softmax_train') for epoch in range(1, args.total_epoch + 1): scheduler_classi.step() # train model _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) net.train() since = time.time() for data in trainloader: img, label = data[0].to(device), data[1].to(device) feature = net(img) output = margin(feature) loss_classi = criterion_classi(output, label) total_loss = loss_classi optimizer_classi.zero_grad() total_loss.backward() optimizer_classi.step() total_iters += 1 # print train information if total_iters % 100 == 0: #current training accuracy _, predict = torch.max(output.data, 1) total = label.size(0) correct = (np.array(predict) == np.array(label.data)).sum() time_cur = (time.time() - since) / 100 since = time.time() vis.plot_curves({'train loss': loss_classi.item()}, iters=total_iters, title='train loss', xlabel='iters', ylabel='train loss') vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters', ylabel='train accuracy') print( "Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}" .format(total_iters, epoch, loss_classi.item(), correct / total, time_cur, scheduler_classi.get_lr()[0])) # save model if total_iters % args.save_freq == 0: msg = 'Saving checkpoint: {}'.format(total_iters) _print(msg) if multi_gpus: net_state_dict = net.module.state_dict() margin_state_dict = margin.module.state_dict() else: net_state_dict = net.state_dict() margin_state_dict = margin.state_dict() if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save( { 'iters': total_iters, 'net_state_dict': net_state_dict }, os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters)) torch.save( { 'iters': total_iters, 'net_state_dict': margin_state_dict }, os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters)) # test accuracy if total_iters % args.test_freq == 0: # test model on lfw net.eval() getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat') _print('LFW Ave Accuracy: {:.4f}'.format( np.mean(lfw_accs) * 100)) if best_lfw_acc < np.mean(lfw_accs) * 100: best_lfw_acc = np.mean(lfw_accs) * 100 best_lfw_iters = total_iters # test model on AgeDB30 getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader) age_accs = evaluation_10_fold( './result/cur_agedb30_result.mat') _print('AgeDB-30 Ave Accuracy: {:.4f}'.format( np.mean(age_accs) * 100)) if best_agedb30_acc < np.mean(age_accs) * 100: best_agedb30_acc = np.mean(age_accs) * 100 best_agedb30_iters = total_iters # test model on CFP-FP getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader) cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat') _print('CFP-FP Ave Accuracy: {:.4f}'.format( np.mean(cfp_accs) * 100)) if best_cfp_fp_acc < np.mean(cfp_accs) * 100: best_cfp_fp_acc = np.mean(cfp_accs) * 100 best_cfp_fp_iters = total_iters _print( 'Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}' .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) vis.plot_curves( { 'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs) }, iters=total_iters, title='test accuracy', xlabel='iters', ylabel='test accuracy') net.train() _print( 'Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}' .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) print('finishing training')
def train(args): # gpu init multi_gpus = False if len(args.gpus.split(',')) > 1: multi_gpus = True os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # log init save_dir = os.path.join( args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) if os.path.exists(save_dir): raise NameError('model dir exists!') os.makedirs(save_dir) logging = init_log(save_dir) _print = logging.info # dataset loader transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) # validation dataset trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=False) # test dataset lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform) lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) # define backbone and margin layer if args.backbone == 'MobileFace': net = MobileFaceNet(feature_dim=args.feature_dim) elif args.backbone == 'Res50': net = ResNet50() elif args.backbone == 'Res101': net = ResNet101() elif args.backbone == 'Res50_IR': net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir') elif args.backbone == 'SERes50_IR': net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir') elif args.backbone == 'SphereNet': net = SphereNet(num_layers=64, feature_dim=args.feature_dim) else: print(args.backbone, ' is not available!') if args.margin_type == 'ArcFace': margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) elif args.margin_type == 'CosFace': pass elif args.margin_type == 'SphereFace': pass elif args.margin_type == 'InnerProduct': margin = InnerProduct(args.feature_dim, trainset.class_nums) else: print(args.margin_type, 'is not available!') if args.resume: print('resume the model parameters from: ', args.net_path, args.margin_path) net.load_state_dict(torch.load(args.net_path)['net_state_dict']) margin.load_state_dict(torch.load(args.margin_path)['net_state_dict']) # define optimizers for different layers criterion_classi = torch.nn.CrossEntropyLoss().to(device) optimizer_classi = optim.SGD([{ 'params': net.parameters(), 'weight_decay': 5e-4 }, { 'params': margin.parameters(), 'weight_decay': 5e-4 }], lr=0.1, momentum=0.9, nesterov=True) scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi, milestones=[35, 60, 85], gamma=0.1) criterion_center = AgentCenterLoss(trainset.class_nums, args.feature_dim, args.scale_size).to(device) optimizer_center = optim.SGD(criterion_center.parameters(), lr=0.5) scheduler_center = lr_scheduler.MultiStepLR(optimizer_center, milestones=[35, 60, 85], gamma=0.1) if multi_gpus: net = DataParallel(net).to(device) margin = DataParallel(margin).to(device) else: net = net.to(device) margin = margin.to(device) best_lfw_acc = 0.0 best_lfw_iters = 0 total_iters = 0 for epoch in range(1, args.total_epoch + 1): scheduler_classi.step() scheduler_center.step() # train model _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) net.train() if args.plot: all_features, all_labels = [], [] since = time.time() for data in trainloader: img, label = data[0].to(device), data[1].to(device) feature = net(img) output = margin(feature) loss_classi = criterion_classi(output, label) loss_center = criterion_center(feature, label) total_loss = loss_classi + loss_center * args.weight_center optimizer_classi.zero_grad() optimizer_center.zero_grad() total_loss.backward() optimizer_classi.step() # by doing so, weight_cent would not impact on the learning of centers #for param in criterion_center.parameters(): # param.grad.data *= (1. / args.weight_center) optimizer_center.step() total_iters += 1 if args.plot: feat = feature.data.cpu().numpy() #for i in range(feat.shape[0]): # feat[i] = feat[i] / np.sqrt((np.dot(feat[i], feat[i]))) all_features.append(feat) all_labels.append(label.data.cpu().numpy()) # print train information if total_iters % 10 == 0: # current training accuracy _, predict = torch.max(output.data, 1) total = label.size(0) correct = (np.array(predict.cpu()) == np.array( label.data.cpu())).sum() time_cur = (time.time() - since) / 10 since = time.time() print( "Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, loss_center: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}" .format(total_iters, epoch, loss_classi.item(), loss_center.item(), correct / total, time_cur, scheduler_classi.get_lr()[0])) # save model if total_iters % args.save_freq == 0: msg = 'Saving checkpoint: {}'.format(total_iters) _print(msg) if multi_gpus: net_state_dict = net.module.state_dict() margin_state_dict = margin.module.state_dict() else: net_state_dict = net.state_dict() margin_state_dict = margin.state_dict() if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save( { 'iters': total_iters, 'net_state_dict': net_state_dict }, os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters)) torch.save( { 'iters': total_iters, 'net_state_dict': margin_state_dict }, os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters)) #torch.save({ # 'iters': total_iters, # 'net_state_dict': criterion_center.state_dict()}, # os.path.join(save_dir, 'Iter_%06d_center.ckpt' % total_iters)) # test accuracy if total_iters % args.test_freq == 0: # test model on lfw net.eval() getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat') _print('LFW Ave Accuracy: {:.4f}'.format( np.mean(lfw_accs) * 100)) if best_lfw_acc < np.mean(lfw_accs) * 100: best_lfw_acc = np.mean(lfw_accs) * 100 best_lfw_iters = total_iters net.train() if args.plot: all_features = np.concatenate(all_features, 0) all_labels = np.concatenate(all_labels, 0) plot_features(all_features, all_labels, trainset.class_nums, epoch, save_dir) _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}'.format( best_lfw_acc, best_lfw_iters)) print('finishing training')
def train(args): # gpu init multi_gpus = False if len(args.gpus.split(',')) > 1: multi_gpus = True os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # log init save_dir = os.path.join( args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) if os.path.exists(save_dir): raise NameError('model dir exists!') os.makedirs(save_dir) logging = init_log(save_dir) _print = logging.info # dataset loader transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) # validation dataset trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=12, drop_last=False) # test dataset lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform) lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform) agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform) cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) # define backbone and margin layer if args.backbone == 'MobileFace': net = MobileFaceNet() elif args.backbone is 'Res50': net = ResNet50() elif args.backbone == 'Res101': net = ResNet101() elif args.backbone == 'Res50_IR': net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir') elif args.backbone == 'SERes50_IR': net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir') else: print(args.backbone, ' is not available!') if args.margin_type == 'arcface': margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) elif args.margin_type == 'cosface': pass elif args.margin_type == 'sphereface': pass else: print(args.margin_type, 'is not available!') if args.resume: print('resume the model parameters from: ', args.resume) ckpt = torch.load(args.resume) net.load_state_dict(ckpt['net_state_dict']) # define optimizers for different layer ignored_params_id = [] ignored_params_id += list(map(id, margin.weight)) prelu_params = [] for m in net.modules(): if isinstance(m, nn.PReLU): ignored_params_id += list(map(id, m.parameters())) prelu_params += m.parameters() base_params = filter(lambda p: id(p) not in ignored_params_id, net.parameters()) optimizer_ft = optim.SGD([{ 'params': base_params, 'weight_decay': 5e-4 }, { 'params': margin.weight, 'weight_decay': 5e-4 }, { 'params': prelu_params, 'weight_decay': 0.0 }], lr=0.1, momentum=0.9, nesterov=True) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[10, 18, 25], gamma=0.1) if multi_gpus: net = DataParallel(net).to(device) margin = DataParallel(margin).to(device) else: net = net.to(device) margin = margin.to(device) criterion = torch.nn.CrossEntropyLoss().to(device) best_lfw_acc = 0.0 best_lfw_iters = 0 best_agedb30_acc = 0.0 best_agedb30_iters = 0 best_cfp_fp_acc = 0.0 best_cfp_fp_iters = 0 total_iters = 0 for epoch in range(1, args.total_epoch + 1): exp_lr_scheduler.step() # train model _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) net.train() since = time.time() for data in trainloader: img, label = data[0].to(device), data[1].to(device) batch_size = img.size(0) optimizer_ft.zero_grad() raw_logits = net(img) output = margin(raw_logits, label) total_loss = criterion(output, label) total_loss.backward() optimizer_ft.step() total_iters += 1 # print train information if total_iters % 100 == 0: time_cur = (time.time() - since) / 100 since = time.time() print( "Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, time: {:.2f} s/iter, learning rate: {}" .format(total_iters, epoch, total_loss.item(), time_cur, exp_lr_scheduler.get_lr()[0])) # save model if total_iters % args.save_freq == 0: msg = 'Saving checkpoint: {}'.format(total_iters) _print(msg) if multi_gpus: net_state_dict = net.module.state_dict() else: net_state_dict = net.state_dict() if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save( { 'iters': total_iters, 'net_state_dict': net_state_dict }, os.path.join(save_dir, 'Iter_%06d.ckpt' % total_iters)) # test accuracy if total_iters % args.test_freq == 0: # test model on lfw net.eval() getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) accs = evaluation_10_fold('./result/cur_lfw_result.mat') _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(accs) * 100)) if best_lfw_acc < np.mean(accs) * 100: best_lfw_acc = np.mean(accs) * 100 best_lfw_iters = total_iters # test model on AgeDB30 getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader) accs = evaluation_10_fold('./result/cur_agedb30_result.mat') _print('AgeDB-30 Ave Accuracy: {:.4f}'.format( np.mean(accs) * 100)) if best_agedb30_acc < np.mean(accs) * 100: best_agedb30_acc = np.mean(accs) * 100 best_agedb30_iters = total_iters # test model on CFP-FP getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader) accs = evaluation_10_fold('./result/cur_cfpfp_result.mat') _print('CFP-FP Ave Accuracy: {:.4f}'.format( np.mean(accs) * 100)) if best_cfp_fp_acc < np.mean(accs) * 100: best_cfp_fp_acc = np.mean(accs) * 100 best_cfp_fp_iters = total_iters _print( 'Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}' .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) net.train() _print( 'Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}' .format(best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) print('finishing training')
def train(args): # gpu init multi_gpus = False if len(args.gpus.split(',')) > 1: multi_gpus = True os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # log init save_dir = os.path.join( args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) if os.path.exists(save_dir): raise NameError('model dir exists!') os.makedirs(save_dir) logging = init_log(save_dir) _print = logging.info # dataset loader if not args.use_gray: transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] ]) else: transform = transforms.Compose([ transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] ]) # validation dataset trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform, use_gray=args.use_gray) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=False) # test dataset lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform, use_gray=args.use_gray) lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, shuffle=False, num_workers=4, drop_last=False) # define backbone and margin layer in_channels = 1 if args.use_gray else 3 if args.backbone == 'MobileFace': net = MobileFaceNet() elif args.backbone == 'Res50_IR': net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir') elif args.backbone == 'SERes50_IR': net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se') elif args.backbone == 'Res100_IR': net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir') elif args.backbone == 'SERes100_IR': net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se') elif args.backbone == 'Attention_56': net = ResidualAttentionNet_56(feature_dim=args.feature_dim) elif args.backbone == 'Attention_92': net = ResidualAttentionNet_92(feature_dim=args.feature_dim) elif args.backbone == 'SmallVGG': net = SmallVGG(in_channels, args.feature_dim, alpha=0.5) else: print(args.backbone, ' is not available!') exit(-1) calc_flops(net, in_channels, 112, 112) if args.margin_type == 'ArcFace': margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) elif args.margin_type == 'MultiMargin': margin = MultiMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) elif args.margin_type == 'CosFace': margin = CosineMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) elif args.margin_type == 'Softmax': margin = InnerProduct(args.feature_dim, trainset.class_nums) elif args.margin_type == 'SphereFace': pass else: print(args.margin_type, 'is not available!') if args.resume: print('resume the model parameters from: ', args.net_path, args.margin_path) net.load_state_dict(torch.load(args.net_path)['net_state_dict']) margin.load_state_dict(torch.load(args.margin_path)['net_state_dict']) # define optimizers for different layer criterion = torch.nn.CrossEntropyLoss().to(device) optimizer_ft = optim.SGD([{ 'params': net.parameters(), 'weight_decay': 5e-4 }, { 'params': margin.parameters(), 'weight_decay': 5e-4 }], lr=0.1, momentum=0.9, nesterov=True) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[6, 11, 16], gamma=0.1) if multi_gpus: net = DataParallel(net).to(device) margin = DataParallel(margin).to(device) else: net = net.to(device) margin = margin.to(device) best_lfw_acc = 0.0 best_lfw_iters = 0 total_iters = 0 for epoch in range(1, args.total_epoch + 1): if epoch > 1: exp_lr_scheduler.step() # train model _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) net.train() since = time.time() for data in trainloader: img, label = data[0].to(device), data[1].to(device) optimizer_ft.zero_grad() raw_logits = net(img) output = margin(raw_logits, label) total_loss = criterion(output, label) total_loss.backward() optimizer_ft.step() total_iters += 1 # print train information if total_iters % 100 == 0: # current training accuracy _, predict = torch.max(output.data, 1) total = label.size(0) correct = (np.array(predict.cpu()) == np.array( label.data.cpu())).sum() time_cur = (time.time() - since) / 100 since = time.time() _print( "Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}" .format(total_iters, epoch, total_loss.item(), correct / total, time_cur, exp_lr_scheduler.get_last_lr()[0])) # save model if total_iters % args.save_freq == 0: msg = 'Saving checkpoint: {}'.format(total_iters) _print(msg) if multi_gpus: net_state_dict = net.module.state_dict() margin_state_dict = margin.module.state_dict() else: net_state_dict = net.state_dict() margin_state_dict = margin.state_dict() if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save( { 'iters': total_iters, 'net_state_dict': net_state_dict }, os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters)) torch.save( { 'iters': total_iters, 'net_state_dict': margin_state_dict }, os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters)) # test accuracy if total_iters % args.test_freq == 0: # test model on lfw net.eval() getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat') _print('LFW Ave Accuracy: {:.4f}'.format( np.mean(lfw_accs) * 100)) if best_lfw_acc <= np.mean(lfw_accs) * 100: best_lfw_acc = np.mean(lfw_accs) * 100 best_lfw_iters = total_iters net.train() _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}'.format( best_lfw_acc, best_lfw_iters)) print('finishing training')