Exemple #1
0
 def _get_data(self, mode):
     """ Check and download Dataset """
     dl_paths = {}
     version = self.config.get("version", "3.0.0")
     if version not in ["1.0.0", "2.0.0", "3.0.0"]:
         raise ValueError("Unsupported version: %s" % version)
     dl_paths["version"] = version
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     for k, v in self.cnn_dailymail.items():
         dir_path = os.path.join(default_root, k)
         if not os.path.exists(dir_path):
             get_path_from_url(v["url"], default_root, v["md5"])
         unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                  .trainer_endpoints[:])
         if ParallelEnv().current_endpoint in unique_endpoints:
             file_num = len(os.listdir(os.path.join(dir_path, "stories")))
             if file_num != v["file_num"]:
                 logger.warning(
                     "Number of %s stories is %d != %d, decompress again." %
                     (k, file_num, v["file_num"]))
                 shutil.rmtree(os.path.join(dir_path, "stories"))
                 _decompress(
                     os.path.join(default_root, os.path.basename(v["url"])))
         dl_paths[k] = dir_path
     filename, url, data_hash = self.SPLITS[mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (data_hash and
                                         not md5file(fullname) == data_hash):
         get_path_from_url(url, default_root, data_hash)
     dl_paths[mode] = fullname
     return dl_paths
Exemple #2
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash, url, zipfile_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if mode == 'train':
            if not os.path.exists(fullname):
                get_path_from_url(url, default_root, zipfile_hash)
            unique_endpoints = _get_unique_endpoints(
                ParallelEnv().trainer_endpoints[:])
            if ParallelEnv().current_endpoint in unique_endpoints:
                file_num = len(os.listdir(fullname))
                if file_num != len(ALL_LANGUAGES):
                    logger.warning(
                        "Number of train files is %d != %d, decompress again."
                        % (file_num, len(ALL_LANGUAGES)))
                    shutil.rmtree(fullname)
                    _decompress(
                        os.path.join(default_root, os.path.basename(url)))
        else:
            if not os.path.exists(fullname) or (
                    data_hash and not md5file(fullname) == data_hash):
                get_path_from_url(url, default_root, zipfile_hash)

        return fullname
Exemple #3
0
    def test_uncompress_result(self):
        results = [
            [
                "files/single_dir/file1", "files/single_dir/file2",
                "files/single_file.pdparams"
            ],
            ["single_dir/file1", "single_dir/file2"],
            ["single_file.pdparams"],
        ]
        tar_urls = [
            "https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
            "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar",
            "https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar",
        ]

        for url, uncompressd_res in zip(tar_urls, results):
            uncompressed_path = get_path_from_url(url, root_dir='./test_tar')
            self.assertTrue(all([os.path.exists(os.path.join("./test_tar", filepath)) \
                                 for filepath in uncompressd_res]))

        zip_urls = [
            "https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
            "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip",
            "https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip",
        ]
        for url, uncompressd_res in zip(zip_urls, results):
            uncompressed_path = get_path_from_url(url, root_dir='./test_zip')
            self.assertTrue(all([os.path.exists(os.path.join("./test_zip", filepath)) \
                                 for filepath in uncompressd_res]))
Exemple #4
0
 def _download_termtree(self, filename):
     default_root = os.path.join(MODEL_HOME, 'ernie-ctm')
     fullname = os.path.join(default_root, filename)
     url = URLS[filename]
     if not os.path.exists(fullname):
         get_path_from_url(url, default_root)
     return fullname
Exemple #5
0
 def _get_data(self, root, mode, **kwargs):
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     if self.version_2_with_negative:
         filename, data_hash = self.SPLITS['2.0'][mode]
     else:
         filename, data_hash = self.SPLITS['1.1'][mode]
     fullname = os.path.join(default_root,
                             filename) if root is None else os.path.join(
                                 os.path.expanduser(root), filename)
     if not os.path.exists(fullname) or (
             data_hash and not md5file(fullname) == data_hash):
         if root is not None:  # not specified, and no need to warn
             warnings.warn(
                 'md5 check failed for {}, download {} data to {}'.format(
                     filename, self.__class__.__name__, default_root))
         if mode == 'train':
             if self.version_2_with_negative:
                 fullname = get_path_from_url(
                     self.TRAIN_DATA_URL_V2,
                     os.path.join(default_root, 'v2'))
             else:
                 fullname = get_path_from_url(
                     self.TRAIN_DATA_URL_V1,
                     os.path.join(default_root, 'v1'))
         elif mode == 'dev':
             if self.version_2_with_negative:
                 fullname = get_path_from_url(
                     self.DEV_DATA_URL_V2, os.path.join(default_root, 'v2'))
             else:
                 fullname = get_path_from_url(
                     self.DEV_DATA_URL_V1, os.path.join(default_root, 'v1'))
     self.full_path = fullname
Exemple #6
0
    def _get_data(self, mode, **kwargs):
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[
            mode]
        src_fullname = os.path.join(default_root, src_filename)
        tgt_fullname = os.path.join(default_root, tgt_filename)

        (bpe_vocab_filename, bpe_vocab_hash), (sub_vocab_filename,
                                               sub_vocab_hash) = self.VOCAB_INFO
        bpe_vocab_fullname = os.path.join(default_root, bpe_vocab_filename)
        sub_vocab_fullname = os.path.join(default_root, sub_vocab_filename)

        if (not os.path.exists(src_fullname) or
            (src_data_hash and not md5file(src_fullname) == src_data_hash)) or (
                not os.path.exists(tgt_fullname) or
                (tgt_data_hash and
                 not md5file(tgt_fullname) == tgt_data_hash)) or (
                     not os.path.exists(bpe_vocab_fullname) or
                     (bpe_vocab_hash and
                      not md5file(bpe_vocab_fullname) == bpe_vocab_hash)) or (
                          not os.path.exists(sub_vocab_fullname) or
                          (sub_vocab_hash and
                           not md5file(sub_vocab_fullname) == sub_vocab_hash)):
            get_path_from_url(self.URL, default_root, self.MD5)

        return src_fullname, tgt_fullname
Exemple #7
0
 def _get_data(self, mode, **kwargs):
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     filename, data_hash = self.SPLITS[mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (
             data_hash and not md5file(fullname) == data_hash):
         get_path_from_url(self.URL, default_root, self.MD5)
     return fullname
Exemple #8
0
    def __init__(self,
                 embedding_name=EMBEDDING_NAME_LIST[0],
                 unknown_token=UNK_TOKEN,
                 unknown_token_vector=None,
                 extended_vocab_path=None,
                 trainable=True,
                 keep_extended_vocab_only=False):
        vector_path = osp.join(EMBEDDING_HOME, embedding_name + ".npz")
        if not osp.exists(vector_path):
            # download
            url = EMBEDDING_URL_ROOT + "/" + embedding_name + ".tar.gz"
            get_path_from_url(url, EMBEDDING_HOME)

        logger.info("Loading token embedding...")
        vector_np = np.load(vector_path)
        self.embedding_dim = vector_np['embedding'].shape[1]
        self.unknown_token = unknown_token
        if unknown_token_vector is not None:
            unk_vector = np.array(unknown_token_vector).astype(
                paddle.get_default_dtype())
        else:
            unk_vector = np.random.normal(scale=0.02,
                                          size=self.embedding_dim).astype(
                                              paddle.get_default_dtype())
        pad_vector = np.array([0] * self.embedding_dim).astype(
            paddle.get_default_dtype())
        if extended_vocab_path is not None:
            embedding_table = self._extend_vocab(extended_vocab_path,
                                                 vector_np, pad_vector,
                                                 unk_vector,
                                                 keep_extended_vocab_only)
            trainable = True
        else:
            embedding_table = self._init_without_extend_vocab(
                vector_np, pad_vector, unk_vector)

        self.vocab = Vocab.from_dict(self._word_to_idx,
                                     unk_token=unknown_token,
                                     pad_token=PAD_TOKEN)
        self.num_embeddings = embedding_table.shape[0]
        # import embedding
        super(TokenEmbedding,
              self).__init__(self.num_embeddings,
                             self.embedding_dim,
                             padding_idx=self._word_to_idx[PAD_TOKEN])
        self.weight.set_value(embedding_table)
        self.set_trainable(trainable)
        logger.info("Finish loading embedding vector.")
        s = "Token Embedding info:\
             \nUnknown index: {}\
             \nUnknown token: {}\
             \nPadding index: {}\
             \nPadding token: {}\
             \nShape :{}".format(self._word_to_idx[self.unknown_token],
                                 self.unknown_token,
                                 self._word_to_idx[PAD_TOKEN], PAD_TOKEN,
                                 self.weight.shape)
        logger.info(s)
Exemple #9
0
 def _get_data(self, mode, **kwargs):
     builder_config = self.BUILDER_CONFIGS[self.name]
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     filename, url, data_hash = builder_config['splits'][mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (
             data_hash and not md5file(fullname) == data_hash):
         get_path_from_url(url, default_root, data_hash)
     return fullname
Exemple #10
0
def load_state_dict_from_url(url: str, path: str, md5: str = None):
    """
    Download and load a state dict from url
    """
    if not os.path.isdir(path):
        os.makedirs(path)

    download.get_path_from_url(url, path, md5)
    return load_state_dict(os.path.join(path, os.path.basename(url)))
Exemple #11
0
def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('-d',
                        '--data_dir',
                        help='directory to save data to',
                        type=str,
                        default='./')
    args = parser.parse_args(arguments)
    get_path_from_url(URL, args.data_dir)
Exemple #12
0
def download_and_decompress(archives: List[Dict[str, str]], path: str):
    """
    Download archieves and decompress to specific path.
    """
    for archive in archives:
        assert 'url' in archive and 'md5' in archive, \
            'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'

        logger.info(f'Downloading from: {archive["url"]}')
        download.get_path_from_url(archive['url'], path, archive['md5'])
Exemple #13
0
 def test_get_path_from_url(self):
     urls = [
         "https://paddle-hapi.bj.bcebos.com/unittest/files.tar",
         "https://paddle-hapi.bj.bcebos.com/unittest/files.zip",
         "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.tar",
         "https://paddle-hapi.bj.bcebos.com/unittest/single_dir.zip",
         "https://paddle-hapi.bj.bcebos.com/unittest/single_file.tar",
         "https://paddle-hapi.bj.bcebos.com/unittest/single_file.zip",
     ]
     for url in urls:
         get_path_from_url(url, root_dir='./test')
Exemple #14
0
def download_data(url):
    save_name = os.path.basename(url).split('.')[0]
    output_path = os.path.join(hubenv.DATA_HOME, save_name)

    if not os.path.exists(output_path):
        get_path_from_url(url, hubenv.DATA_HOME)

    def _wrapper(Dataset):
        return Dataset

    return _wrapper
Exemple #15
0
 def _get_data(self, mode, **kwargs):
     default_root = os.path.join(DATA_HOME, self.__class__.__name__, mode)
     meta_info_list = self.SPLITS[mode]
     fullnames = []
     for meta_info in meta_info_list:
         filename, data_hash, URL = meta_info
         fullname = os.path.join(default_root, filename)
         if not os.path.exists(fullname) or (
                 data_hash and not md5file(fullname) == data_hash):
             get_path_from_url(URL, default_root)
         fullnames.append(fullname)
     return fullnames
Exemple #16
0
def download_and_decompress(archives: List[Dict[str, str]], path: str):
    """
    Download archieves and decompress to specific path.
    """
    if not os.path.isdir(path):
        os.makedirs(path)

    for archive in archives:
        assert 'url' in archive and 'md5' in archive, \
            'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}'

        download.get_path_from_url(archive['url'], path, archive['md5'])
Exemple #17
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        builder_config = self.BUILDER_CONFIGS[self.name]
        default_root = os.path.join(DATA_HOME, 'SE-ABSA16_PHNS')
        filename, data_hash, _, _ = builder_config['splits'][mode]
        fullname = os.path.join(default_root, filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            url = builder_config['url']
            md5 = builder_config['md5']
            get_path_from_url(url, DATA_HOME, md5)

        return fullname
Exemple #18
0
def main(arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('-d',
                        '--data_dir',
                        help='directory to save data to',
                        type=str,
                        default='data')
    parser.add_argument(
        '-t',
        '--task',
        help='tasks to download data for as a comma separated string',
        type=str,
        default='ptb')
    args = parser.parse_args(arguments)
    get_path_from_url(URL[args.task], args.data_dir)
Exemple #19
0
 def _get_data(self, mode, **kwargs):
     """Downloads dataset."""
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     data_dir = os.path.join(default_root, "aclImdb", mode)
     if not os.path.exists(data_dir):
         path = get_path_from_url(self.URL, default_root, self.MD5)
     return data_dir
    def __init__(self):
        save_pth = get_path_from_url(BISENET_WEIGHT_URL, osp.split(osp.realpath(__file__))[0])

        self.net = FCN(num_classes=2, backbone=HRNet_W18())
        state_dict = paddle.load(save_pth)
        self.net.set_state_dict(state_dict)
        self.net.eval()
Exemple #21
0
    def _download_data(cls, mode="train", root=None):
        """Download dataset"""
        default_root = os.path.join(DATA_HOME, 'machine_translation',
                                    cls.__name__)
        src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
            mode]

        filename_list = [
            src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1]
        ]
        fullname_list = []
        for filename in filename_list:
            fullname = os.path.join(
                default_root, filename) if root is None else os.path.join(
                    os.path.expanduser(root), filename)
            fullname_list.append(fullname)

        data_hash_list = [
            src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
        ]
        for i, fullname in enumerate(fullname_list):
            if not os.path.exists(fullname) or (
                    data_hash_list[i]
                    and not md5file(fullname) == data_hash_list[i]):
                if root is not None:  # not specified, and no need to warn
                    warnings.warn(
                        'md5 check failed for {}, download {} data to {}'.
                        format(filename, cls.__name__, default_root))
                path = get_path_from_url(cls.URL, default_root, cls.MD5)
                return default_root
        return root if root is not None else default_root
Exemple #22
0
    def __init__(self,
                 input,
                 output,
                 weight_path=None,
                 colorization=False,
                 reference_dir=None,
                 mindim=360):
        self.input = input
        self.output = os.path.join(output, 'DeepRemaster')
        self.colorization = colorization
        self.reference_dir = reference_dir
        self.mindim = mindim

        if weight_path is None:
            weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL, cur_path)

        state_dict, _ = paddle.load(weight_path)

        self.modelR = NetworkR()
        self.modelR.load_dict(state_dict['modelR'])
        self.modelR.eval()
        if colorization:
            self.modelC = NetworkC()
            self.modelC.load_dict(state_dict['modelC'])
            self.modelC.eval()
    def __init__(self, device="cpu"):
        self.mapper = {
            0: 0,
            1: 1,
            2: 2,
            3: 3,
            4: 4,
            5: 5,
            6: 0,
            7: 11,
            8: 12,
            9: 0,
            10: 6,
            11: 8,
            12: 7,
            13: 9,
            14: 13,
            15: 0,
            16: 0,
            17: 10,
            18: 0
        }
        #self.dict = paddle.to_tensor(mapper)
        self.save_pth = get_path_from_url(BISENET_WEIGHT_URL,
                                          osp.split(osp.realpath(__file__))[0])

        self.net = BiSeNet(n_classes=19)

        self.transforms = T.Compose([
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
Exemple #24
0
    def __init__(self,
                 generator,
                 discriminator=None,
                 cycle_criterion=None,
                 idt_criterion=None,
                 gan_criterion=None,
                 l1_criterion=None,
                 l2_criterion=None,
                 pool_size=50,
                 direction='a2b',
                 lambda_a=10.,
                 lambda_b=10.,
                 is_train=True):
        """Initialize the PSGAN class.

        Parameters:
            cfg (dict)-- config of model.
        """
        super(MakeupModel, self).__init__()
        self.lambda_a = lambda_a
        self.lambda_b = lambda_b
        self.is_train = is_train
        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.nets['netG'] = build_generator(generator)
        init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)

        if self.is_train:  # define discriminators
            vgg = vgg16(pretrained=False)
            self.vgg = vgg.features
            cur_path = os.path.abspath(os.path.dirname(__file__))
            vgg_weight_path = get_path_from_url(VGGFACE_WEIGHT_URL, cur_path)
            param = paddle.load(vgg_weight_path)
            vgg.load_dict(param)

            self.nets['netD_A'] = build_discriminator(discriminator)
            self.nets['netD_B'] = build_discriminator(discriminator)
            init_weights(self.nets['netD_A'],
                         init_type='xavier',
                         init_gain=1.0)
            init_weights(self.nets['netD_B'],
                         init_type='xavier',
                         init_gain=1.0)

            # create image buffer to store previously generated images
            self.fake_A_pool = ImagePool(pool_size)
            self.fake_B_pool = ImagePool(pool_size)

            # define loss functions
            if gan_criterion:
                self.gan_criterion = build_criterion(gan_criterion)
            if cycle_criterion:
                self.cycle_criterion = build_criterion(cycle_criterion)
            if idt_criterion:
                self.idt_criterion = build_criterion(idt_criterion)
            if l1_criterion:
                self.l1_criterion = build_criterion(l1_criterion)
            if l2_criterion:
                self.l2_criterion = build_criterion(l2_criterion)
Exemple #25
0
 def _get_data(self, root, mode):
     default_root = os.path.join(DATA_HOME, 'lm')
     self.data_path = os.path.join(default_root,
                                   self.DATA_PATH) if root is None else root
     if not os.path.exists(self.data_path):
         path = get_path_from_url(self.DATA_URL, default_root)
         self.data_path = os.path.join(default_root, self.DATA_PATH)
Exemple #26
0
    def __init__(self, cfg):
        """Initialize the PSGAN class.

        Parameters:
            cfg (dict)-- config of model.
        """
        super(MakeupModel, self).__init__(cfg)

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.nets['netG'] = build_generator(cfg.model.generator)
        init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)

        if self.is_train:  # define discriminators
            vgg = vgg16(pretrained=False)
            self.vgg = vgg.features
            cur_path = os.path.abspath(os.path.dirname(__file__))
            vgg_weight_path = get_path_from_url(VGGFACE_WEIGHT_URL, cur_path)
            param = paddle.load(vgg_weight_path)
            vgg.load_dict(param)

            self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
            self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
            init_weights(self.nets['netD_A'],
                         init_type='xavier',
                         init_gain=1.0)
            init_weights(self.nets['netD_B'],
                         init_type='xavier',
                         init_gain=1.0)

            self.fake_A_pool = ImagePool(
                cfg.dataset.train.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                cfg.dataset.train.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = GANLoss(
                cfg.model.gan_mode)  #.to(self.device)  # define GAN loss.
            self.criterionCycle = paddle.nn.L1Loss()
            self.criterionIdt = paddle.nn.L1Loss()
            self.criterionL1 = paddle.nn.L1Loss()
            self.criterionL2 = paddle.nn.MSELoss()

            self.build_lr_scheduler()
            self.optimizers['optimizer_G'] = build_optimizer(
                cfg.optimizer,
                self.lr_scheduler,
                parameter_list=self.nets['netG'].parameters())
            self.optimizers['optimizer_DA'] = build_optimizer(
                cfg.optimizer,
                self.lr_scheduler,
                parameter_list=self.nets['netD_A'].parameters())
            self.optimizers['optimizer_DB'] = build_optimizer(
                cfg.optimizer,
                self.lr_scheduler,
                parameter_list=self.nets['netD_B'].parameters())
Exemple #27
0
    def _get_data(self, root, mode, **kwargs):
        default_root = os.path.join(DATA_HOME, 'DuReader')

        filename, data_hash = self.SPLITS[mode]

        fullname = os.path.join(default_root,
                                filename) if root is None else os.path.join(
                                    os.path.expanduser(root), filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            if root is not None:  # not specified, and no need to warn
                warnings.warn(
                    'md5 check failed for {}, download {} data to {}'.format(
                        filename, self.__class__.__name__, default_root))

            get_path_from_url(self.DATA_URL, default_root)

        self.full_path = fullname
Exemple #28
0
 def _get_data(self, mode, **kwargs):
     ''' Check and download Dataset '''
     builder_config = self.BUILDER_CONFIGS[self.name]
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     source_file_dir = builder_config['splits'][mode][0]
     source_full_dir = os.path.join(default_root, source_file_dir)
     if not os.path.exists(source_full_dir):
         get_path_from_url(builder_config['url'], default_root,
                           builder_config['md5'])
     if mode == 'train':
         return source_full_dir
     elif mode == 'dev':
         target_file_dir = builder_config['splits'][mode][1]
         target_full_dir = os.path.join(default_root, target_file_dir)
         if not os.path.exists(target_full_dir):
             get_path_from_url(builder_config['url'], default_root,
                               builder_config['md5'])
         return source_full_dir, target_full_dir
Exemple #29
0
    def __init__(self, input, output, batch_size=1, weight_path=None):
        self.input = input
        self.output = os.path.join(output, 'RealSR')
        self.model = RRDBNet(3, 3, 64, 23)
        if weight_path is None:
            weight_path = get_path_from_url(REALSR_WEIGHT_URL, cur_path)

        state_dict, _ = paddle.load(weight_path)
        self.model.load_dict(state_dict)
        self.model.eval()
    def __init__(self,
                 output='output',
                 weight_path=None,
                 config=None,
                 relative=False,
                 adapt_scale=False,
                 find_best_frame=False,
                 best_frame=None):
        if config is not None and isinstance(config, str):
            self.cfg = yaml.load(config)
        elif isinstance(config, dict):
            self.cfg = config
        elif config is None:
            self.cfg = {
                'model_params': {
                    'common_params': {
                        'num_kp': 10,
                        'num_channels': 3,
                        'estimate_jacobian': True
                    },
                    'kp_detector_params': {
                        'temperature': 0.1,
                        'block_expansion': 32,
                        'max_features': 1024,
                        'scale_factor': 0.25,
                        'num_blocks': 5
                    },
                    'generator_params': {
                        'block_expansion': 64,
                        'max_features': 512,
                        'num_down_blocks': 2,
                        'num_bottleneck_blocks': 6,
                        'estimate_occlusion_map': True,
                        'dense_motion_params': {
                            'block_expansion': 64,
                            'max_features': 1024,
                            'num_blocks': 5,
                            'scale_factor': 0.25
                        }
                    }
                }
            }
            if weight_path is None:
                vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
                cur_path = os.path.abspath(os.path.dirname(__file__))
                weight_path = get_path_from_url(vox_cpk_weight_url, cur_path)

        self.weight_path = weight_path
        self.output = output
        self.relative = relative
        self.adapt_scale = adapt_scale
        self.find_best_frame = find_best_frame
        self.best_frame = best_frame
        self.generator, self.kp_detector = self.load_checkpoints(
            self.cfg, self.weight_path)