acc = val.val_sesface(-1, 96, 112, 32, args.device, fnet, net, index=7) net.setVal(False) if acc > best_acc: best_acc = acc save_network_for_backup(args, net, optimizer, scheduler, epoch_id) # Save the final SR model save_network(args, net, epochs) def eval(): net = SeSface() fnet = SphereFace(type='teacher', pretrain=torch.load('../../pretrained/sface.pth')) fnet.to(args.device) checkpoint = torch.load(args.model_file) net.load_state_dict(checkpoint['net']) # 加载模型可学习参数 net.to(args.device) net.setVal(True) acc = val.val_sesface(-1, 96, 112, 32, args.device, fnet, net, index=16) acc = val.val_sesface(-1, 96, 112, 32, args.device, fnet, net, index=8) acc = val.val_sesface(-1, 96, 112, 32, args.device, fnet, net, index=4) args = train_args.get_args() if __name__ == '__main__': if args.type == 'train' or args.type == 'finetune': main() else: eval()
def main(): args = train_args.get_args() dataloader = get_loader(args, 'celeba') train_iter = iter(dataloader) ## Setup FNet fnet = getattr(FNet, 'sface')() fnet.load_state_dict(torch.load('../pretrained/sface.pth')) freeze(fnet) fnet.to(args.device) ## Setup SRNet srnet = SRNet.edsr() srnet.to(args.device) if len(args.gpu_ids) > 1: srnet = nn.DataParallel(srnet) optimizer = optim.Adam(srnet.parameters(), lr=args.lr, betas=(0.9, 0.999)) scheduler = StepLR(optimizer, step_size=args.decay_step, gamma=0.5) criterion_pixel = nn.L1Loss() losses = ['loss', 'loss_pixel', 'loss_feature', 'lr', 'index'] pbar = tqdm(range(1, args.iterations + 1), ncols=0) for total_steps in pbar: srnet.train() scheduler.step() # update learning rate lr = optimizer.param_groups[0]['lr'] try: inputs = next(train_iter) except: train_iter = iter(dataloader) inputs = next(train_iter) index = np.random.randint(2, 4 + 1) lr_face = inputs['down{}'.format(2 ** index)].to(args.device) mr_face = inputs['down{}'.format(2 ** (index - 2))].to(args.device) if index == 2: hr_face = mr_face else: hr_face = inputs['down1'].to(args.device) sr_face = srnet(lr_face) loss_pixel = criterion_pixel(sr_face, mr_face.detach()) loss = loss_pixel # Feature loss sr_face_up = nn.functional.interpolate(sr_face, size=(112, 96), mode='bilinear', align_corners=False) if args.lamb_id > 0: sr_face_feature = fnet(tensor2SFTensor(sr_face_up)) hr_face_feature = fnet(tensor2SFTensor(hr_face)).detach() loss_feature = 1 - torch.nn.CosineSimilarity()(sr_face_feature, hr_face_feature) loss_feature = loss_feature.mean() loss += args.lamb_id * loss_feature optimizer.zero_grad() loss.backward() optimizer.step() # display description = "" for name in losses: try: value = float(eval(name)) if name == 'index': description += '{}: {:.0f} '.format(name, value) elif name == 'lr': description += '{}: {:.3e} '.format(name, value) else: description += '{}: {:.3f} '.format(name, value) except: continue pbar.set_description(desc=description) # Save the final SR model save_network(args, srnet, args.iterations)