def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs): if pretrain_dir_or_url in cls.resource_map: url = cls.resource_map[pretrain_dir_or_url] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download=force_download) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = pretrain_dir_or_url if not os.path.exists(pretrain_dir): raise ValueError('pretrain dir not found: %s' % pretrain_dir) vocab_path = os.path.join(pretrain_dir, 'vocab.txt') if not os.path.exists(vocab_path): raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir) vocab_dict = { j.strip().split('\t')[0]: i for i, j in enumerate(open(vocab_path).readlines()) } t = cls(vocab_dict, **kwargs) return t
def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs): if pretrain_dir_or_url in cls.resource_map: url = cls.resource_map[pretrain_dir_or_url] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = pretrain_dir_or_url if not os.path.exists(pretrain_dir): raise ValueError('pretrain dir not found: %s' % pretrain_dir) param_path = os.path.join(pretrain_dir, 'params') state_dict_path = os.path.join(pretrain_dir, 'saved_weights') config_path = os.path.join(pretrain_dir, 'ernie_config.json') if not os.path.exists(config_path): raise ValueError('config path not found: %s' % config_path) name_prefix = kwargs.pop('name', None) cfg_dict = dict(json.loads(open(config_path).read()), **kwargs) #model = cls(cfg_dict, name=name_prefix) log.info('loading pretrained model from %s' % pretrain_dir) return cfg_dict, param_path
def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs): if pretrain_dir_or_url in cls.resource_map: url = cls.resource_map[pretrain_dir_or_url] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): raise ValueError('pretrain dir not found: %s' % pretrain_dir) vocab_path = pretrain_dir / 'vocab.txt' sp_model_path = pretrain_dir / 'subword/spm_cased_simp_sampled.model' if not vocab_path.exists(): raise ValueError('no vocab file in pretrain dir: %s' % pretrain_dir) vocab_dict = { j.strip().split('\t')[0]: i for i, j in enumerate( vocab_path.open(encoding='utf8').readlines()) } t = cls(vocab_dict, sp_model_path, **kwargs) return t
def get_config(pretrain_dir_or_url): bce = 'https://ernie-github.cdn.bcebos.com/' resource_map = { 'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz', 'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz', 'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz', 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', } url = resource_map[pretrain_dir_or_url] pretrain_dir = _fetch_from_remote(url, False) config_path = os.path.join(pretrain_dir, 'ernie_config.json') if not os.path.exists(config_path): raise ValueError('config path not found: %s' % config_path) cfg_dict = dict(json.loads(open(config_path).read())) return cfg_dict
def from_pretrained(cls, pretrain_dir_or_url, force_download=False, **kwargs): if not Path(pretrain_dir_or_url).exists( ) and pretrain_dir_or_url in cls.resource_map: url = cls.resource_map[pretrain_dir_or_url] log.info('get pretrain dir from %s' % url) pretrain_dir = _fetch_from_remote(url, force_download) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map))) pretrain_dir = Path(pretrain_dir_or_url) if not pretrain_dir.exists(): raise ValueError('pretrain dir not found: %s' % pretrain_dir) state_dict_path = pretrain_dir / 'saved_weights.pdparams' config_path = pretrain_dir / 'ernie_config.json' if not config_path.exists(): raise ValueError('config path not found: %s' % config_path) name_prefix = kwargs.pop('name', None) cfg_dict = dict(json.loads(config_path.open().read()), **kwargs) model = cls(cfg_dict, name=name_prefix) log.info('loading pretrained model from %s' % pretrain_dir) #param_path = pretrain_dir / 'params' #if os.path.exists(param_path): # raise NotImplementedError() # log.debug('load pretrained weight from program state') # F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix if state_dict_path.exists(): m = P.load(state_dict_path) for k, v in model.state_dict().items(): if k not in m: log.warn('param:%s not set in pretrained model, skip' % k) m[k] = v # FIXME: no need to do this in the future model.set_state_dict(m) else: raise ValueError('weight file not found in pretrain dir: %s' % pretrain_dir) return model
def get_config(pretrain_dir_or_url): bce = 'https://ernie-github.cdn.bcebos.com/' resource_map = { 'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz', 'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz', 'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz', 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', } if not Path(pretrain_dir_or_url).exists() and str( pretrain_dir_or_url) in resource_map: url = resource_map[pretrain_dir_or_url] pretrain_dir = _fetch_from_remote(url, False) else: log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(resource_map))) pretrain_dir = Path(pretrain_dir_or_url) config_path = os.path.join(pretrain_dir, 'ernie_config.json') if not os.path.exists(config_path): raise ValueError('config path not found: %s' % config_path) cfg_dict = dict(json.loads(open(config_path).read())) return cfg_dict