コード例 #1
0
ファイル: test_face_dfdnet.py プロジェクト: tamwaiban/BasicSR
        'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'  # noqa: E501
    )
    parser.add_argument(
        '--landmark68_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'  # noqa: E501
    )

    args = parser.parse_args()
    if args.test_path.endswith('/'):  # solve when path ends with /
        args.test_path = args.test_path[:-1]
    result_root = f'results/DFDNet/{os.path.basename(args.test_path)}'

    # set up the DFDNet
    net = DFDNet(64, dict_path=args.dict_path).to(device)
    checkpoint = torch.load(args.model_path,
                            map_location=lambda storage, loc: storage)
    net.load_state_dict(checkpoint['params'])
    net.eval()

    save_crop_root = os.path.join(result_root, 'cropped_faces')
    save_inverse_affine_root = os.path.join(result_root, 'inverse_affine')
    os.makedirs(save_inverse_affine_root, exist_ok=True)
    save_restore_root = os.path.join(result_root, 'restored_faces')
    save_final_root = os.path.join(result_root, 'final_results')

    face_helper = FaceRestorationHelper(args.upscale_factor,
                                        args.face_template_path,
                                        out_size=512)
コード例 #2
0
            elif 'upsample4' in crt_k and 'body' in crt_k:
                ori_k = ori_k.replace('body', 'Model')

        else:
            print('unprocess key: ', crt_k)

        # replace
        if crt_net[crt_k].size() != ori_net[ori_k].size():
            raise ValueError('Wrong tensor size: \n'
                             f'crt_net: {crt_net[crt_k].size()}\n'
                             f'ori_net: {ori_net[ori_k].size()}')
        else:
            crt_net[crt_k] = ori_net[ori_k]

    return crt_net


if __name__ == '__main__':
    ori_net = torch.load(
        'experiments/pretrained_models/DFDNet/DFDNet_official_original.pth')
    dfd_net = DFDNet(
        64,
        dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth')
    crt_net = dfd_net.state_dict()
    crt_net_params = convert_net(ori_net, crt_net)

    torch.save(dict(params=crt_net_params),
               'experiments/pretrained_models/DFDNet/DFDNet_official.pth',
               _use_new_zipfile_serialization=False)