コード例 #1
0
    def _parse_config(self, opt: Opt):
        """
        Parse config for task.

        Use this to parse all options and settings necessary to set the variables for
        the conversation
        """
        self.debug = opt.get('is_debug', True)
        #self.overworld = 'InteractiveWorld'
        #self.world_path = 'core.blended_skill_talk.worlds'
        self.overworld = 'MessengerBotChatTaskWorld'
        self.world_path = 'chat_service.core.worlds'
        self.world_module = utils.get_world_module(self.world_path)
        self.max_workers = opt.get('max_workers', 30)
        # Deepcopy the opts so the manager opts aren't changed by the world runner
        self.runner_opt = copy.deepcopy(opt)
        self.world_runner = ChatServiceWorldRunner(
            self.runner_opt, self.world_module, self.max_workers, self,
            self.debug)  # Replace with base runner
        self.max_agents_for = 1
        self.onboard_map = None  #'chat_service.core.onboardworld.OnboardWorld'
        self.taskworld_map = None  #'chatbot'
        self.max_agents_for = 5
        self.service_reference_id = None
        self.parse_additional_args(opt)
コード例 #2
0
    def __init__(self,
                 opt,
                 datatype: str = 'train',
                 seed: Optional[int] = None):
        """
        Initalize the context generator.

        opt: only a 'datapath' key is required, to specify the ParlAI data folder
        """

        if seed is not None:
            self.rng = random.Random(seed)
        else:
            self.rng = random.Random()

        convai2_opt = Opt({'datapath': opt['datapath'], 'datatype': datatype})
        self.convai2_teacher = BothTeacher(convai2_opt)

        ed_opt = Opt({
            'datapath': opt['datapath'],
            'datatype': datatype,
            'train_experiencer_only': True,
        })
        # Specify train_experiencer_only = True because we want to ensure that the text
        # will correspond to a Speaker utterance and the label to a Listener response
        self.ed_teacher = EmpatheticDialoguesTeacher(ed_opt)

        wow_opt = Opt({'datapath': opt['datapath'], 'datatype': datatype})
        self.wow_teacher = WizardDialogKnowledgeTeacher(wow_opt)

        self.topic_to_persona_path = _topic_to_persona_path(opt)
        self.wow_topics_to_episode_idxes = self._setup_topics_to_episodes()
        self.persona_strings_to_wow_topics = self._setup_personas_to_topics()
コード例 #3
0
def compare_init_model_opts(opt: Opt, curr_opt: Opt):
    """
    Print loud warning when `init_model` opts differ from previous configuration.
    """
    if opt.get('init_model') is None:
        return
    opt['init_model'] = modelzoo_path(opt['datapath'], opt['init_model'])
    optfile = opt['init_model'] + '.opt'
    if not os.path.isfile(optfile):
        return
    init_model_opt = Opt.load(optfile)

    extra_opts = {}
    different_opts = {}
    exempt_opts = [
        'model_file',
        'dict_file',
        'override',
        'starttime',
        'init_model',
        'batchindex',
    ]

    # search through init model opts
    for k, v in init_model_opt.items():
        if (k not in exempt_opts and k in init_model_opt
                and init_model_opt[k] != curr_opt.get(k)):
            if isinstance(v, list):
                if init_model_opt[k] != list(curr_opt[k]):
                    different_opts[k] = ','.join([str(x) for x in v])
            else:
                different_opts[k] = v

    # search through opts to load
    for k, v in curr_opt.items():
        if k not in exempt_opts and k not in init_model_opt:
            if isinstance(v, list):
                extra_opts[k] = ','.join([str(x) for x in v])
            else:
                extra_opts[k] = v

    # print warnings
    extra_strs = ['{}: {}'.format(k, v) for k, v in extra_opts.items()]
    if extra_strs:
        print('\n' + '*' * 75)
        print('[ WARNING ] : your model is being loaded with opts that do not '
              'exist in the model you are initializing the weights with: '
              '{}'.format(','.join(extra_strs)))

    different_strs = [
        '--{} {}'.format(k, v).replace('_', '-')
        for k, v in different_opts.items()
    ]
    if different_strs:
        print('\n' + '*' * 75)
        print('[ WARNING ] : your model is being loaded with opts that differ '
              'from the model you are initializing the weights with. Add the '
              'following args to your run command to change this: \n'
              '\n{}'.format(' '.join(different_strs)))
        print('*' * 75)
コード例 #4
0
ファイル: params.py プロジェクト: anoopkarnik/Nelly-Chatbot
    def _process_args_to_opts(self,
                              args_that_override: Optional[List[str]] = None):
        self.opt = Opt(vars(self.args))

        # custom post-parsing
        self.opt['parlai_home'] = self.parlai_home
        self.opt = self._infer_datapath(self.opt)

        # set all arguments specified in command line as overridable
        option_strings_dict = {}
        store_true = []
        store_false = []
        for group in self._action_groups:
            for a in group._group_actions:
                if hasattr(a, 'option_strings'):
                    for option in a.option_strings:
                        option_strings_dict[option] = a.dest
                        if '_StoreTrueAction' in str(type(a)):
                            store_true.append(option)
                        elif '_StoreFalseAction' in str(type(a)):
                            store_false.append(option)

        if args_that_override is None:
            args_that_override = _sys.argv[1:]

        for i in range(len(args_that_override)):
            if args_that_override[i] in option_strings_dict:
                if args_that_override[i] in store_true:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = True
                elif args_that_override[i] in store_false:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = False
                elif (i < len(args_that_override) - 1
                      and args_that_override[i + 1][:1] != '-'):
                    key = option_strings_dict[args_that_override[i]]
                    self.overridable[key] = self.opt[key]
        self.opt['override'] = self.overridable

        # load opts if a file is provided.
        if self.opt.get('init_opt', None) is not None:
            self._load_opts(self.opt)

        # map filenames that start with 'zoo:' to point to the model zoo dir
        options_to_change = {
            'model_file', 'dict_file', 'bpe_vocab', 'bpe_merge'
        }
        for each_key in options_to_change:
            if self.opt.get(each_key) is not None:
                self.opt[each_key] = modelzoo_path(self.opt.get('datapath'),
                                                   self.opt[each_key])
            if self.opt['override'].get(each_key) is not None:
                # also check override
                self.opt['override'][each_key] = modelzoo_path(
                    self.opt.get('datapath'), self.opt['override'][each_key])

        # add start time of an experiment
        self.opt['starttime'] = datetime.datetime.today().strftime(
            '%b%d_%H-%M')
コード例 #5
0
def create_agent(opt: Opt, requireModelExists=False):
    """
    Create an agent from the options ``model``, ``model_params`` and ``model_file``.

    The input is either of the form
    ``parlai.agents.ir_baseline.agents:IrBaselineAgent`` (i.e. the path
    followed by the class name) or else just ``ir_baseline`` which
    assumes the path above, and a class name suffixed with 'Agent'.

    If ``model-file`` is available in the options this function can also
    attempt to load the model from that location instead. This avoids having to
    specify all the other options necessary to set up the model including its
    name as they are all loaded from the options file if it exists (the file
    opt['model_file'] + '.opt' must exist and contain a pickled or json dict
    containing the model's options).
    """
    if opt.get('datapath', None) is None:
        # add datapath, it is missing
        from core.params import ParlaiParser, get_model_name

        parser = ParlaiParser(add_parlai_args=False)
        parser.add_parlai_data_path()
        # add model args if they are missing
        model = get_model_name(opt)
        if model is not None:
            parser.add_model_subargs(model)
        opt_parser = parser.parse_args("", print_args=False)
        for k, v in opt_parser.items():
            if k not in opt:
                opt[k] = v

    if opt.get('model_file'):
        opt['model_file'] = modelzoo_path(opt.get('datapath'),
                                          opt['model_file'])
        if requireModelExists and not os.path.isfile(opt['model_file']):
            raise RuntimeError(
                'WARNING: Model file does not exist, check to make '
                'sure it is correct: {}'.format(opt['model_file']))
        # Attempt to load the model from the model file first (this way we do
        # not even have to specify the model name as a parameter)
        model = create_agent_from_opt_file(opt)
        if model is not None:
            return model
        else:
            print(f"[ no model with opt yet at: {opt['model_file']}(.opt) ]")

    if opt.get('model'):
        model_class = load_agent_module(opt['model'])
        # if we want to load weights from --init-model, compare opts with
        # loaded ones
        compare_init_model_opts(opt, opt)
        model = model_class(opt)
        if requireModelExists and hasattr(
                model, 'load') and not opt.get('model_file'):
            # double check that we didn't forget to set model_file on loadable model
            print('WARNING: model_file unset but model has a `load` function.')
        return model
    else:
        raise RuntimeError('Need to set `model` argument to use create_agent.')
コード例 #6
0
ファイル: params.py プロジェクト: anoopkarnik/Nelly-Chatbot
 def _load_opts(self, opt):
     optfile = opt.get('init_opt')
     new_opt = Opt.load(optfile)
     for key, value in new_opt.items():
         # existing command line parameters take priority.
         if key not in opt:
             raise RuntimeError(
                 'Trying to set opt from file that does not exist: ' +
                 str(key))
         if key not in opt['override']:
             opt[key] = value
             opt['override'][key] = value
コード例 #7
0
ファイル: params.py プロジェクト: anoopkarnik/Nelly-Chatbot
    def _load_known_opts(self, optfile, parsed):
        """
        Pull in CLI args for proper models/tasks/etc.

        Called before args are parsed; ``_load_opts`` is used for actually overriding
        opts after they are parsed.
        """
        new_opt = Opt.load(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in parsed or parsed[key] is None:
                parsed[key] = value
コード例 #8
0
ファイル: params.py プロジェクト: anoopkarnik/Nelly-Chatbot
def get_model_name(opt):
    """
    Get the model name from either `--model` or `--model-file`.
    """
    model = opt.get('model', None)
    if model is None:
        # try to get model name from model opt file
        model_file = opt.get('model_file', None)
        if model_file is not None:
            model_file = modelzoo_path(opt.get('datapath'), model_file)
            optfile = model_file + '.opt'
            if os.path.isfile(optfile):
                new_opt = Opt.load(optfile)
                model = new_opt.get('model', None)
    return model
コード例 #9
0
    def upgrade_opt(cls, opt_from_disk: Opt):
        # call the parent upgrades
        opt_from_disk = super(PolyencoderAgent, cls).upgrade_opt(opt_from_disk)

        polyencoder_attention_keys_value = opt_from_disk.get(
            'polyencoder_attention_keys')
        if polyencoder_attention_keys_value is not None:
            # 2020-02-19 We are deprecating this flag because it was used for a one-time
            # set of experiments and won't be used again. This flag was defaulted to
            # 'context', so throw an exception otherwise.
            if polyencoder_attention_keys_value == 'context':
                del opt_from_disk['polyencoder_attention_keys']
            else:
                raise NotImplementedError(
                    'This --polyencoder-attention-keys mode (found in commit 06f0d9f) is no longer supported!'
                )

        return opt_from_disk
コード例 #10
0
ファイル: dict.py プロジェクト: anoopkarnik/Nelly-Chatbot
    def __init__(self, opt: Opt, shared=None):
        """
        Initialize DictionaryAgent.
        """
        self.opt = copy.deepcopy(opt)
        self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq)
        self.null_token = opt.get('dict_nulltoken',
                                  DictionaryAgent.default_null)
        self.end_token = opt.get('dict_endtoken', DictionaryAgent.default_end)
        self.unk_token = opt.get('dict_unktoken', DictionaryAgent.default_unk)
        self.start_token = opt.get('dict_starttoken',
                                   DictionaryAgent.default_start)
        self.max_ngram_size = opt.get('dict_max_ngram_size',
                                      DictionaryAgent.default_maxngram)
        self.tokenizer = opt.get('dict_tokenizer', DictionaryAgent.default_tok)
        self.lower = opt.get('dict_lower', DictionaryAgent.default_lower)
        self.maxtokens = opt.get('dict_maxtokens',
                                 DictionaryAgent.default_maxtokens)
        self.textfields = opt.get(
            'dict_textfields', DictionaryAgent.default_textfields).split(",")

        try:
            self.tokenizer_fun = getattr(self, self.tokenizer + '_tokenize')
        except AttributeError:
            raise AttributeError('tokenizer type {} not yet supported'.format(
                self.tokenizer))

        if shared:
            self.freq = shared.get('freq', {})
            self.tok2ind = shared.get('tok2ind', {})
            self.ind2tok = shared.get('ind2tok', {})
        else:
            self.freq = defaultdict(int)
            self.tok2ind = {}
            self.ind2tok = {}

            if self.null_token:
                self.add_token(self.null_token)

            if self.start_token:
                # set special start of sentence word token
                self.add_token(self.start_token)

            if self.end_token:
                # set special end of sentence word token
                self.add_token(self.end_token)

            if self.unk_token:
                # set special unknown word token
                self.add_token(self.unk_token)

            loaded = False
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('dict_file'):
                opt['dict_file'] = modelzoo_path(opt.get('datapath'),
                                                 opt['dict_file'])
                if os.path.isfile(opt['dict_file']):
                    # load pre-existing dictionary
                    self.load(opt['dict_file'])
                    loaded = True

            if not loaded and opt.get('dict_initpath'):
                # load seed dictionary
                opt['dict_initpath'] = modelzoo_path(opt.get('datapath'),
                                                     opt['dict_initpath'])
                # don't check isfile first, should fail if file not found
                self.load(opt['dict_initpath'])
            opt['dict_loaded'] = loaded

        # cache unk token for later
        self._unk_token_idx = self.tok2ind.get(self.unk_token)

        # initialize tokenizers
        if self.tokenizer == 'nltk':
            try:
                import nltk
            except ImportError:
                raise ImportError('Please install nltk (pip install nltk)')
            # nltk-specific setup
            st_path = 'tokenizers/punkt/{0}.pickle'.format(
                opt['dict_language'])
            try:
                self.sent_tok = nltk.data.load(st_path)
            except LookupError:
                nltk.download('punkt')
                self.sent_tok = nltk.data.load(st_path)
            self.word_tok = nltk.tokenize.treebank.TreebankWordTokenizer()
        elif self.tokenizer in [
                'bpe', 'gpt2', 'bytelevelbpe', 'slow_bytelevel_bpe'
        ]:
            self.bpe = bpe_factory(opt, shared)
            self.bpe.sync_with_dict(self)

        if not shared:
            if self.null_token:
                # fix count for null token to one billion and three
                self.freq[self.null_token] = 1000000003

            if self.start_token:
                # fix count for start of sentence token to one billion and two
                self.freq[self.start_token] = 1000000002

            if self.end_token:
                # fix count for end of sentence token to one billion and one
                self.freq[self.end_token] = 1000000001

            if self.unk_token:
                # fix count for unknown token to one billion
                self.freq[self.unk_token] = 1000000000

            if opt.get('dict_file'):
                self.save_path = opt['dict_file']
コード例 #11
0
ファイル: params.py プロジェクト: anoopkarnik/Nelly-Chatbot
class ParlaiParser(argparse.ArgumentParser):
    """
    Provide an opt-producer and CLI argument parser.

    Pseudo-extension of ``argparse`` which sets a number of parameters
    for the ParlAI framework. More options can be added specific to other
    modules by passing this object and calling ``add_arg()`` or
    ``add_argument()`` on it.

    For example, see ``core.dict.DictionaryAgent.add_cmdline_args``.

    :param add_parlai_args:
        (default True) initializes the default arguments for ParlAI
        package, including the data download paths and task arguments.
    :param add_model_args:
        (default False) initializes the default arguments for loading
        models, including initializing arguments from that model.
    """
    def __init__(self,
                 add_parlai_args=True,
                 add_model_args=False,
                 description='ParlAI parser'):
        """
        Initialize the ParlAI argparser.
        """
        super().__init__(
            description=description,
            allow_abbrev=False,
            conflict_handler='resolve',
            formatter_class=CustomHelpFormatter,
            add_help=add_parlai_args,
        )
        self.register('type', 'nonestr', str2none)
        self.register('type', 'bool', str2bool)
        self.register('type', 'floats', str2floats)
        self.register('type', 'class', str2class)
        self.parlai_home = os.path.dirname(
            os.path.dirname(os.path.realpath(__file__)))
        os.environ['PARLAI_HOME'] = self.parlai_home

        self.add_arg = self.add_argument

        # remember which args were specified on the command line
        self.overridable = {}

        if add_parlai_args:
            self.add_parlai_args()
        if add_model_args:
            self.add_model_args()

    def add_parlai_data_path(self, argument_group=None):
        """
        Add --datapath CLI arg.
        """
        if argument_group is None:
            argument_group = self
        argument_group.add_argument(
            '-dp',
            '--datapath',
            default=None,
            help='path to datasets, defaults to {parlai_dir}/data',
        )

    def add_mturk_args(self):
        """
        Add standard mechanical turk arguments.
        """
        mturk = self.add_argument_group('Mechanical Turk')
        default_log_path = os.path.join(self.parlai_home, 'logs', 'mturk')
        mturk.add_argument(
            '--mturk-log-path',
            default=default_log_path,
            help='path to MTurk logs, defaults to {parlai_dir}/logs/mturk',
        )
        mturk.add_argument(
            '-t',
            '--task',
            help='MTurk task, e.g. "qa_data_collection" or "model_evaluator"',
        )
        mturk.add_argument(
            '-nc',
            '--num-conversations',
            default=1,
            type=int,
            help='number of conversations you want to create for this task',
        )
        mturk.add_argument(
            '--unique',
            dest='unique_worker',
            default=False,
            action='store_true',
            help='enforce that no worker can work on your task twice',
        )
        mturk.add_argument(
            '--max-hits-per-worker',
            dest='max_hits_per_worker',
            default=0,
            type=int,
            help=
            'Max number of hits each worker can perform during current group run',
        )
        mturk.add_argument(
            '--unique-qual-name',
            dest='unique_qual_name',
            default=None,
            type=str,
            help='qualification name to use for uniqueness between HITs',
        )
        mturk.add_argument(
            '-r',
            '--reward',
            default=0.05,
            type=float,
            help='reward for each worker for finishing the conversation, '
            'in US dollars',
        )
        mturk.add_argument(
            '--sandbox',
            dest='is_sandbox',
            action='store_true',
            help='submit the HITs to MTurk sandbox site',
        )
        mturk.add_argument(
            '--live',
            dest='is_sandbox',
            action='store_false',
            help='submit the HITs to MTurk live site',
        )
        mturk.add_argument(
            '--debug',
            dest='is_debug',
            action='store_true',
            help='print and log all server interactions and messages',
        )
        mturk.add_argument(
            '--verbose',
            dest='verbose',
            action='store_true',
            help='print all messages sent to and from Turkers',
        )
        mturk.add_argument(
            '--hard-block',
            dest='hard_block',
            action='store_true',
            default=False,
            help='Hard block disconnecting Turkers from all of your HITs',
        )
        mturk.add_argument(
            '--log-level',
            dest='log_level',
            type=int,
            default=20,
            help='importance level for what to put into the logs. the lower '
            'the level the more that gets logged. values are 0-50',
        )
        mturk.add_argument(
            '--disconnect-qualification',
            dest='disconnect_qualification',
            default=None,
            help='Qualification to use for soft blocking users for '
            'disconnects. By default '
            'turkers are never blocked, though setting this will allow '
            'you to filter out turkers that have disconnected too many '
            'times on previous HITs where this qualification was set.',
        )
        mturk.add_argument(
            '--block-qualification',
            dest='block_qualification',
            default=None,
            help='Qualification to use for soft blocking users. This '
            'qualification is granted whenever soft_block_worker is '
            'called, and can thus be used to filter workers out from a '
            'single task or group of tasks by noted performance.',
        )
        mturk.add_argument(
            '--count-complete',
            dest='count_complete',
            default=False,
            action='store_true',
            help='continue until the requested number of conversations are '
            'completed rather than attempted',
        )
        mturk.add_argument(
            '--allowed-conversations',
            dest='allowed_conversations',
            default=0,
            type=int,
            help='number of concurrent conversations that one mturk worker '
            'is able to be involved in, 0 is unlimited',
        )
        mturk.add_argument(
            '--max-connections',
            dest='max_connections',
            default=30,
            type=int,
            help='number of HITs that can be launched at the same time, 0 is '
            'unlimited.',
        )
        mturk.add_argument(
            '--min-messages',
            dest='min_messages',
            default=0,
            type=int,
            help='number of messages required to be sent by MTurk agent when '
            'considering whether to approve a HIT in the event of a '
            'partner disconnect. I.e. if the number of messages '
            'exceeds this number, the turker can submit the HIT.',
        )
        mturk.add_argument(
            '--local',
            dest='local',
            default=False,
            action='store_true',
            help='Run the server locally on this server rather than setting up'
            ' a heroku server.',
        )
        mturk.add_argument(
            '--hobby',
            dest='hobby',
            default=False,
            action='store_true',
            help='Run the heroku server on the hobby tier.',
        )
        mturk.add_argument(
            '--max-time',
            dest='max_time',
            default=0,
            type=int,
            help='Maximum number of seconds per day that a worker is allowed '
            'to work on this assignment',
        )
        mturk.add_argument(
            '--max-time-qual',
            dest='max_time_qual',
            default=None,
            help='Qualification to use to share the maximum time requirement '
            'with other runs from other machines.',
        )
        mturk.add_argument(
            '--heroku-team',
            dest='heroku_team',
            default=None,
            help='Specify Heroku team name to use for launching Dynos.',
        )
        mturk.add_argument(
            '--tmp-dir',
            dest='tmp_dir',
            default=None,
            help='Specify location to use for scratch builds and such.',
        )

        # it helps to indicate to agents that they're in interactive mode, and
        # can avoid some otherwise troublesome behavior (not having candidates,
        # sharing self.replies, etc).
        mturk.set_defaults(interactive_mode=True)

        mturk.set_defaults(is_sandbox=True)
        mturk.set_defaults(is_debug=False)
        mturk.set_defaults(verbose=False)

    def add_chatservice_args(self):
        """
        Arguments for all chat services.
        """
        args = self.add_argument_group('Chat Services')
        args.add_argument(
            '--debug',
            dest='is_debug',
            action='store_true',
            help='print and log all server interactions and messages',
        )
        args.add_argument(
            '--config-path',
            default=None,
            type=str,
            help='/path/to/config/file for a given task.',
        )
        args.add_argument(
            '--password',
            dest='password',
            type=str,
            default=None,
            help='Require a password for entry to the bot',
        )

    def add_websockets_args(self):
        """
        Add websocket arguments.
        """
        self.add_chatservice_args()
        websockets = self.add_argument_group('Websockets')
        websockets.add_argument('--port',
                                default=35496,
                                type=int,
                                help='Port to run the websocket handler')

    def add_messenger_args(self):
        """
        Add Facebook Messenger arguments.
        """
        self.add_chatservice_args()
        messenger = self.add_argument_group('Facebook Messenger')
        messenger.add_argument(
            '--verbose',
            dest='verbose',
            action='store_true',
            help='print all messages sent to and from Turkers',
        )
        messenger.add_argument(
            '--log-level',
            dest='log_level',
            type=int,
            default=20,
            help='importance level for what to put into the logs. the lower '
            'the level the more that gets logged. values are 0-50',
        )
        messenger.add_argument(
            '--force-page-token',
            dest='force_page_token',
            action='store_true',
            help='override the page token stored in the cache for a new one',
        )
        messenger.add_argument(
            '--bypass-server-setup',
            dest='bypass_server_setup',
            action='store_true',
            default=False,
            help='should bypass traditional server and socket setup',
        )
        messenger.add_argument(
            '--local',
            dest='local',
            action='store_true',
            default=False,
            help='Run the server locally on this server rather than setting up'
            ' a heroku server.',
        )

        messenger.set_defaults(is_debug=False)
        messenger.set_defaults(verbose=False)

    def add_parlai_args(self, args=None):
        """
        Add common ParlAI args across all scripts.
        """
        parlai = self.add_argument_group('Main ParlAI Arguments')
        self.add_argument(
            '-o',
            '--init-opt',
            default=None,
            help='Path to json file of options. '
            'Note: Further Command-line arguments override file-based options.',
        )
        self.add_argument(
            '-v',
            '--show-advanced-args',
            action='store_true',
            help='Show hidden command line options (advanced users only)',
        )
        self.add_argument(
            '-t',
            '--task',
            help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"')
        self.add_argument(
            '--download-path',
            default=None,
            hidden=True,
            help='path for non-data dependencies to store any needed files.'
            'defaults to {parlai_dir}/downloads',
        )
        self.add_argument(
            '-dt',
            '--datatype',
            default='train',
            choices=[
                'train',
                'train:stream',
                'train:ordered',
                'train:ordered:stream',
                'train:stream:ordered',
                'train:evalmode',
                'train:evalmode:stream',
                'train:evalmode:ordered',
                'train:evalmode:ordered:stream',
                'train:evalmode:stream:ordered',
                'valid',
                'valid:stream',
                'test',
                'test:stream',
            ],
            help='choose from: train, train:ordered, valid, test. to stream '
            'data add ":stream" to any option (e.g., train:stream). '
            'by default: train is random with replacement, '
            'valid is ordered, test is ordered.',
        )
        self.add_argument(
            '-im',
            '--image-mode',
            default='raw',
            type=str,
            help='image preprocessor to use. default is "raw". set to "none" '
            'to skip image loading.',
            hidden=True,
        )
        self.add_argument(
            '-nt',
            '--numthreads',
            default=1,
            type=int,
            help='number of threads. Used for hogwild if batchsize is 1, else '
            'for number of threads in threadpool loading,',
        )
        self.add_argument(
            '--hide-labels',
            default=False,
            type='bool',
            hidden=True,
            help='default (False) moves labels in valid and test sets to the '
            'eval_labels field. If True, they are hidden completely.',
        )
        self.add_argument(
            '-mtw',
            '--multitask-weights',
            type='floats',
            default=[1],
            help='list of floats, one for each task, specifying '
            'the probability of drawing the task in multitask case',
            hidden=True,
        )
        self.add_argument(
            '-bs',
            '--batchsize',
            default=1,
            type=int,
            help='batch size for minibatch training schemes',
        )
        self.add_argument(
            '-dynb',
            '--dynamic-batching',
            default=None,
            type='nonestr',
            choices={None, 'full', 'batchsort'},
            help='Use dynamic batching',
        )
        self.add_parlai_data_path(parlai)

    def add_distributed_training_args(self):
        """
        Add CLI args for distributed training.
        """
        grp = self.add_argument_group('Distributed Training')
        grp.add_argument('--distributed-world-size',
                         type=int,
                         help='Number of workers.')
        grp.add_argument(
            '--verbose',
            type='bool',
            default=False,
            help='All workers print output.',
            hidden=True,
        )
        return grp

    def add_model_args(self):
        """
        Add arguments related to models such as model files.
        """
        model_args = self.add_argument_group('ParlAI Model Arguments')
        model_args.add_argument(
            '-m',
            '--model',
            default=None,
            help='the model class name. can match parlai/agents/<model> for '
            'agents in that directory, or can provide a fully specified '
            'module for `from X import Y` via `-m X:Y` '
            '(e.g. `-m agents.seq2seq.seq2seq:Seq2SeqAgent`)',
        )
        model_args.add_argument(
            '-mf',
            '--model-file',
            default=None,
            help='model file name for loading and saving models',
        )
        model_args.add_argument(
            '-im',
            '--init-model',
            default=None,
            type=str,
            help='load model weights and dict from this file',
        )
        model_args.add_argument('--dict-class',
                                hidden=True,
                                help='the class of the dictionary agent uses')

    def add_model_subargs(self, model):
        """
        Add arguments specific to a particular model.
        """
        agent = load_agent_module(model)
        try:
            if hasattr(agent, 'add_cmdline_args'):
                agent.add_cmdline_args(self)
        except argparse.ArgumentError:
            # already added
            pass
        try:
            if hasattr(agent, 'dictionary_class'):
                s = class2str(agent.dictionary_class())
                self.set_defaults(dict_class=s)
        except argparse.ArgumentError:
            # already added
            pass

    def add_task_args(self, task):
        """
        Add arguments specific to the specified task.
        """
        for t in ids_to_tasks(task).split(','):
            agent = load_teacher_module(t)
            try:
                if hasattr(agent, 'add_cmdline_args'):
                    agent.add_cmdline_args(self)
            except argparse.ArgumentError:
                # already added
                pass

    def add_world_args(self, task, interactive_task, selfchat_task):
        """
        Add arguments specific to the world.
        """
        world_class = load_world_module(task,
                                        interactive_task=interactive_task,
                                        selfchat_task=selfchat_task)
        if world_class is not None and hasattr(world_class,
                                               'add_cmdline_args'):
            try:
                world_class.add_cmdline_args(self)
            except argparse.ArgumentError:
                # already added
                pass

    def add_image_args(self, image_mode):
        """
        Add additional arguments for handling images.
        """
        try:
            parlai = self.add_argument_group(
                'ParlAI Image Preprocessing Arguments')
            self.add_argument(
                '--image-size',
                type=int,
                default=256,
                help='resizing dimension for images',
                hidden=True,
            )
            self.add_argument(
                '--image-cropsize',
                type=int,
                default=224,
                help='crop dimension for images',
                hidden=True,
            )
        except argparse.ArgumentError:
            # already added
            pass

    def add_extra_args(self, args=None):
        """
        Add more args depending on how known args are set.
        """
        parsed = vars(self.parse_known_args(args, nohelp=True)[0])
        # Also load extra args options if a file is given.
        if parsed.get('init_opt', None) is not None:
            self._load_known_opts(parsed.get('init_opt'), parsed)
        parsed = self._infer_datapath(parsed)

        # find which image mode specified if any, and add additional arguments
        image_mode = parsed.get('image_mode', None)
        if image_mode is not None and image_mode != 'no_image_model':
            self.add_image_args(image_mode)

        # find which task specified if any, and add its specific arguments
        task = parsed.get('task', None)
        if task is not None:
            self.add_task_args(task)
        evaltask = parsed.get('evaltask', None)
        if evaltask is not None:
            self.add_task_args(evaltask)

        # find which model specified if any, and add its specific arguments
        model = get_model_name(parsed)
        if model is not None:
            self.add_model_subargs(model)

        # add world args, if we know a priori which world is being used
        if task is not None:
            self.add_world_args(
                task,
                parsed.get('interactive_task', False),
                parsed.get('selfchat_task', False),
            )

        # reset parser-level defaults over any model-level defaults
        try:
            self.set_defaults(**self._defaults)
        except AttributeError:
            raise RuntimeError('Please file an issue on github that argparse '
                               'got an attribute error when parsing.')

    def parse_known_args(self, args=None, namespace=None, nohelp=False):
        """
        Parse known args to ignore help flag.
        """
        if args is None:
            # args default to the system args
            args = _sys.argv[1:]
        args = fix_underscores(args)

        if nohelp:
            # ignore help
            args = [a for a in args if a != '-h' and a != '--help']
        return super().parse_known_args(args, namespace)

    def _load_known_opts(self, optfile, parsed):
        """
        Pull in CLI args for proper models/tasks/etc.

        Called before args are parsed; ``_load_opts`` is used for actually overriding
        opts after they are parsed.
        """
        new_opt = Opt.load(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in parsed or parsed[key] is None:
                parsed[key] = value

    def _load_opts(self, opt):
        optfile = opt.get('init_opt')
        new_opt = Opt.load(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in opt:
                raise RuntimeError(
                    'Trying to set opt from file that does not exist: ' +
                    str(key))
            if key not in opt['override']:
                opt[key] = value
                opt['override'][key] = value

    def _infer_datapath(self, opt):
        """
        Set the value for opt['datapath'] and opt['download_path'].

        Sets the value for opt['datapath'] and opt['download_path'], correctly
        respecting environmental variables and the default.
        """
        # set environment variables
        # Priority for setting the datapath (same applies for download_path):
        # --datapath -> os.environ['PARLAI_DATAPATH'] -> <self.parlai_home>/data
        if opt.get('download_path'):
            os.environ['PARLAI_DOWNPATH'] = opt['download_path']
        elif os.environ.get('PARLAI_DOWNPATH') is None:
            os.environ['PARLAI_DOWNPATH'] = os.path.join(
                self.parlai_home, 'downloads')
        if opt.get('datapath'):
            os.environ['PARLAI_DATAPATH'] = opt['datapath']
        elif os.environ.get('PARLAI_DATAPATH') is None:
            os.environ['PARLAI_DATAPATH'] = os.path.join(
                self.parlai_home, 'data')

        opt['download_path'] = os.environ['PARLAI_DOWNPATH']
        opt['datapath'] = os.environ['PARLAI_DATAPATH']

        return opt

    def _process_args_to_opts(self,
                              args_that_override: Optional[List[str]] = None):
        self.opt = Opt(vars(self.args))

        # custom post-parsing
        self.opt['parlai_home'] = self.parlai_home
        self.opt = self._infer_datapath(self.opt)

        # set all arguments specified in command line as overridable
        option_strings_dict = {}
        store_true = []
        store_false = []
        for group in self._action_groups:
            for a in group._group_actions:
                if hasattr(a, 'option_strings'):
                    for option in a.option_strings:
                        option_strings_dict[option] = a.dest
                        if '_StoreTrueAction' in str(type(a)):
                            store_true.append(option)
                        elif '_StoreFalseAction' in str(type(a)):
                            store_false.append(option)

        if args_that_override is None:
            args_that_override = _sys.argv[1:]

        for i in range(len(args_that_override)):
            if args_that_override[i] in option_strings_dict:
                if args_that_override[i] in store_true:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = True
                elif args_that_override[i] in store_false:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = False
                elif (i < len(args_that_override) - 1
                      and args_that_override[i + 1][:1] != '-'):
                    key = option_strings_dict[args_that_override[i]]
                    self.overridable[key] = self.opt[key]
        self.opt['override'] = self.overridable

        # load opts if a file is provided.
        if self.opt.get('init_opt', None) is not None:
            self._load_opts(self.opt)

        # map filenames that start with 'zoo:' to point to the model zoo dir
        options_to_change = {
            'model_file', 'dict_file', 'bpe_vocab', 'bpe_merge'
        }
        for each_key in options_to_change:
            if self.opt.get(each_key) is not None:
                self.opt[each_key] = modelzoo_path(self.opt.get('datapath'),
                                                   self.opt[each_key])
            if self.opt['override'].get(each_key) is not None:
                # also check override
                self.opt['override'][each_key] = modelzoo_path(
                    self.opt.get('datapath'), self.opt['override'][each_key])

        # add start time of an experiment
        self.opt['starttime'] = datetime.datetime.today().strftime(
            '%b%d_%H-%M')

    def parse_and_process_known_args(self, args=None):
        """
        Parse provided arguments and return parlai opts and unknown arg list.

        Runs the same arg->opt parsing that parse_args does, but doesn't throw an error
        if the args being parsed include additional command line arguments that parlai
        doesn't know what to do with.
        """
        self.args, unknowns = super().parse_known_args(args=args)
        self._process_args_to_opts(args)
        return self.opt, unknowns

    def parse_args(self, args=None, namespace=None, print_args=True):
        """
        Parse the provided arguments and returns a dictionary of the ``args``.

        We specifically remove items with ``None`` as values in order to support the
        style ``opt.get(key, default)``, which would otherwise return ``None``.
        """
        self.add_extra_args(args)
        self.args = super().parse_args(args=args)

        self._process_args_to_opts(args)

        if print_args:
            self.print_args()
            if GIT_AVAILABLE:
                print_git_commit()
            print_announcements(self.opt)

        if os.environ.get('PARLAI_VERBOSE'):
            self.opt['verbose'] = True

        if self.opt.get('verbose'):
            logging.set_verbose_mode()

        return self.opt

    def _kwargs_to_str_args(self, **kwargs):
        """
        Attempt to map from python-code kwargs into CLI args.

        e.g. model_file -> --model-file.

        Works with short options too, like t="convai2".
        """

        # we have to do this large block of repetitive code twice, the first
        # round is basically just to become aware of anything that would have
        # been added by add_extra_args
        kwname_to_action = {}
        for action in self._actions:
            if action.dest == 'help':
                # no help allowed
                continue
            for option_string in action.option_strings:
                kwname = option_string.lstrip('-').replace('-', '_')
                assert (kwname not in kwname_to_action) or (
                    kwname_to_action[kwname] is action
                ), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
                kwname_to_action[kwname] = action

        string_args = []
        for kwname, value in kwargs.items():
            if kwname not in kwname_to_action:
                # best guess, we need to delay it. hopefully this gets added
                # during add_kw_Args
                continue
            action = kwname_to_action[kwname]
            last_option_string = action.option_strings[-1]
            if isinstance(action, argparse._StoreTrueAction) and bool(value):
                string_args.append(last_option_string)
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs is None:
                string_args.append(last_option_string)
                string_args.append(str(value))
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs in '*+':
                string_args.append(last_option_string)
                string_args.extend([str(v) for v in value])
            else:
                raise TypeError(f"Don't know what to do with {action}")

        # become aware of any extra args that might be specified if the user
        # provides something like model="transformer/generator".
        self.add_extra_args(string_args)

        # do it again, this time knowing about ALL args.
        kwname_to_action = {}
        for action in self._actions:
            if action.dest == 'help':
                # no help allowed
                continue
            for option_string in action.option_strings:
                kwname = option_string.lstrip('-').replace('-', '_')
                assert (kwname not in kwname_to_action) or (
                    kwname_to_action[kwname] is action
                ), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
                kwname_to_action[kwname] = action

        string_args = []
        for kwname, value in kwargs.items():
            # note we don't have the if kwname not in kwname_to_action here.
            # it MUST appear, or else we legitimately should be throwing a KeyError
            # because user has provided an unspecified option
            action = kwname_to_action[kwname]
            last_option_string = action.option_strings[-1]
            if isinstance(action, argparse._StoreTrueAction) and bool(value):
                string_args.append(last_option_string)
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs is None:
                string_args.append(last_option_string)
                string_args.append(str(value))
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs in '*+':
                string_args.append(last_option_string)
                string_args.extend([str(v) for v in value])
            else:
                raise TypeError(f"Don't know what to do with {action}")

        return string_args

    def parse_kwargs(self, **kwargs):
        """
        Parse kwargs, with type checking etc.
        """

        # hack: capture any error messages without raising a SystemExit
        def _captured_error(msg):
            raise ValueError(msg)

        old_error = self.error
        self.error = _captured_error
        try:
            string_args = self._kwargs_to_str_args(**kwargs)
            return self.parse_args(args=string_args, print_args=False)
        finally:
            self.error = old_error

    def print_args(self):
        """
        Print out all the arguments in this parser.
        """
        if not self.opt:
            self.parse_args(print_args=False)
        values = {}
        for key, value in self.opt.items():
            values[str(key)] = str(value)
        for group in self._action_groups:
            group_dict = {
                a.dest: getattr(self.args, a.dest, None)
                for a in group._group_actions
            }
            namespace = argparse.Namespace(**group_dict)
            count = 0
            for key in sorted(namespace.__dict__):
                if key in values:
                    if count == 0:
                        print('[ ' + group.title + ': ] ')
                    count += 1
                    print('[  ' + key + ': ' + values[key] + ' ]')

    def set_params(self, **kwargs):
        """
        Set overridable kwargs.
        """
        self.set_defaults(**kwargs)
        for k, v in kwargs.items():
            self.overridable[k] = v

    @property
    def show_advanced_args(self):
        """
        Check if we should show arguments marked as hidden.
        """
        if hasattr(self, '_show_advanced_args'):
            return self._show_advanced_args
        known_args, _ = self.parse_known_args(nohelp=True)
        if hasattr(known_args, 'show_advanced_args'):
            self._show_advanced_args = known_args.show_advanced_args
        else:
            self._show_advanced_args = True
        return self._show_advanced_args

    def _handle_custom_options(self, kwargs):
        """
        Handle custom parlai options.

        Includes hidden, recommended. Future may include no_save and no_override.
        """
        action_attr = {}
        if 'recommended' in kwargs:
            rec = kwargs.pop('recommended')
            action_attr['recommended'] = rec
        action_attr['hidden'] = kwargs.get('hidden', False)
        if 'hidden' in kwargs:
            hidden = kwargs.pop('hidden')
            if hidden:
                kwargs['help'] = argparse.SUPPRESS
        if 'type' in kwargs and kwargs['type'] is bool:
            # common error, we really want simple form
            kwargs['type'] = 'bool'
        return kwargs, action_attr

    def add_argument(self, *args, **kwargs):
        """
        Override to convert underscores to hyphens for consistency.
        """
        kwargs, newattr = self._handle_custom_options(kwargs)
        action = super().add_argument(*fix_underscores(args), **kwargs)
        for k, v in newattr.items():
            setattr(action, k, v)
        return action

    def add_argument_group(self, *args, **kwargs):
        """
        Override to make arg groups also convert underscores to hyphens.
        """
        arg_group = super().add_argument_group(*args, **kwargs)
        original_add_arg = arg_group.add_argument

        def ag_add_argument(*args, **kwargs):
            kwargs, newattr = self._handle_custom_options(kwargs)
            action = original_add_arg(*fix_underscores(args), **kwargs)
            for k, v in newattr.items():
                setattr(action, k, v)
            return action

        arg_group.add_argument = ag_add_argument  # override _ => -
        arg_group.add_argument_group = self.add_argument_group
        return arg_group

    def error(self, message):
        """
        Override to print custom error message.
        """
        self.print_help()
        _sys.stderr.write('\nParse Error: %s\n' % message)
        _sys.exit(2)
コード例 #12
0
    def __init__(self, opt: Opt, shared=None):
        init_model, self.is_finetune = self._get_init_model(opt, shared)
        super().__init__(opt, shared)

        # set up classes
        if opt.get('classes') is None and opt.get('classes_from_file') is None:
            raise RuntimeError(
                'Must specify --classes or --classes-from-file argument.')
        if not shared:
            if opt['classes_from_file'] is not None:
                with open(opt['classes_from_file']) as f:
                    self.class_list = f.read().splitlines()
            else:
                self.class_list = opt['classes']
            self.class_dict = {val: i for i, val in enumerate(self.class_list)}
            if opt.get('class_weights', None) is not None:
                self.class_weights = opt['class_weights']
            else:
                self.class_weights = [1.0 for c in self.class_list]
            self.reset_metrics()
        else:
            self.class_list = shared['class_list']
            self.class_dict = shared['class_dict']
            self.class_weights = shared['class_weights']

        # in binary classfication, opt['threshold'] applies to ref class
        if opt['ref_class'] is None or opt['ref_class'] not in self.class_dict:
            self.ref_class = self.class_list[0]
        else:
            self.ref_class = opt['ref_class']
            ref_class_id = self.class_list.index(self.ref_class)
            if ref_class_id != 0:
                # move to the front of the class list
                self.class_list.insert(0, self.class_list.pop(ref_class_id))

        # set up threshold, only used in binary classification
        if len(self.class_list) == 2 and opt.get('threshold', 0.5) != 0.5:
            self.threshold = opt['threshold']
        else:
            self.threshold = None

        # set up model and optimizers

        if shared:
            self.model = shared['model']
        else:
            self.model = self.build_model()
            self.criterion = self.build_criterion()
            if self.model is None or self.criterion is None:
                raise AttributeError(
                    'build_model() and build_criterion() need to return the model or criterion'
                )
            if init_model:
                print('Loading existing model parameters from ' + init_model)
                self.load(init_model)
            if self.use_cuda:
                if self.model_parallel:
                    self.model = PipelineHelper().make_parallel(self.model)
                    logging.info("Model parallelized via PipelineHelper")
                else:
                    self.model.cuda()
                    logging.info("Model made just CUDA")
                if self.data_parallel:
                    self.model = torch.nn.DataParallel(self.model)
                self.criterion.cuda()
                logging.info("Criterion loaded as cuda")

        if shared:
            # We don't use get here because hasattr is used on optimizer later.
            if 'optimizer' in shared:
                self.optimizer = shared['optimizer']
        elif self._should_initialize_optimizer():
            optim_params = [
                p for p in self.model.parameters() if p.requires_grad
            ]
            self.init_optim(optim_params)
            self.build_lr_scheduler()
コード例 #13
0
def create_agent_from_opt_file(opt: Opt):
    """
    Load agent options and module from file if opt file exists.

    Checks to see if file exists opt['model_file'] + ".opt"; if so, load up the
    options from the file and use that to create an agent, loading the model
    type from that file and overriding any options specified in that file when
    instantiating the agent.

    If that file does not exist, return None.
    """
    model_file = opt['model_file']
    optfile = model_file + '.opt'
    if os.path.isfile(optfile):
        new_opt = Opt.load(optfile)
        # TODO we need a better way to say these options are never copied...
        if 'datapath' in new_opt:
            # never use the datapath from an opt dump
            del new_opt['datapath']
        if 'batchindex' in new_opt:
            # This saved variable can cause trouble if we switch to BS=1 at test time
            del new_opt['batchindex']
        # only override opts specified in 'override' dict
        if opt.get('override'):
            for k, v in opt['override'].items():
                if str(v) != str(new_opt.get(k, None)):
                    print("[ warning: overriding opt['{}'] to {} ("
                          "previously: {} )]".format(k, v,
                                                     new_opt.get(k, None)))
                new_opt[k] = v

        model_class = load_agent_module(new_opt['model'])

        # check for model version
        if hasattr(model_class, 'model_version'):
            curr_version = new_opt.get('model_version', 0)
            if curr_version != model_class.model_version():
                model = new_opt['model']
                m = ('It looks like you are trying to load an older version of'
                     ' the selected model. Change your model argument to use '
                     'the old version from core/agents/legacy_agents: for '
                     'example: `-m legacy:{m}:{v}` or '
                     '`--model parlai.agents.legacy_agents.{m}.{m}_v{v}:{c}`')
                if '.' not in model:
                    # give specific error message if it's easy
                    raise RuntimeError(
                        m.format(m=model,
                                 v=curr_version,
                                 c=model_class.__name__))
                else:
                    # otherwise generic one
                    raise RuntimeError(
                        m.format(m='modelname', v=curr_version,
                                 c='ModelAgent'))

        if hasattr(model_class, 'upgrade_opt'):
            new_opt = model_class.upgrade_opt(new_opt)

        # add model arguments to new_opt if they aren't in new_opt already
        for k, v in opt.items():
            if k not in new_opt:
                new_opt[k] = v
        new_opt['model_file'] = model_file
        if not new_opt.get('dict_file'):
            new_opt['dict_file'] = model_file + '.dict'
        elif new_opt.get('dict_file') and not os.path.isfile(
                new_opt['dict_file']):
            old_dict_file = new_opt['dict_file']
            new_opt['dict_file'] = model_file + '.dict'
        if not os.path.isfile(new_opt['dict_file']):
            warn_once(
                'WARNING: Neither the specified dict file ({}) nor the '
                '`model_file`.dict file ({}) exists, check to make sure either '
                'is correct. This may manifest as a shape mismatch later '
                'on.'.format(old_dict_file, new_opt['dict_file']))

        # if we want to load weights from --init-model, compare opts with
        # loaded ones
        compare_init_model_opts(opt, new_opt)
        return model_class(new_opt)
    else:
        return None
コード例 #14
0
ファイル: run.py プロジェクト: anoopkarnik/Nelly-Chatbot
    try:
        manager.start_task()
    except BaseException:
        raise
    finally:
        manager.shutdown()


def get_model_path(model_path):
    os.chdir('../')
    os.chdir('../')
    new_path = os.path.join(os.getcwd(), 'data')
    os.chdir(new_path)
    os.chdir(os.path.join(new_path, model_path))
    model_obs_path = os.getcwd()
    return model_obs_path


if __name__ == '__main__':
    """ parser = setup_bst()
    opt = parser.parse_args(print_args=True, print_parser=parser, withparley=False)
    import core.agents as agents
    opt = agents.create_agent_from_opt_file()"""

    opt_file = os.path.abspath(os.path.join(os.getcwd(), '../..')) + \
               '/data/models' + '/' + _sys.argv[1:][0] + '.opt'
    if os.path.exists(opt_file):
        if os.path.isfile(opt_file):
            opt = options.load(opt_file)
            run(opt)
コード例 #15
0
    def __init__(self, opt: Opt, shared=None):
        # Must call _get_init_model() first so that paths are updated if necessary
        # (e.g., a .dict file)
        init_model, is_finetune = self._get_init_model(opt, shared)
        opt['rank_candidates'] = True
        self._set_candidate_variables(opt)
        super().__init__(opt, shared)

        states: Dict[str, Any]
        if shared:
            states = {}
        else:
            # Note: we cannot change the type of metrics ahead of time, so you
            # should correctly initialize to floats or ints here
            self.criterion = self.build_criterion()
            self.model = self.build_model()

            if self.model is None or self.criterion is None:
                raise AttributeError(
                    'build_model() and build_criterion() need to return the model '
                    'or criterion')
            train_params = trainable_parameters(self.model)
            total_params = total_parameters(self.model)
            print(
                f"Total parameters: {total_params:,d} ({train_params:,d} trainable)"
            )

            if self.fp16:
                self.model = self.model.half()
            if init_model:
                print('Loading existing model parameters from ' + init_model)
                states = self.load(init_model)
            else:
                states = {}

            if self.use_cuda:
                if self.model_parallel:
                    self.model = PipelineHelper().make_parallel(self.model)
                else:
                    self.model.cuda()
                if self.data_parallel:
                    self.model = torch.nn.DataParallel(self.model)
                self.criterion.cuda()

        self.rank_top_k = opt.get('rank_top_k', -1)

        # Set fixed and vocab candidates if applicable
        self.set_fixed_candidates(shared)
        self.set_vocab_candidates(shared)

        if shared:
            # We don't use get here because hasattr is used on optimizer later.
            if 'optimizer' in shared:
                self.optimizer = shared['optimizer']
        elif self._should_initialize_optimizer():
            # only build an optimizer if we're training
            optim_params = [
                p for p in self.model.parameters() if p.requires_grad
            ]
            self.init_optim(optim_params, states.get('optimizer'),
                            states.get('optimizer_type'))
            self.build_lr_scheduler(states, hard_reset=is_finetune)

        if shared is None and is_distributed():
            device_ids = None if self.model_parallel else [self.opt['gpu']]
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=device_ids, broadcast_buffers=False)