Example #1
0
        def get_tl(tmpdir):
            final_opt = Opt({
                'task': 'integration_tests',
                'datatype': 'valid',
                'validation_max_exs': 30,
                'short_final_eval': True,
            })
            final_opt.save(os.path.join(tmpdir, "final_opt.opt"))

            opt = Opt({
                'task':
                'integration_tests',
                'validation_max_exs':
                10,
                'model':
                'repeat_label',
                'model_file':
                os.path.join(tmpdir, 'model'),
                'short_final_eval':
                True,
                'num_epochs':
                1.0,
                'final_extra_opt':
                str(os.path.join(tmpdir, "final_opt.opt")),
            })
            parser = tms.setup_args()
            parser.set_params(**opt)
            popt = parser.parse_args([])
            for k, v in opt.items():
                popt[k] = v
            return tms.TrainLoop(popt)
Example #2
0
    def __init__(self, opt: Opt, model=None):
        try:
            # wand is a very expensive thing to import. Wait until the
            # last second to import it.
            import wandb

        except ImportError:
            raise ImportError('Please run `pip install wandb`.')

        name = opt.get('wandb_name')
        project = opt.get('wandb_project') or datetime.datetime.now().strftime(
            '%Y-%m-%d-%H-%M')

        self.run = wandb.init(
            name=name,
            project=project,
            dir=os.path.dirname(opt['model_file']),
            notes=f"{opt['model_file']}",
            entity=opt.get('wandb_entity'),
            reinit=True,  # in case of preemption
        )
        # suppress wandb's output
        logging.getLogger("wandb").setLevel(logging.ERROR)
        for key, value in opt.items():
            if value is None or isinstance(value,
                                           (str, numbers.Number, tuple)):
                setattr(self.run.config, key, value)
        if model is not None:
            self.run.watch(model)
Example #3
0
def train_model(opt: Opt) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Run through a TrainLoop.

    If model_file is not in opt, then this helper will create a temporary
    directory to store the model, dict, etc.

    :return: (stdout, valid_results, test_results)
    :rtype: (str, dict, dict)
    """
    import parlai.scripts.train_model as tms

    with tempdir() as tmpdir:
        if 'model_file' not in opt:
            opt['model_file'] = os.path.join(tmpdir, 'model')
        if 'dict_file' not in opt:
            opt['dict_file'] = os.path.join(tmpdir, 'model.dict')
        parser = tms.setup_args()
        # needed at the very least to set the overrides.
        parser.set_params(**opt)
        parser.set_params(log_every_n_secs=10)
        popt = parser.parse_args([], print_args=False)
        # in some rare cases, like for instance if the model class also
        # overrides its default params, the params override will not
        # be taken into account.
        for k, v in opt.items():
            popt[k] = v
        tl = tms.TrainLoop(popt)
        valid, test = tl.train()

    return valid, test
Example #4
0
def compare_init_model_opts(opt: Opt, curr_opt: Opt):
    """
    Print loud warning when `init_model` opts differ from previous configuration.
    """
    if opt.get('init_model') is None:
        return
    opt['init_model'] = modelzoo_path(opt['datapath'], opt['init_model'])
    optfile = opt['init_model'] + '.opt'
    if not os.path.isfile(optfile):
        return
    init_model_opt = Opt.load(optfile)

    extra_opts = {}
    different_opts = {}
    exempt_opts = [
        'model_file',
        'dict_file',
        'override',
        'starttime',
        'init_model',
        'batchindex',
    ]

    # search through init model opts
    for k, v in init_model_opt.items():
        if (k not in exempt_opts and k in init_model_opt
                and init_model_opt[k] != curr_opt.get(k)):
            if isinstance(v, list):
                if init_model_opt[k] != list(curr_opt[k]):
                    different_opts[k] = ','.join([str(x) for x in v])
            else:
                different_opts[k] = v

    # search through opts to load
    for k, v in curr_opt.items():
        if k not in exempt_opts and k not in init_model_opt:
            if isinstance(v, list):
                extra_opts[k] = ','.join([str(x) for x in v])
            else:
                extra_opts[k] = v

    # print warnings
    extra_strs = ['{}: {}'.format(k, v) for k, v in extra_opts.items()]
    if extra_strs:
        print('\n' + '*' * 75)
        print('[ WARNING ] : your model is being loaded with opts that do not '
              'exist in the model you are initializing the weights with: '
              '{}'.format(','.join(extra_strs)))

    different_strs = [
        '--{} {}'.format(k, v).replace('_', '-')
        for k, v in different_opts.items()
    ]
    if different_strs:
        print('\n' + '*' * 75)
        print('[ WARNING ] : your model is being loaded with opts that differ '
              'from the model you are initializing the weights with. Add the '
              'following args to your run command to change this: \n'
              '\n{}'.format(' '.join(different_strs)))
        print('*' * 75)
Example #5
0
def create_agent_from_opt_file(opt: Opt):
    """
    Load agent options and module from file if opt file exists.

    Checks to see if file exists opt['model_file'] + ".opt"; if so, load up the
    options from the file and use that to create an agent, loading the model
    type from that file and overriding any options specified in that file when
    instantiating the agent.

    If that file does not exist, return None.
    """
    model_file = opt['model_file']
    optfile = model_file + '.opt'
    if os.path.isfile(optfile):
        new_opt = Opt.load(optfile)
        # TODO we need a better way to say these options are never copied...
        if 'datapath' in new_opt:
            # never use the datapath from an opt dump
            del new_opt['datapath']
        if 'batchindex' in new_opt:
            # This saved variable can cause trouble if we switch to BS=1 at test time
            del new_opt['batchindex']
        # only override opts specified in 'override' dict
        if opt.get('override'):
            for k, v in opt['override'].items():
                if k in new_opt and str(v) != str(new_opt.get(k)):
                    logging.warn(
                        f"overriding opt['{k}'] to {v} (previously: {new_opt.get(k)})"
                    )
                new_opt[k] = v

        model_class = load_agent_module(new_opt['model'])

        if hasattr(model_class, 'upgrade_opt'):
            new_opt = model_class.upgrade_opt(new_opt)

        # add model arguments to new_opt if they aren't in new_opt already
        for k, v in opt.items():
            if k not in new_opt:
                new_opt[k] = v
        new_opt['model_file'] = model_file
        if not new_opt.get('dict_file'):
            new_opt['dict_file'] = model_file + '.dict'
        elif new_opt.get('dict_file') and not os.path.isfile(
                new_opt['dict_file']):
            old_dict_file = new_opt['dict_file']
            new_opt['dict_file'] = model_file + '.dict'
        if not os.path.isfile(new_opt['dict_file']):
            warn_once(
                'WARNING: Neither the specified dict file ({}) nor the '
                '`model_file`.dict file ({}) exists, check to make sure either '
                'is correct. This may manifest as a shape mismatch later '
                'on.'.format(old_dict_file, new_opt['dict_file']))

        # if we want to load weights from --init-model, compare opts with
        # loaded ones
        compare_init_model_opts(opt, new_opt)
        return model_class(new_opt)
    else:
        return None
Example #6
0
def create_agent_from_opt_file(opt: Opt):
    """
    Load agent options and module from file if opt file exists.

    Checks to see if file exists opt['model_file'] + ".opt"; if so, load up the
    options from the file and use that to create an agent, loading the model
    type from that file and overriding any options specified in that file when
    instantiating the agent.

    If that file does not exist, return None.
    """
    model_file = opt['model_file']
    optfile = model_file + '.opt'

    if not PathManager.exists(optfile):
        return None

    opt_from_file = Opt.load(optfile)

    # delete args that we do not want to copy over when loading the model
    for arg in NOCOPY_ARGS:
        if arg in opt_from_file:
            del opt_from_file[arg]

    # only override opts specified in 'override' dict
    if opt.get('override'):
        for k, v in opt['override'].items():
            if k in opt_from_file and str(v) != str(opt_from_file.get(k)):
                logging.warning(
                    f'Overriding opt["{k}"] to {v} (previously: {opt_from_file.get(k)})'
                )
            opt_from_file[k] = v

    model_class = load_agent_module(opt_from_file['model'])

    if hasattr(model_class, 'upgrade_opt'):
        opt_from_file = model_class.upgrade_opt(opt_from_file)

    # add model arguments to opt_from_file if they aren't in opt_from_file already
    for k, v in opt.items():
        if k not in opt_from_file:
            opt_from_file[k] = v

    # update model file path to the one set by opt
    opt_from_file['model_file'] = model_file
    # update init model path to the one set by opt
    # NOTE: this step is necessary when for example the 'init_model' is
    # set by the Train Loop (as is the case when loading from checkpoint)
    if opt.get('init_model') is not None:
        opt_from_file['init_model'] = opt['init_model']

    # update dict file path
    if not opt_from_file.get('dict_file'):
        old_dict_file = None
        opt_from_file['dict_file'] = model_file + '.dict'
    elif opt_from_file.get('dict_file') and not PathManager.exists(
            opt_from_file['dict_file']):
        old_dict_file = opt_from_file['dict_file']
        opt_from_file['dict_file'] = model_file + '.dict'
    if not PathManager.exists(opt_from_file['dict_file']):
        warn_once(
            'WARNING: Neither the specified dict file ({}) nor the '
            '`model_file`.dict file ({}) exists, check to make sure either '
            'is correct. This may manifest as a shape mismatch later '
            'on.'.format(old_dict_file, opt_from_file['dict_file']))

    # if we want to load weights from --init-model, compare opts with
    # loaded ones
    compare_init_model_opts(opt, opt_from_file)
    return model_class(opt_from_file)
Example #7
0
class ParlaiParser(argparse.ArgumentParser):
    """
    Provide an opt-producer and CLI argument parser.

    Pseudo-extension of ``argparse`` which sets a number of parameters
    for the ParlAI framework. More options can be added specific to other
    modules by passing this object and calling ``add_arg()`` or
    ``add_argument()`` on it.

    For example, see ``parlai.core.dict.DictionaryAgent.add_cmdline_args``.

    :param add_parlai_args:
        (default True) initializes the default arguments for ParlAI
        package, including the data download paths and task arguments.
    :param add_model_args:
        (default False) initializes the default arguments for loading
        models, including initializing arguments from that model.
    """
    def __init__(self,
                 add_parlai_args=True,
                 add_model_args=False,
                 description=None,
                 **kwargs):
        """
        Initialize the ParlAI argparser.
        """
        if 'formatter_class' not in kwargs:
            kwargs['formatter_class'] = CustomHelpFormatter

        super().__init__(
            description=description,
            allow_abbrev=False,
            conflict_handler='resolve',
            add_help=True,
            **kwargs,
        )
        self.register('action', 'helpall', _HelpAllAction)
        self.register('type', 'nonestr', str2none)
        self.register('type', 'bool', str2bool)
        self.register('type', 'floats', str2floats)
        self.register('type', 'multitask_weights', str2multitask_weights)
        self.register('type', 'class', str2class)
        self.parlai_home = os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
        os.environ['PARLAI_HOME'] = self.parlai_home

        self.add_arg = self.add_argument

        # remember which args were specified on the command line
        self.overridable = {}

        if add_parlai_args:
            self.add_parlai_args()
        if add_model_args:
            self.add_model_args()

    def add_parlai_data_path(self, argument_group=None):
        """
        Add --datapath CLI arg.
        """
        if argument_group is None:
            argument_group = self
        argument_group.add_argument(
            '-dp',
            '--datapath',
            default=None,
            help='path to datasets, defaults to {parlai_dir}/data',
        )

    def add_mturk_args(self):
        """
        Add standard mechanical turk arguments.
        """
        mturk = self.add_argument_group('Mechanical Turk')
        default_log_path = os.path.join(self.parlai_home, 'logs', 'mturk')
        mturk.add_argument(
            '--mturk-log-path',
            default=default_log_path,
            help='path to MTurk logs, defaults to {parlai_dir}/logs/mturk',
        )
        mturk.add_argument(
            '-t',
            '--task',
            help='MTurk task, e.g. "qa_data_collection" or "model_evaluator"',
        )
        mturk.add_argument(
            '-nc',
            '--num-conversations',
            default=1,
            type=int,
            help='number of conversations you want to create for this task',
        )
        mturk.add_argument(
            '--unique',
            dest='unique_worker',
            default=False,
            action='store_true',
            help='enforce that no worker can work on your task twice',
        )
        mturk.add_argument(
            '--max-hits-per-worker',
            dest='max_hits_per_worker',
            default=0,
            type=int,
            help=
            'Max number of hits each worker can perform during current group run',
        )
        mturk.add_argument(
            '--unique-qual-name',
            dest='unique_qual_name',
            default=None,
            type=str,
            help='qualification name to use for uniqueness between HITs',
        )
        mturk.add_argument(
            '-r',
            '--reward',
            default=0.05,
            type=float,
            help='reward for each worker for finishing the conversation, '
            'in US dollars',
        )
        mturk.add_argument(
            '--sandbox',
            dest='is_sandbox',
            action='store_true',
            help='submit the HITs to MTurk sandbox site',
        )
        mturk.add_argument(
            '--live',
            dest='is_sandbox',
            action='store_false',
            help='submit the HITs to MTurk live site',
        )
        mturk.add_argument(
            '--debug',
            dest='is_debug',
            action='store_true',
            help='print and log all server interactions and messages',
        )
        mturk.add_argument(
            '--verbose',
            dest='verbose',
            action='store_true',
            help='print all messages sent to and from Turkers',
        )
        mturk.add_argument(
            '--hard-block',
            dest='hard_block',
            action='store_true',
            default=False,
            help='Hard block disconnecting Turkers from all of your HITs',
        )
        mturk.add_argument(
            '--log-level',
            dest='log_level',
            type=int,
            default=20,
            help='importance level for what to put into the logs. the lower '
            'the level the more that gets logged. values are 0-50',
        )
        mturk.add_argument(
            '--disconnect-qualification',
            dest='disconnect_qualification',
            default=None,
            help='Qualification to use for soft blocking users for '
            'disconnects. By default '
            'turkers are never blocked, though setting this will allow '
            'you to filter out turkers that have disconnected too many '
            'times on previous HITs where this qualification was set.',
        )
        mturk.add_argument(
            '--block-qualification',
            dest='block_qualification',
            default=None,
            help='Qualification to use for soft blocking users. This '
            'qualification is granted whenever soft_block_worker is '
            'called, and can thus be used to filter workers out from a '
            'single task or group of tasks by noted performance.',
        )
        mturk.add_argument(
            '--count-complete',
            dest='count_complete',
            default=False,
            action='store_true',
            help='continue until the requested number of conversations are '
            'completed rather than attempted',
        )
        mturk.add_argument(
            '--allowed-conversations',
            dest='allowed_conversations',
            default=0,
            type=int,
            help='number of concurrent conversations that one mturk worker '
            'is able to be involved in, 0 is unlimited',
        )
        mturk.add_argument(
            '--max-connections',
            dest='max_connections',
            default=30,
            type=int,
            help='number of HITs that can be launched at the same time, 0 is '
            'unlimited.',
        )
        mturk.add_argument(
            '--min-messages',
            dest='min_messages',
            default=0,
            type=int,
            help='number of messages required to be sent by MTurk agent when '
            'considering whether to approve a HIT in the event of a '
            'partner disconnect. I.e. if the number of messages '
            'exceeds this number, the turker can submit the HIT.',
        )
        mturk.add_argument(
            '--local',
            dest='local',
            default=False,
            action='store_true',
            help='Run the server locally on this server rather than setting up'
            ' a heroku server.',
        )
        mturk.add_argument(
            '--hobby',
            dest='hobby',
            default=False,
            action='store_true',
            help='Run the heroku server on the hobby tier.',
        )
        mturk.add_argument(
            '--max-time',
            dest='max_time',
            default=0,
            type=int,
            help='Maximum number of seconds per day that a worker is allowed '
            'to work on this assignment',
        )
        mturk.add_argument(
            '--max-time-qual',
            dest='max_time_qual',
            default=None,
            help='Qualification to use to share the maximum time requirement '
            'with other runs from other machines.',
        )
        mturk.add_argument(
            '--heroku-team',
            dest='heroku_team',
            default=None,
            help='Specify Heroku team name to use for launching Dynos.',
        )
        mturk.add_argument(
            '--tmp-dir',
            dest='tmp_dir',
            default=None,
            help='Specify location to use for scratch builds and such.',
        )

        # it helps to indicate to agents that they're in interactive mode, and
        # can avoid some otherwise troublesome behavior (not having candidates,
        # sharing self.replies, etc).
        mturk.set_defaults(interactive_mode=True)

        mturk.set_defaults(is_sandbox=True)
        mturk.set_defaults(is_debug=False)
        mturk.set_defaults(verbose=False)

    def add_chatservice_args(self):
        """
        Arguments for all chat services.
        """
        args = self.add_argument_group('Chat Services')
        args.add_argument(
            '--debug',
            dest='is_debug',
            action='store_true',
            help='print and log all server interactions and messages',
        )
        args.add_argument(
            '--config-path',
            default=None,
            type=str,
            help='/path/to/config/file for a given task.',
        )
        args.add_argument(
            '--password',
            dest='password',
            type=str,
            default=None,
            help='Require a password for entry to the bot',
        )

    def add_websockets_args(self):
        """
        Add websocket arguments.
        """
        self.add_chatservice_args()
        websockets = self.add_argument_group('Websockets')
        websockets.add_argument('--port',
                                default=35496,
                                type=int,
                                help='Port to run the websocket handler')

    def add_messenger_args(self):
        """
        Add Facebook Messenger arguments.
        """
        self.add_chatservice_args()
        messenger = self.add_argument_group('Facebook Messenger')
        messenger.add_argument(
            '--verbose',
            dest='verbose',
            action='store_true',
            help='print all messages sent to and from Turkers',
        )
        messenger.add_argument(
            '--log-level',
            dest='log_level',
            type=int,
            default=20,
            help='importance level for what to put into the logs. the lower '
            'the level the more that gets logged. values are 0-50',
        )
        messenger.add_argument(
            '--force-page-token',
            dest='force_page_token',
            action='store_true',
            help='override the page token stored in the cache for a new one',
        )
        messenger.add_argument(
            '--bypass-server-setup',
            dest='bypass_server_setup',
            action='store_true',
            default=False,
            help='should bypass traditional server and socket setup',
        )
        messenger.add_argument(
            '--local',
            dest='local',
            action='store_true',
            default=False,
            help='Run the server locally on this server rather than setting up'
            ' a heroku server.',
        )

        messenger.set_defaults(is_debug=False)
        messenger.set_defaults(verbose=False)

    def add_parlai_args(self, args=None):
        """
        Add common ParlAI args across all scripts.
        """
        self.add_argument(
            '--helpall',
            action='helpall',
            help='Show usage, including advanced arguments.',
        )
        parlai = self.add_argument_group('Main ParlAI Arguments')
        parlai.add_argument(
            '-o',
            '--init-opt',
            default=None,
            help='Path to json file of options. '
            'Note: Further Command-line arguments override file-based options.',
        )
        parlai.add_argument(
            '-t',
            '--task',
            help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"')
        parlai.add_argument(
            '--download-path',
            default=None,
            hidden=True,
            help='path for non-data dependencies to store any needed files.'
            'defaults to {parlai_dir}/downloads',
        )
        parlai.add_argument(
            '--loglevel',
            default='info',
            hidden=True,
            choices=logging.get_all_levels(),
            help='Logging level',
        )
        parlai.add_argument(
            '-dt',
            '--datatype',
            metavar='DATATYPE',
            default='train',
            choices=[
                'train',
                'train:stream',
                'train:ordered',
                'train:ordered:stream',
                'train:stream:ordered',
                'train:evalmode',
                'train:evalmode:stream',
                'train:evalmode:ordered',
                'train:evalmode:ordered:stream',
                'train:evalmode:stream:ordered',
                'valid',
                'valid:stream',
                'test',
                'test:stream',
            ],
            help='choose from: train, train:ordered, valid, test. to stream '
            'data add ":stream" to any option (e.g., train:stream). '
            'by default train is random with replacement, '
            'valid is ordered, test is ordered.',
        )
        parlai.add_argument(
            '-im',
            '--image-mode',
            default='raw',
            type=str,
            help='image preprocessor to use. default is "raw". set to "none" '
            'to skip image loading.',
            hidden=True,
        )
        parlai.add_argument(
            '-nt',
            '--numthreads',
            default=1,
            type=int,
            help='number of threads. Used for hogwild if batchsize is 1, else '
            'for number of threads in threadpool loading,',
        )
        parlai.add_argument(
            '--hide-labels',
            default=False,
            type='bool',
            hidden=True,
            help='default (False) moves labels in valid and test sets to the '
            'eval_labels field. If True, they are hidden completely.',
        )
        parlai.add_argument(
            '-mtw',
            '--multitask-weights',
            type='multitask_weights',
            default=[1],
            help=
            ('list of floats, one for each task, specifying '
             'the probability of drawing the task in multitask case. You may also '
             'provide "stochastic" to simulate simple concatenation.'),
            hidden=True,
        )
        parlai.add_argument(
            '-bs',
            '--batchsize',
            default=1,
            type=int,
            help='batch size for minibatch training schemes',
        )
        parlai.add_argument(
            '-dynb',
            '--dynamic-batching',
            default=None,
            type='nonestr',
            choices={None, 'full', 'batchsort'},
            help='Use dynamic batching',
        )
        self.add_parlai_data_path(parlai)

    def add_distributed_training_args(self):
        """
        Add CLI args for distributed training.
        """
        grp = self.add_argument_group('Distributed Training')
        grp.add_argument('--distributed-world-size',
                         type=int,
                         help='Number of workers.')
        grp.add_argument(
            '--verbose',
            type='bool',
            default=False,
            help='All workers print output.',
            hidden=True,
        )
        return grp

    def add_model_args(self):
        """
        Add arguments related to models such as model files.
        """
        model_args = self.add_argument_group('ParlAI Model Arguments')
        model_args.add_argument(
            '-m',
            '--model',
            default=None,
            help='the model class name. can match parlai/agents/<model> for '
            'agents in that directory, or can provide a fully specified '
            'module for `from X import Y` via `-m X:Y` '
            '(e.g. `-m parlai.agents.seq2seq.seq2seq:Seq2SeqAgent`)',
        )
        model_args.add_argument(
            '-mf',
            '--model-file',
            default=None,
            help='model file name for loading and saving models',
        )
        model_args.add_argument(
            '-im',
            '--init-model',
            default=None,
            type=str,
            help='load model weights and dict from this file',
        )
        model_args.add_argument('--dict-class',
                                hidden=True,
                                help='the class of the dictionary agent uses')

    def add_model_subargs(self, model):
        """
        Add arguments specific to a particular model.
        """
        agent = load_agent_module(model)
        try:
            if hasattr(agent, 'add_cmdline_args'):
                agent.add_cmdline_args(self)
        except argparse.ArgumentError:
            # already added
            pass
        try:
            if hasattr(agent, 'dictionary_class'):
                s = class2str(agent.dictionary_class())
                self.set_defaults(dict_class=s)
        except argparse.ArgumentError:
            # already added
            pass

    def add_task_args(self, task):
        """
        Add arguments specific to the specified task.
        """
        for t in ids_to_tasks(task).split(','):
            agent = load_teacher_module(t)
            try:
                if hasattr(agent, 'add_cmdline_args'):
                    agent.add_cmdline_args(self)
            except argparse.ArgumentError:
                # already added
                pass

    def add_world_args(self, task, interactive_task, selfchat_task):
        """
        Add arguments specific to the world.
        """
        world_class = load_world_module(task,
                                        interactive_task=interactive_task,
                                        selfchat_task=selfchat_task)
        if world_class is not None and hasattr(world_class,
                                               'add_cmdline_args'):
            try:
                world_class.add_cmdline_args(self)
            except argparse.ArgumentError:
                # already added
                pass

    def add_image_args(self, image_mode):
        """
        Add additional arguments for handling images.
        """
        try:
            parlai = self.add_argument_group(
                'ParlAI Image Preprocessing Arguments')
            parlai.add_argument(
                '--image-size',
                type=int,
                default=256,
                help='resizing dimension for images',
                hidden=True,
            )
            parlai.add_argument(
                '--image-cropsize',
                type=int,
                default=224,
                help='crop dimension for images',
                hidden=True,
            )
        except argparse.ArgumentError:
            # already added
            pass

    def add_extra_args(self, args=None):
        """
        Add more args depending on how known args are set.
        """
        parsed = vars(self.parse_known_args(args, nohelp=True)[0])
        # Also load extra args options if a file is given.
        if parsed.get('init_opt') is not None:
            self._load_known_opts(parsed.get('init_opt'), parsed)
        parsed = self._infer_datapath(parsed)

        # find which image mode specified if any, and add additional arguments
        image_mode = parsed.get('image_mode', None)
        if image_mode is not None and image_mode != 'no_image_model':
            self.add_image_args(image_mode)

        # find which task specified if any, and add its specific arguments
        task = parsed.get('task', None)
        if task is not None:
            self.add_task_args(task)
        evaltask = parsed.get('evaltask', None)
        if evaltask is not None:
            self.add_task_args(evaltask)

        # find which model specified if any, and add its specific arguments
        model = get_model_name(parsed)
        if model is not None:
            self.add_model_subargs(model)

        # add world args, if we know a priori which world is being used
        if task is not None:
            self.add_world_args(
                task,
                parsed.get('interactive_task', False),
                parsed.get('selfchat_task', False),
            )

        # reset parser-level defaults over any model-level defaults
        try:
            self.set_defaults(**self._defaults)
        except AttributeError:
            raise RuntimeError('Please file an issue on github that argparse '
                               'got an attribute error when parsing.')

    def parse_known_args(self, args=None, namespace=None, nohelp=False):
        """
        Parse known args to ignore help flag.
        """
        if args is None:
            # args default to the system args
            args = _sys.argv[1:]
        args = fix_underscores(args)

        if nohelp:
            # ignore help
            args = [
                a for a in args
                if a != '-h' and a != '--help' and a != '--helpall'
            ]
        return super().parse_known_args(args, namespace)

    def _load_known_opts(self, optfile, parsed):
        """
        Pull in CLI args for proper models/tasks/etc.

        Called before args are parsed; ``_load_opts`` is used for actually overriding
        opts after they are parsed.
        """
        new_opt = Opt.load(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in parsed or parsed[key] is None:
                parsed[key] = value

    def _load_opts(self, opt):
        optfile = opt.get('init_opt')
        new_opt = Opt.load(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in opt:
                raise RuntimeError(
                    'Trying to set opt from file that does not exist: ' +
                    str(key))
            if key not in opt['override']:
                opt[key] = value
                opt['override'][key] = value

    def _infer_datapath(self, opt):
        """
        Set the value for opt['datapath'] and opt['download_path'].

        Sets the value for opt['datapath'] and opt['download_path'], correctly
        respecting environmental variables and the default.
        """
        # set environment variables
        # Priority for setting the datapath (same applies for download_path):
        # --datapath -> os.environ['PARLAI_DATAPATH'] -> <self.parlai_home>/data
        if opt.get('download_path'):
            os.environ['PARLAI_DOWNPATH'] = opt['download_path']
        elif os.environ.get('PARLAI_DOWNPATH') is None:
            os.environ['PARLAI_DOWNPATH'] = os.path.join(
                self.parlai_home, 'downloads')
        if opt.get('datapath'):
            os.environ['PARLAI_DATAPATH'] = opt['datapath']
        elif os.environ.get('PARLAI_DATAPATH') is None:
            os.environ['PARLAI_DATAPATH'] = os.path.join(
                self.parlai_home, 'data')

        opt['download_path'] = os.environ['PARLAI_DOWNPATH']
        opt['datapath'] = os.environ['PARLAI_DATAPATH']

        return opt

    def _process_args_to_opts(self,
                              args_that_override: Optional[List[str]] = None):
        self.opt = Opt(vars(self.args))

        # custom post-parsing
        self.opt['parlai_home'] = self.parlai_home
        self.opt = self._infer_datapath(self.opt)

        # set all arguments specified in command line as overridable
        option_strings_dict = {}
        store_true = []
        store_false = []
        for group in self._action_groups:
            for a in group._group_actions:
                if hasattr(a, 'option_strings'):
                    for option in a.option_strings:
                        option_strings_dict[option] = a.dest
                        if '_StoreTrueAction' in str(type(a)):
                            store_true.append(option)
                        elif '_StoreFalseAction' in str(type(a)):
                            store_false.append(option)

        if args_that_override is None:
            args_that_override = _sys.argv[1:]

        for i in range(len(args_that_override)):
            if args_that_override[i] in option_strings_dict:
                if args_that_override[i] in store_true:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = True
                elif args_that_override[i] in store_false:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = False
                elif (i < len(args_that_override) - 1 and
                      args_that_override[i + 1] not in option_strings_dict):
                    key = option_strings_dict[args_that_override[i]]
                    self.overridable[key] = self.opt[key]
        self.opt['override'] = self.overridable

        # load opts if a file is provided.
        if self.opt.get('init_opt', None) is not None:
            self._load_opts(self.opt)

        # map filenames that start with 'zoo:' to point to the model zoo dir
        options_to_change = {
            'model_file', 'dict_file', 'bpe_vocab', 'bpe_merge'
        }
        for each_key in options_to_change:
            if self.opt.get(each_key) is not None:
                self.opt[each_key] = modelzoo_path(self.opt.get('datapath'),
                                                   self.opt[each_key])
            if self.opt['override'].get(each_key) is not None:
                # also check override
                self.opt['override'][each_key] = modelzoo_path(
                    self.opt.get('datapath'), self.opt['override'][each_key])

        # add start time of an experiment
        self.opt['starttime'] = datetime.datetime.today().strftime(
            '%b%d_%H-%M')

    def parse_and_process_known_args(self, args=None):
        """
        Parse provided arguments and return parlai opts and unknown arg list.

        Runs the same arg->opt parsing that parse_args does, but doesn't throw an error
        if the args being parsed include additional command line arguments that parlai
        doesn't know what to do with.
        """
        self.args, unknowns = super().parse_known_args(args=args)
        self._process_args_to_opts(args)
        return self.opt, unknowns

    def parse_args(self, args=None, namespace=None, print_args=True):
        """
        Parse the provided arguments and returns a dictionary of the ``args``.

        We specifically remove items with ``None`` as values in order to support the
        style ``opt.get(key, default)``, which would otherwise return ``None``.
        """
        self.add_extra_args(args)
        self.args = super().parse_args(args=args)

        self._process_args_to_opts(args)

        if print_args:
            self.print_args()
            if GIT_AVAILABLE:
                print_git_commit()
            print_announcements(self.opt)

        logging.set_log_level(self.opt.get('loglevel', 'info').upper())

        return self.opt

    def _value2argstr(self, value) -> str:
        """
        Reverse-parse an opt value into one interpretable by argparse.
        """
        if isinstance(value, (list, tuple)):
            return ",".join(str(v) for v in value)
        else:
            return str(value)

    def _kwargs_to_str_args(self, **kwargs):
        """
        Attempt to map from python-code kwargs into CLI args.

        e.g. model_file -> --model-file.

        Works with short options too, like t="convai2".
        """

        # we have to do this large block of repetitive code twice, the first
        # round is basically just to become aware of anything that would have
        # been added by add_extra_args
        kwname_to_action = {}
        for action in self._actions:
            if action.dest == 'help':
                # no help allowed
                continue
            for option_string in action.option_strings:
                kwname = option_string.lstrip('-').replace('-', '_')
                assert (kwname not in kwname_to_action) or (
                    kwname_to_action[kwname] is action
                ), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
                kwname_to_action[kwname] = action

        string_args = []
        for kwname, value in kwargs.items():
            if kwname not in kwname_to_action:
                # best guess, we need to delay it. hopefully this gets added
                # during add_kw_Args
                continue
            action = kwname_to_action[kwname]
            last_option_string = action.option_strings[-1]
            if isinstance(action, argparse._StoreTrueAction):
                if bool(value):
                    string_args.append(last_option_string)
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs is None:
                string_args.append(last_option_string)
                string_args.append(self._value2argstr(value))
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs in '*+':
                string_args.append(last_option_string)
                string_args.extend([self._value2argstr(value) for v in value])
            else:
                raise TypeError(f"Don't know what to do with {action}")

        # become aware of any extra args that might be specified if the user
        # provides something like model="transformer/generator".
        self.add_extra_args(string_args)

        # do it again, this time knowing about ALL args.
        kwname_to_action = {}
        for action in self._actions:
            if action.dest == 'help':
                # no help allowed
                continue
            for option_string in action.option_strings:
                kwname = option_string.lstrip('-').replace('-', '_')
                assert (kwname not in kwname_to_action) or (
                    kwname_to_action[kwname] is action
                ), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
                kwname_to_action[kwname] = action

        string_args = []
        for kwname, value in kwargs.items():
            # note we don't have the if kwname not in kwname_to_action here.
            # it MUST appear, or else we legitimately should be throwing a KeyError
            # because user has provided an unspecified option
            action = kwname_to_action[kwname]
            last_option_string = action.option_strings[-1]
            if isinstance(action, argparse._StoreTrueAction):
                if bool(value):
                    string_args.append(last_option_string)
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs is None:
                string_args.append(last_option_string)
                string_args.append(self._value2argstr(value))
            elif isinstance(action,
                            argparse._StoreAction) and action.nargs in '*+':
                string_args.append(last_option_string)
                # Special case: Labels
                string_args.extend([str(v) for v in value])
            else:
                raise TypeError(f"Don't know what to do with {action}")

        return string_args

    def parse_kwargs(self, **kwargs):
        """
        Parse kwargs, with type checking etc.
        """

        # hack: capture any error messages without raising a SystemExit
        def _captured_error(msg):
            raise ValueError(msg)

        old_error = self.error
        self.error = _captured_error
        try:
            string_args = self._kwargs_to_str_args(**kwargs)
            return self.parse_args(args=string_args, print_args=False)
        finally:
            self.error = old_error

    def print_args(self):
        """
        Print out all the arguments in this parser.
        """
        if not self.opt:
            self.parse_args(print_args=False)
        values = {}
        for key, value in self.opt.items():
            values[str(key)] = str(value)
        for group in self._action_groups:
            group_dict = {
                a.dest: getattr(self.args, a.dest, None)
                for a in group._group_actions
            }
            namespace = argparse.Namespace(**group_dict)
            count = 0
            for key in sorted(namespace.__dict__):
                if key in values:
                    if count == 0:
                        print('[ ' + group.title + ': ] ')
                    count += 1
                    print('[  ' + key + ': ' + values[key] + ' ]')

    def set_params(self, **kwargs):
        """
        Set overridable kwargs.
        """
        self.set_defaults(**kwargs)
        for k, v in kwargs.items():
            self.overridable[k] = v

    def _unsuppress_hidden(self):
        for action in self._actions:
            if hasattr(action, 'real_help'):
                action.help = action.real_help

    def _handle_custom_options(self, kwargs):
        """
        Handle custom parlai options.

        Includes hidden, recommended. Future may include no_save and no_override.
        """
        action_attr = {}
        if 'recommended' in kwargs:
            rec = kwargs.pop('recommended')
            action_attr['recommended'] = rec
        action_attr['hidden'] = kwargs.get('hidden', False)
        action_attr['real_help'] = kwargs.get('help', None)
        if 'hidden' in kwargs:
            if kwargs.pop('hidden'):
                kwargs['help'] = argparse.SUPPRESS

        if 'type' in kwargs and kwargs['type'] is bool:
            # common error, we really want simple form
            kwargs['type'] = 'bool'
        return kwargs, action_attr

    def add_argument(self, *args, **kwargs):
        """
        Override to convert underscores to hyphens for consistency.
        """
        kwargs, newattr = self._handle_custom_options(kwargs)
        action = super().add_argument(*fix_underscores(args), **kwargs)
        for k, v in newattr.items():
            setattr(action, k, v)
        return action

    def add_argument_group(self, *args, **kwargs):
        """
        Override to make arg groups also convert underscores to hyphens.
        """
        arg_group = super().add_argument_group(*args, **kwargs)
        original_add_arg = arg_group.add_argument

        def ag_add_argument(*args, **kwargs):
            kwargs, newattr = self._handle_custom_options(kwargs)
            action = original_add_arg(*fix_underscores(args), **kwargs)
            for k, v in newattr.items():
                setattr(action, k, v)
            return action

        arg_group.add_argument = ag_add_argument  # override _ => -
        arg_group.add_argument_group = self.add_argument_group
        return arg_group

    def error(self, message):
        """
        Override to print custom error message.
        """
        self.print_help()
        _sys.stderr.write('\nParse Error: %s\n' % message)
        _sys.exit(2)
Example #8
0
class ParlaiParser(argparse.ArgumentParser):
    """
    Provide an opt-producer and CLI argument parser.

    Pseudo-extension of ``argparse`` which sets a number of parameters
    for the ParlAI framework. More options can be added specific to other
    modules by passing this object and calling ``add_arg()`` or
    ``add_argument()`` on it.

    For example, see ``parlai.core.dict.DictionaryAgent.add_cmdline_args``.

    :param add_parlai_args:
        (default True) initializes the default arguments for ParlAI
        package, including the data download paths and task arguments.
    :param add_model_args:
        (default False) initializes the default arguments for loading
        models, including initializing arguments from that model.
    """
    def __init__(self,
                 add_parlai_args=True,
                 add_model_args=False,
                 description='ParlAI parser'):
        """
        Initialize the ParlAI argparser.
        """
        super().__init__(
            description=description,
            allow_abbrev=False,
            conflict_handler='resolve',
            formatter_class=CustomHelpFormatter,
        )
        self.register('type', 'bool', str2bool)
        self.register('type', 'floats', str2floats)
        self.register('type', 'class', str2class)
        self.parlai_home = os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
        os.environ['PARLAI_HOME'] = self.parlai_home

        self.add_arg = self.add_argument

        # remember which args were specified on the command line
        self.overridable = {}

        if add_parlai_args:
            self.add_parlai_args()
        if add_model_args:
            self.add_model_args()

    def add_parlai_data_path(self, argument_group=None):
        """
        Add --datapath CLI arg.
        """
        if argument_group is None:
            argument_group = self
        argument_group.add_argument(
            '-dp',
            '--datapath',
            default=None,
            help='path to datasets, defaults to {parlai_dir}/data',
        )

    def add_mturk_args(self):
        """
        Add standard mechanical turk arguments.
        """
        mturk = self.add_argument_group('Mechanical Turk')
        default_log_path = os.path.join(self.parlai_home, 'logs', 'mturk')
        mturk.add_argument(
            '--mturk-log-path',
            default=default_log_path,
            help='path to MTurk logs, defaults to {parlai_dir}/logs/mturk',
        )
        mturk.add_argument(
            '-t',
            '--task',
            help='MTurk task, e.g. "qa_data_collection" or "model_evaluator"',
        )
        mturk.add_argument(
            '-nc',
            '--num-conversations',
            default=1,
            type=int,
            help='number of conversations you want to create for this task',
        )
        mturk.add_argument(
            '--unique',
            dest='unique_worker',
            default=False,
            action='store_true',
            help='enforce that no worker can work on your task twice',
        )
        mturk.add_argument(
            '--max-hits-per-worker',
            dest='max_hits_per_worker',
            default=0,
            type=int,
            help=
            'Max number of hits each worker can perform during current group run',
        )
        mturk.add_argument(
            '--unique-qual-name',
            dest='unique_qual_name',
            default=None,
            type=str,
            help='qualification name to use for uniqueness between HITs',
        )
        mturk.add_argument(
            '-r',
            '--reward',
            default=0.05,
            type=float,
            help='reward for each worker for finishing the conversation, '
            'in US dollars',
        )
        mturk.add_argument(
            '--sandbox',
            dest='is_sandbox',
            action='store_true',
            help='submit the HITs to MTurk sandbox site',
        )
        mturk.add_argument(
            '--live',
            dest='is_sandbox',
            action='store_false',
            help='submit the HITs to MTurk live site',
        )
        mturk.add_argument(
            '--debug',
            dest='is_debug',
            action='store_true',
            help='print and log all server interactions and messages',
        )
        mturk.add_argument(
            '--verbose',
            dest='verbose',
            action='store_true',
            help='print all messages sent to and from Turkers',
        )
        mturk.add_argument(
            '--hard-block',
            dest='hard_block',
            action='store_true',
            default=False,
            help='Hard block disconnecting Turkers from all of your HITs',
        )
        mturk.add_argument(
            '--log-level',
            dest='log_level',
            type=int,
            default=20,
            help='importance level for what to put into the logs. the lower '
            'the level the more that gets logged. values are 0-50',
        )
        mturk.add_argument(
            '--disconnect-qualification',
            dest='disconnect_qualification',
            default=None,
            help='Qualification to use for soft blocking users for '
            'disconnects. By default '
            'turkers are never blocked, though setting this will allow '
            'you to filter out turkers that have disconnected too many '
            'times on previous HITs where this qualification was set.',
        )
        mturk.add_argument(
            '--block-qualification',
            dest='block_qualification',
            default=None,
            help='Qualification to use for soft blocking users. This '
            'qualification is granted whenever soft_block_worker is '
            'called, and can thus be used to filter workers out from a '
            'single task or group of tasks by noted performance.',
        )
        mturk.add_argument(
            '--count-complete',
            dest='count_complete',
            default=False,
            action='store_true',
            help='continue until the requested number of conversations are '
            'completed rather than attempted',
        )
        mturk.add_argument(
            '--allowed-conversations',
            dest='allowed_conversations',
            default=0,
            type=int,
            help='number of concurrent conversations that one mturk worker '
            'is able to be involved in, 0 is unlimited',
        )
        mturk.add_argument(
            '--max-connections',
            dest='max_connections',
            default=30,
            type=int,
            help='number of HITs that can be launched at the same time, 0 is '
            'unlimited.',
        )
        mturk.add_argument(
            '--min-messages',
            dest='min_messages',
            default=0,
            type=int,
            help='number of messages required to be sent by MTurk agent when '
            'considering whether to approve a HIT in the event of a '
            'partner disconnect. I.e. if the number of messages '
            'exceeds this number, the turker can submit the HIT.',
        )
        mturk.add_argument(
            '--local',
            dest='local',
            default=False,
            action='store_true',
            help='Run the server locally on this server rather than setting up'
            ' a heroku server.',
        )
        mturk.add_argument(
            '--hobby',
            dest='hobby',
            default=False,
            action='store_true',
            help='Run the heroku server on the hobby tier.',
        )
        mturk.add_argument(
            '--max-time',
            dest='max_time',
            default=0,
            type=int,
            help='Maximum number of seconds per day that a worker is allowed '
            'to work on this assignment',
        )
        mturk.add_argument(
            '--max-time-qual',
            dest='max_time_qual',
            default=None,
            help='Qualification to use to share the maximum time requirement '
            'with other runs from other machines.',
        )
        mturk.add_argument(
            '--heroku-team',
            dest='heroku_team',
            default=None,
            help='Specify Heroku team name to use for launching Dynos.',
        )
        mturk.add_argument(
            '--tmp-dir',
            dest='tmp_dir',
            default=None,
            help='Specify location to use for scratch builds and such.',
        )

        # it helps to indicate to agents that they're in interactive mode, and
        # can avoid some otherwise troublesome behavior (not having candidates,
        # sharing self.replies, etc).
        mturk.set_defaults(interactive_mode=True)

        mturk.set_defaults(is_sandbox=True)
        mturk.set_defaults(is_debug=False)
        mturk.set_defaults(verbose=False)

    def add_chatservice_args(self):
        """
        Arguments for all chat services.
        """
        args = self.add_argument_group('Chat Services')
        args.add_argument(
            '--debug',
            dest='is_debug',
            action='store_true',
            help='print and log all server interactions and messages',
        )
        args.add_argument(
            '--config-path',
            default=None,
            type=str,
            help='/path/to/config/file for a given task.',
        )
        args.add_argument(
            '--password',
            dest='password',
            type=str,
            default=None,
            help='Require a password for entry to the bot',
        )

    def add_websockets_args(self):
        """
        Add websocket arguments.
        """
        self.add_chatservice_args()
        websockets = self.add_argument_group('Websockets')
        websockets.add_argument('--port',
                                default=35496,
                                type=int,
                                help='Port to run the websocket handler')

    def add_messenger_args(self):
        """
        Add Facebook Messenger arguments.
        """
        self.add_chatservice_args()
        messenger = self.add_argument_group('Facebook Messenger')
        messenger.add_argument(
            '--verbose',
            dest='verbose',
            action='store_true',
            help='print all messages sent to and from Turkers',
        )
        messenger.add_argument(
            '--log-level',
            dest='log_level',
            type=int,
            default=20,
            help='importance level for what to put into the logs. the lower '
            'the level the more that gets logged. values are 0-50',
        )
        messenger.add_argument(
            '--force-page-token',
            dest='force_page_token',
            action='store_true',
            help='override the page token stored in the cache for a new one',
        )
        messenger.add_argument(
            '--bypass-server-setup',
            dest='bypass_server_setup',
            action='store_true',
            default=False,
            help='should bypass traditional server and socket setup',
        )
        messenger.add_argument(
            '--local',
            dest='local',
            action='store_true',
            default=False,
            help='Run the server locally on this server rather than setting up'
            ' a heroku server.',
        )

        messenger.set_defaults(is_debug=False)
        messenger.set_defaults(verbose=False)

    def add_parlai_args(self, args=None):
        """
        Add common ParlAI args across all scripts.
        """
        parlai = self.add_argument_group('Main ParlAI Arguments')
        parlai.add_argument(
            '-o',
            '--init-opt',
            default=None,
            help='Path to json file of options. '
            'Note: Further Command-line arguments override file-based options.',
        )
        parlai.add_argument(
            '-v',
            '--show-advanced-args',
            action='store_true',
            help='Show hidden command line options (advanced users only)',
        )
        parlai.add_argument(
            '-t',
            '--task',
            help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"')
        parlai.add_argument(
            '--download-path',
            default=None,
            hidden=True,
            help='path for non-data dependencies to store any needed files.'
            'defaults to {parlai_dir}/downloads',
        )
        parlai.add_argument(
            '-dt',
            '--datatype',
            default='train',
            choices=[
                'train',
                'train:stream',
                'train:ordered',
                'train:ordered:stream',
                'train:stream:ordered',
                'train:evalmode',
                'train:evalmode:stream',
                'train:evalmode:ordered',
                'train:evalmode:ordered:stream',
                'train:evalmode:stream:ordered',
                'valid',
                'valid:stream',
                'test',
                'test:stream',
            ],
            help='choose from: train, train:ordered, valid, test. to stream '
            'data add ":stream" to any option (e.g., train:stream). '
            'by default: train is random with replacement, '
            'valid is ordered, test is ordered.',
        )

        parlai.add_argument(
            '-im',
            '--image-mode',
            default='raw',
            type=str,
            help='image preprocessor to use. default is "raw". set to "none" '
            'to skip image loading.',
            hidden=True,
        )
        parlai.add_argument(
            '-nt',
            '--numthreads',
            default=1,
            type=int,
            help='number of threads. Used for hogwild if batchsize is 1, else '
            'for number of threads in threadpool loading,',
        )
        parlai.add_argument(
            '--hide-labels',
            default=False,
            type='bool',
            hidden=True,
            help='default (False) moves labels in valid and test sets to the '
            'eval_labels field. If True, they are hidden completely.',
        )
        parlai.add_argument(
            '-mtw',
            '--multitask-weights',
            type='floats',
            default=[1],
            help='list of floats, one for each task, specifying '
            'the probability of drawing the task in multitask case',
            hidden=True,
        )
        parlai.add_argument(
            '-bs',
            '--batchsize',
            default=1,
            type=int,
            help='batch size for minibatch training schemes',
        )
        self.add_parlai_data_path(parlai)

    def add_distributed_training_args(self):
        """
        Add CLI args for distributed training.
        """
        grp = self.add_argument_group('Distributed Training')
        grp.add_argument('--distributed-world-size',
                         type=int,
                         help='Number of workers.')
        grp.add_argument(
            '--verbose',
            type='bool',
            default=False,
            help='All workers print output.',
            hidden=True,
        )
        return grp

    def add_pytorch_datateacher_args(self):
        """
        Add CLI args for PytorchDataTeacher.
        """
        pytorch = self.add_argument_group('PytorchData Arguments')
        pytorch.add_argument(
            '-pyt',
            '--pytorch-teacher-task',
            help='Use the PytorchDataTeacher for multiprocessed '
            'data loading with a standard ParlAI task, e.g. "babi:Task1k"',
        )
        pytorch.add_argument(
            '-pytd',
            '--pytorch-teacher-dataset',
            help='Use the PytorchDataTeacher for multiprocessed '
            'data loading with a pytorch Dataset, e.g. "vqa_1" or "flickr30k"',
        )
        pytorch.add_argument(
            '--pytorch-datapath',
            type=str,
            default=None,
            help='datapath for pytorch data loader'
            '(note: only specify if the data does not reside'
            'in the normal ParlAI datapath)',
            hidden=True,
        )
        pytorch.add_argument(
            '-nw',
            '--numworkers',
            type=int,
            default=4,
            help='how many workers the Pytorch dataloader should use',
            hidden=True,
        )
        pytorch.add_argument(
            '--pytorch-preprocess',
            type='bool',
            default=False,
            help='Whether the agent should preprocess the data while building'
            'the pytorch data',
            hidden=True,
        )
        pytorch.add_argument(
            '-pybsrt',
            '--pytorch-teacher-batch-sort',
            type='bool',
            default=False,
            help='Whether to construct batches of similarly sized episodes'
            'when using the PytorchDataTeacher (either via specifying `-pyt`',
            hidden=True,
        )
        pytorch.add_argument(
            '--batch-sort-cache-type',
            type=str,
            choices=['pop', 'index', 'none'],
            default='pop',
            help='how to build up the batch cache',
            hidden=True,
        )
        pytorch.add_argument(
            '--batch-length-range',
            type=int,
            default=5,
            help='degree of variation of size allowed in batch',
            hidden=True,
        )
        pytorch.add_argument(
            '--shuffle',
            type='bool',
            default=False,
            help='Whether to shuffle the data',
            hidden=True,
        )
        pytorch.add_argument(
            '--batch-sort-field',
            type=str,
            default='text',
            help='What field to use when determining the length of an episode',
            hidden=True,
        )
        pytorch.add_argument(
            '-pyclen',
            '--pytorch-context-length',
            default=-1,
            type=int,
            help='Number of past utterances to remember when building flattened '
            'batches of data in multi-example episodes.'
            '(For use with PytorchDataTeacher)',
            hidden=True,
        )
        pytorch.add_argument(
            '-pyincl',
            '--pytorch-include-labels',
            default=True,
            type='bool',
            help=
            'Specifies whether or not to include labels as past utterances when '
            'building flattened batches of data in multi-example episodes.'
            '(For use with PytorchDataTeacher)',
            hidden=True,
        )

    def add_model_args(self):
        """
        Add arguments related to models such as model files.
        """
        model_args = self.add_argument_group('ParlAI Model Arguments')
        model_args.add_argument(
            '-m',
            '--model',
            default=None,
            help='the model class name. can match parlai/agents/<model> for '
            'agents in that directory, or can provide a fully specified '
            'module for `from X import Y` via `-m X:Y` '
            '(e.g. `-m parlai.agents.seq2seq.seq2seq:Seq2SeqAgent`)',
        )
        model_args.add_argument(
            '-mf',
            '--model-file',
            default=None,
            help='model file name for loading and saving models',
        )
        model_args.add_argument(
            '-im',
            '--init-model',
            default=None,
            type=str,
            help='load model weights and dict from this file',
        )
        model_args.add_argument('--dict-class',
                                hidden=True,
                                help='the class of the dictionary agent uses')

    def add_model_subargs(self, model):
        """
        Add arguments specific to a particular model.
        """
        agent = get_agent_module(model)
        try:
            if hasattr(agent, 'add_cmdline_args'):
                agent.add_cmdline_args(self)
        except argparse.ArgumentError:
            # already added
            pass
        try:
            if hasattr(agent, 'dictionary_class'):
                s = class2str(agent.dictionary_class())
                self.set_defaults(dict_class=s)
        except argparse.ArgumentError:
            # already added
            pass

    def add_task_args(self, task):
        """
        Add arguments specific to the specified task.
        """
        for t in ids_to_tasks(task).split(','):
            agent = get_task_module(t)
            try:
                if hasattr(agent, 'add_cmdline_args'):
                    agent.add_cmdline_args(self)
            except argparse.ArgumentError:
                # already added
                pass

    def add_pyt_dataset_args(self, opt):
        """
        Add arguments specific to specified pytorch dataset.
        """
        from parlai.core.pytorch_data_teacher import get_dataset_classes

        dataset_classes = get_dataset_classes(opt)
        for dataset, _, _ in dataset_classes:
            try:
                if hasattr(dataset, 'add_cmdline_args'):
                    dataset.add_cmdline_args(self)
            except argparse.ArgumentError:
                # already added
                pass

    def add_image_args(self, image_mode):
        """
        Add additional arguments for handling images.
        """
        try:
            parlai = self.add_argument_group(
                'ParlAI Image Preprocessing Arguments')
            parlai.add_argument(
                '--image-size',
                type=int,
                default=256,
                help='resizing dimension for images',
                hidden=True,
            )
            parlai.add_argument(
                '--image-cropsize',
                type=int,
                default=224,
                help='crop dimension for images',
                hidden=True,
            )
        except argparse.ArgumentError:
            # already added
            pass

    def add_extra_args(self, args=None):
        """
        Add more args depending on how known args are set.
        """
        parsed = vars(self.parse_known_args(args, nohelp=True)[0])
        # Also load extra args options if a file is given.
        if parsed.get('init_opt', None) is not None:
            self._load_known_opts(parsed.get('init_opt'), parsed)
        parsed = self._infer_datapath(parsed)

        # find which image mode specified if any, and add additional arguments
        image_mode = parsed.get('image_mode', None)
        if image_mode is not None and image_mode != 'no_image_model':
            self.add_image_args(image_mode)

        # find which task specified if any, and add its specific arguments
        task = parsed.get('task', None)
        if task is not None:
            self.add_task_args(task)
        evaltask = parsed.get('evaltask', None)
        if evaltask is not None:
            self.add_task_args(evaltask)

        # find pytorch teacher task if specified, add its specific arguments
        pytorch_teacher_task = parsed.get('pytorch_teacher_task', None)
        if pytorch_teacher_task is not None:
            self.add_task_args(pytorch_teacher_task)

        # find pytorch dataset if specified, add its specific arguments
        pytorch_teacher_dataset = parsed.get('pytorch_teacher_dataset', None)
        if pytorch_teacher_dataset is not None:
            self.add_pyt_dataset_args(parsed)

        # find which model specified if any, and add its specific arguments
        model = get_model_name(parsed)
        if model is not None:
            self.add_model_subargs(model)

        # reset parser-level defaults over any model-level defaults
        try:
            self.set_defaults(**self._defaults)
        except AttributeError:
            raise RuntimeError('Please file an issue on github that argparse '
                               'got an attribute error when parsing.')

    def parse_known_args(self, args=None, namespace=None, nohelp=False):
        """
        Parse known args to ignore help flag.
        """
        if args is None:
            # args default to the system args
            args = _sys.argv[1:]
        args = fix_underscores(args)

        if nohelp:
            # ignore help
            args = [a for a in args if a != '-h' and a != '--help']
        return super().parse_known_args(args, namespace)

    def _load_known_opts(self, optfile, parsed):
        """
        Pull in CLI args for proper models/tasks/etc.

        Called before args are parsed; ``_load_opts`` is used for actually overriding
        opts after they are parsed.
        """
        new_opt = load_opt_file(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in parsed or parsed[key] is None:
                parsed[key] = value

    def _load_opts(self, opt):
        optfile = opt.get('init_opt')
        new_opt = load_opt_file(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in opt:
                raise RuntimeError(
                    'Trying to set opt from file that does not exist: ' +
                    str(key))
            if key not in opt['override']:
                opt[key] = value
                opt['override'][key] = value

    def _infer_datapath(self, opt):
        """
        Set the value for opt['datapath'] and opt['download_path'].

        Sets the value for opt['datapath'] and opt['download_path'], correctly
        respecting environmental variables and the default.
        """
        # set environment variables
        # Priority for setting the datapath (same applies for download_path):
        # --datapath -> os.environ['PARLAI_DATAPATH'] -> <self.parlai_home>/data
        if opt.get('download_path'):
            os.environ['PARLAI_DOWNPATH'] = opt['download_path']
        elif os.environ.get('PARLAI_DOWNPATH') is None:
            os.environ['PARLAI_DOWNPATH'] = os.path.join(
                self.parlai_home, 'downloads')
        if opt.get('datapath'):
            os.environ['PARLAI_DATAPATH'] = opt['datapath']
        elif os.environ.get('PARLAI_DATAPATH') is None:
            os.environ['PARLAI_DATAPATH'] = os.path.join(
                self.parlai_home, 'data')

        opt['download_path'] = os.environ['PARLAI_DOWNPATH']
        opt['datapath'] = os.environ['PARLAI_DATAPATH']

        return opt

    def _process_args_to_opts(self,
                              args_that_override: Optional[List[str]] = None):
        self.opt = Opt(vars(self.args))

        # custom post-parsing
        self.opt['parlai_home'] = self.parlai_home
        self.opt = self._infer_datapath(self.opt)

        # set all arguments specified in command line as overridable
        option_strings_dict = {}
        store_true = []
        store_false = []
        for group in self._action_groups:
            for a in group._group_actions:
                if hasattr(a, 'option_strings'):
                    for option in a.option_strings:
                        option_strings_dict[option] = a.dest
                        if '_StoreTrueAction' in str(type(a)):
                            store_true.append(option)
                        elif '_StoreFalseAction' in str(type(a)):
                            store_false.append(option)

        if args_that_override is None:
            args_that_override = _sys.argv[1:]

        for i in range(len(args_that_override)):
            if args_that_override[i] in option_strings_dict:
                if args_that_override[i] in store_true:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = True
                elif args_that_override[i] in store_false:
                    self.overridable[option_strings_dict[
                        args_that_override[i]]] = False
                elif (i < len(args_that_override) - 1
                      and args_that_override[i + 1][:1] != '-'):
                    key = option_strings_dict[args_that_override[i]]
                    self.overridable[key] = self.opt[key]
        self.opt['override'] = self.overridable

        # load opts if a file is provided.
        if self.opt.get('init_opt', None) is not None:
            self._load_opts(self.opt)

        # map filenames that start with 'zoo:' to point to the model zoo dir
        if self.opt.get('model_file') is not None:
            self.opt['model_file'] = modelzoo_path(self.opt.get('datapath'),
                                                   self.opt['model_file'])
        if self.opt['override'].get('model_file') is not None:
            # also check override
            self.opt['override']['model_file'] = modelzoo_path(
                self.opt.get('datapath'), self.opt['override']['model_file'])
        if self.opt.get('dict_file') is not None:
            self.opt['dict_file'] = modelzoo_path(self.opt.get('datapath'),
                                                  self.opt['dict_file'])
        if self.opt['override'].get('dict_file') is not None:
            # also check override
            self.opt['override']['dict_file'] = modelzoo_path(
                self.opt.get('datapath'), self.opt['override']['dict_file'])

        # add start time of an experiment
        self.opt['starttime'] = datetime.datetime.today().strftime(
            '%b%d_%H-%M')

    def parse_and_process_known_args(self, args=None):
        """
        Parse provided arguments and return parlai opts and unknown arg list.

        Runs the same arg->opt parsing that parse_args does, but doesn't throw an error
        if the args being parsed include additional command line arguments that parlai
        doesn't know what to do with.
        """
        self.args, unknowns = super().parse_known_args(args=args)
        self._process_args_to_opts(args)
        return self.opt, unknowns

    def parse_args(self, args=None, namespace=None, print_args=True):
        """
        Parse the provided arguments and returns a dictionary of the ``args``.

        We specifically remove items with ``None`` as values in order to support the
        style ``opt.get(key, default)``, which would otherwise return ``None``.
        """
        self.add_extra_args(args)
        self.args = super().parse_args(args=args)

        self._process_args_to_opts(args)

        if print_args:
            self.print_args()
            # print_git_commit()
            print_announcements(self.opt)

        return self.opt

    def print_args(self):
        """
        Print out all the arguments in this parser.
        """
        if not self.opt:
            self.parse_args(print_args=False)
        values = {}
        for key, value in self.opt.items():
            values[str(key)] = str(value)
        for group in self._action_groups:
            group_dict = {
                a.dest: getattr(self.args, a.dest, None)
                for a in group._group_actions
            }
            namespace = argparse.Namespace(**group_dict)
            count = 0
            for key in sorted(namespace.__dict__):
                if key in values:
                    if count == 0:
                        print('[ ' + group.title + ': ] ')
                    count += 1
                    print('[  ' + key + ': ' + values[key] + ' ]')

    def set_params(self, **kwargs):
        """
        Set overridable kwargs.
        """
        self.set_defaults(**kwargs)
        for k, v in kwargs.items():
            self.overridable[k] = v

    @property
    def show_advanced_args(self):
        """
        Check if we should show arguments marked as hidden.
        """
        if hasattr(self, '_show_advanced_args'):
            return self._show_advanced_args
        known_args, _ = self.parse_known_args(nohelp=True)
        if hasattr(known_args, 'show_advanced_args'):
            self._show_advanced_args = known_args.show_advanced_args
        else:
            self._show_advanced_args = True
        return self._show_advanced_args

    def _handle_hidden_args(self, kwargs):
        """
        Hide help messages for arguments marked as hidden.
        """
        if 'hidden' in kwargs:
            flag = kwargs['hidden']
            del kwargs['hidden']
            if flag and not self.show_advanced_args:
                kwargs['help'] = argparse.SUPPRESS
        return kwargs

    def _augment_help_msg(self, kwargs):
        """
        Add recommended value to help string if recommended exists.
        """
        if 'help' in kwargs:
            if 'recommended' in kwargs:
                kwargs['help'] += " (recommended: " + str(
                    kwargs['recommended']) + ")"
                del kwargs['recommended']
        return kwargs

    def add_argument(self, *args, **kwargs):
        """
        Override to convert underscores to hyphens for consistency.
        """
        kwargs = self._augment_help_msg(kwargs)
        return super().add_argument(*fix_underscores(args),
                                    **self._handle_hidden_args(kwargs))

    def add_argument_group(self, *args, **kwargs):
        """
        Override to make arg groups also convert underscores to hyphens.
        """
        arg_group = super().add_argument_group(*args, **kwargs)
        original_add_arg = arg_group.add_argument

        def ag_add_argument(*args, **kwargs):
            kwargs = self._augment_help_msg(kwargs)
            return original_add_arg(*fix_underscores(args),
                                    **self._handle_hidden_args(kwargs))

        arg_group.add_argument = ag_add_argument  # override _ => -
        arg_group.add_argument_group = self.add_argument_group
        return arg_group

    def error(self, message):
        """
        Override to print custom error message.
        """
        self.print_help()
        _sys.stderr.write('\nParse Error: %s\n' % message)
        _sys.exit(2)
Example #9
0
def create_agent_from_opt_file(opt: Opt):
    """
    Load agent options and module from file if opt file exists.

    Checks to see if file exists opt['model_file'] + ".opt"; if so, load up the
    options from the file and use that to create an agent, loading the model
    type from that file and overriding any options specified in that file when
    instantiating the agent.

    If that file does not exist, return None.
    """
    model_file = opt['model_file']
    optfile = model_file + '.opt'
    if os.path.isfile(optfile):
        new_opt = Opt.load(optfile)
        # TODO we need a better way to say these options are never copied...
        if 'datapath' in new_opt:
            # never use the datapath from an opt dump
            del new_opt['datapath']
        if 'batchindex' in new_opt:
            # This saved variable can cause trouble if we switch to BS=1 at test time
            del new_opt['batchindex']
        # only override opts specified in 'override' dict
        if opt.get('override'):
            for k, v in opt['override'].items():
                if str(v) != str(new_opt.get(k, None)):
                    print(
                        "[ warning: overriding opt['{}'] to {} ("
                        "previously: {} )]".format(k, v, new_opt.get(k, None))
                    )
                new_opt[k] = v

        model_class = load_agent_module(new_opt['model'])

        # check for model version
        if hasattr(model_class, 'model_version'):
            curr_version = new_opt.get('model_version', 0)
            if curr_version != model_class.model_version():
                model = new_opt['model']
                m = (
                    'It looks like you are trying to load an older version of'
                    ' the selected model. Change your model argument to use '
                    'the old version from parlai/agents/legacy_agents: for '
                    'example: `-m legacy:{m}:{v}` or '
                    '`--model parlai.agents.legacy_agents.{m}.{m}_v{v}:{c}`'
                )
                if '.' not in model:
                    # give specific error message if it's easy
                    raise RuntimeError(
                        m.format(m=model, v=curr_version, c=model_class.__name__)
                    )
                else:
                    # otherwise generic one
                    raise RuntimeError(
                        m.format(m='modelname', v=curr_version, c='ModelAgent')
                    )

        if hasattr(model_class, 'upgrade_opt'):
            new_opt = model_class.upgrade_opt(new_opt)

        # add model arguments to new_opt if they aren't in new_opt already
        for k, v in opt.items():
            if k not in new_opt:
                new_opt[k] = v
        new_opt['model_file'] = model_file
        if not new_opt.get('dict_file'):
            new_opt['dict_file'] = model_file + '.dict'
        elif new_opt.get('dict_file') and not os.path.isfile(new_opt['dict_file']):
            old_dict_file = new_opt['dict_file']
            new_opt['dict_file'] = model_file + '.dict'
        if not os.path.isfile(new_opt['dict_file']):
            warn_once(
                'WARNING: Neither the specified dict file ({}) nor the '
                '`model_file`.dict file ({}) exists, check to make sure either '
                'is correct. This may manifest as a shape mismatch later '
                'on.'.format(old_dict_file, new_opt['dict_file'])
            )

        # if we want to load weights from --init-model, compare opts with
        # loaded ones
        compare_init_model_opts(opt, new_opt)
        return model_class(new_opt)
    else:
        return None