'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat' # noqa: E501 ) parser.add_argument( '--landmark68_path', type=str, default= # noqa: E251 'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat' # noqa: E501 ) args = parser.parse_args() if args.test_path.endswith('/'): # solve when path ends with / args.test_path = args.test_path[:-1] result_root = f'results/DFDNet/{os.path.basename(args.test_path)}' # set up the DFDNet net = DFDNet(64, dict_path=args.dict_path).to(device) checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) net.load_state_dict(checkpoint['params']) net.eval() save_crop_root = os.path.join(result_root, 'cropped_faces') save_inverse_affine_root = os.path.join(result_root, 'inverse_affine') os.makedirs(save_inverse_affine_root, exist_ok=True) save_restore_root = os.path.join(result_root, 'restored_faces') save_final_root = os.path.join(result_root, 'final_results') face_helper = FaceRestorationHelper(args.upscale_factor, args.face_template_path, out_size=512)
elif 'upsample4' in crt_k and 'body' in crt_k: ori_k = ori_k.replace('body', 'Model') else: print('unprocess key: ', crt_k) # replace if crt_net[crt_k].size() != ori_net[ori_k].size(): raise ValueError('Wrong tensor size: \n' f'crt_net: {crt_net[crt_k].size()}\n' f'ori_net: {ori_net[ori_k].size()}') else: crt_net[crt_k] = ori_net[ori_k] return crt_net if __name__ == '__main__': ori_net = torch.load( 'experiments/pretrained_models/DFDNet/DFDNet_official_original.pth') dfd_net = DFDNet( 64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth') crt_net = dfd_net.state_dict() crt_net_params = convert_net(ori_net, crt_net) torch.save(dict(params=crt_net_params), 'experiments/pretrained_models/DFDNet/DFDNet_official.pth', _use_new_zipfile_serialization=False)