Exemplo n.º 1
0
 def from_pretrained(cls,
                     pretrain_dir_or_url,
                     force_download=False,
                     **kwargs):
     if pretrain_dir_or_url in cls.resource_map:
         url = cls.resource_map[pretrain_dir_or_url]
         log.info('get pretrain dir from %s' % url)
         pretrain_dir = _fetch_from_remote(url,
                                           force_download=force_download)
     else:
         log.info('pretrain dir %s not in %s, read from local' %
                  (pretrain_dir_or_url, repr(cls.resource_map)))
         pretrain_dir = pretrain_dir_or_url
     if not os.path.exists(pretrain_dir):
         raise ValueError('pretrain dir not found: %s' % pretrain_dir)
     vocab_path = os.path.join(pretrain_dir, 'vocab.txt')
     if not os.path.exists(vocab_path):
         raise ValueError('no vocab file in pretrain dir: %s' %
                          pretrain_dir)
     vocab_dict = {
         j.strip().split('\t')[0]: i
         for i, j in enumerate(open(vocab_path).readlines())
     }
     t = cls(vocab_dict, **kwargs)
     return t
Exemplo n.º 2
0
    def from_pretrained(cls,
                        pretrain_dir_or_url,
                        force_download=False,
                        **kwargs):
        if pretrain_dir_or_url in cls.resource_map:
            url = cls.resource_map[pretrain_dir_or_url]
            log.info('get pretrain dir from %s' % url)
            pretrain_dir = _fetch_from_remote(url, force_download)
        else:
            log.info('pretrain dir %s not in %s, read from local' %
                     (pretrain_dir_or_url, repr(cls.resource_map)))
            pretrain_dir = pretrain_dir_or_url

        if not os.path.exists(pretrain_dir):
            raise ValueError('pretrain dir not found: %s' % pretrain_dir)
        param_path = os.path.join(pretrain_dir, 'params')
        state_dict_path = os.path.join(pretrain_dir, 'saved_weights')
        config_path = os.path.join(pretrain_dir, 'ernie_config.json')

        if not os.path.exists(config_path):
            raise ValueError('config path not found: %s' % config_path)
        name_prefix = kwargs.pop('name', None)
        cfg_dict = dict(json.loads(open(config_path).read()), **kwargs)
        model = cls(cfg_dict, name=name_prefix)

        log.info('loading pretrained model from %s' % pretrain_dir)

        #if os.path.exists(param_path):
        #    raise NotImplementedError()
        #    log.debug('load pretrained weight from program state')
        #    F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
        if os.path.exists(state_dict_path + '.pdparams'):
            m, _ = D.load_dygraph(state_dict_path)
            for k, v in model.state_dict().items():
                if k not in m:
                    log.warn('param:%s not set in pretrained model, skip' % k)
                    m[k] = v  # FIXME: no need to do this in the future
            model.set_dict(m)
        else:
            raise ValueError('weight file not found in pretrain dir: %s' %
                             pretrain_dir)
        return model