def create_agent(opt, requireModelExists=False): """ Create an agent from the options ``model``, ``model_params`` and ``model_file``. The input is either of the form ``parlai.agents.ir_baseline.agents:IrBaselineAgent`` (i.e. the path followed by the class name) or else just ``ir_baseline`` which assumes the path above, and a class name suffixed with 'Agent'. If ``model-file`` is available in the options this function can also attempt to load the model from that location instead. This avoids having to specify all the other options necessary to set up the model including its name as they are all loaded from the options file if it exists (the file opt['model_file'] + '.opt' must exist and contain a pickled or json dict containing the model's options). """ if opt.get('datapath', None) is None: # add datapath, it is missing from parlai.core.params import ParlaiParser, get_model_name parser = ParlaiParser(add_parlai_args=False) parser.add_parlai_data_path() # add model args if they are missing model = get_model_name(opt) if model is not None: parser.add_model_subargs(model) opt_parser = parser.parse_args("", print_args=False) for k, v in opt_parser.items(): if k not in opt: opt[k] = v if opt.get('model_file'): opt['model_file'] = modelzoo_path(opt.get('datapath'), opt['model_file']) if requireModelExists and not os.path.isfile(opt['model_file']): raise RuntimeError( 'WARNING: Model file does not exist, check to make ' 'sure it is correct: {}'.format(opt['model_file'])) # Attempt to load the model from the model file first (this way we do # not even have to specify the model name as a parameter) model = load_agent_module(opt) if model is not None: return model else: print("[ no model with opt yet at: " + opt.get('model_file') + "(.opt) ]") if opt.get('model'): model_class = get_agent_module(opt['model']) # if we want to load weights from --init-model, compare opts with # loaded ones compare_init_model_opts(opt, opt) model = model_class(opt) if requireModelExists and hasattr( model, 'load') and not opt.get('model_file'): # double check that we didn't forget to set model_file on loadable model print('WARNING: model_file unset but model has a `load` function.') return model else: raise RuntimeError('Need to set `model` argument to use create_agent.')
def add_datapath_and_model_args(opt: Opt): # add datapath, it is missing from parlai.core.params import ParlaiParser, get_model_name parser = ParlaiParser(add_parlai_args=False) parser.add_parlai_data_path() # add model args if they are missing model = get_model_name(opt) if model is not None: parser.add_model_subargs(model, opt) opt_parser = parser.parse_args("") for k, v in opt_parser.items(): if k not in opt: opt[k] = v