Esempio n. 1
0
def create_agent_from_opt_file_and_model_class(opt, model_class):
    model_file = opt['model_file']
    optfile = model_file + '.opt'

    if not PathManager.exists(optfile):
        return None

    opt_from_file = Opt.load(optfile)

    # delete args that we do not want to copy over when loading the model
    for arg in NOCOPY_ARGS:
        if arg in opt_from_file:
            del opt_from_file[arg]

    # only override opts specified in 'override' dict
    if opt.get('override'):
        for k, v in opt['override'].items():
            if k in opt_from_file and str(v) != str(opt_from_file.get(k)):
                logging.warn(
                    f'Overriding opt["{k}"] to {v} (previously: {opt_from_file.get(k)})'
                )
            opt_from_file[k] = v

    if hasattr(model_class, 'upgrade_opt'):
        opt_from_file = model_class.upgrade_opt(opt_from_file)

    # add model arguments to opt_from_file if they aren't in opt_from_file already
    for k, v in opt.items():
        if k not in opt_from_file:
            opt_from_file[k] = v

    # update model file path to the one set by opt
    opt_from_file['model_file'] = model_file
    # update init model path to the one set by opt
    # NOTE: this step is necessary when for example the 'init_model' is
    # set by the Train Loop (as is the case when loading from checkpoint)
    if opt.get('init_model') is not None:
        opt_from_file['init_model'] = opt['init_model']

    # update dict file path
    if not opt_from_file.get('dict_file'):
        old_dict_file = None
        opt_from_file['dict_file'] = model_file + '.dict'
    elif opt_from_file.get('dict_file') and not PathManager.exists(
            opt_from_file['dict_file']):
        old_dict_file = opt_from_file['dict_file']
        opt_from_file['dict_file'] = model_file + '.dict'
    if not PathManager.exists(opt_from_file['dict_file']):
        warn_once(
            'WARNING: Neither the specified dict file ({}) nor the '
            '`model_file`.dict file ({}) exists, check to make sure either '
            'is correct. This may manifest as a shape mismatch later '
            'on.'.format(old_dict_file, opt_from_file['dict_file']))

    # if we want to load weights from --init-model, compare opts with
    # loaded ones
    compare_init_model_opts(opt, opt_from_file)
    return model_class(opt_from_file)
    def _initialize_bart(self, opt: Opt) -> Opt:
        """
        Download and convert BART pre-trained models.

        Additionally, convert `init-fairseq-model` if necessary.

        :param opt:
            ParlAI-parsed options

        :return opt:
            return opt with BART-specific args.
        """
        if not opt.get('converting'):
            download(opt['datapath'])
            opt['init_model'] = os.path.join(opt['datapath'],
                                             'models/bart/bart_large/model')
        if opt.get('init_fairseq_model'):
            opt = self._convert_model(opt)
        opt.update(BART_ARGS)
        compare_init_model_opts(opt, opt)
        return opt