Esempio n. 1
0
 def add_task_args(self, task):
     for t in ids_to_tasks(task).split(','):
         agent = get_task_module(t)
         try:
             if hasattr(agent, 'add_cmdline_args'):
                 agent.add_cmdline_args(self)
         except argparse.ArgumentError:
             # already added
             pass
Esempio n. 2
0
def create_task(opt, user_agents, default_world=None):
    """
    Create a world + task_agents (aka a task).

    Assuming ``opt['task']="task_dir:teacher_class:options"``
    e.g. ``"babi:Task1k:1"`` or ``"#babi-1k"`` or ``"#QA"``,
    see ``parlai/tasks/tasks.py`` and see ``parlai/tasks/task_list.py``
    for list of tasks.
    """
    task = opt.get('task')
    pyt_task = opt.get('pytorch_teacher_task')
    pyt_dataset = opt.get('pytorch_teacher_dataset')
    if not (task or pyt_task or pyt_dataset):
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    # When building pytorch data, there is a point where task and pyt_task
    # are the same; make sure we discount that case.
    pyt_multitask = task is not None and (
        (pyt_task is not None and pyt_task != task) or
        (pyt_dataset is not None and pyt_dataset != task))
    if not task:
        opt['task'] = 'pytorch_teacher'
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    print("\t\t opt['task'] :: ", opt['task'])
    opt['task'] = ids_to_tasks(opt['task'])
    if pyt_multitask and 'pytorch_teacher' not in opt['task']:
        opt['task'] += ',pytorch_teacher'
    print('\t\t[creating task(s): ' + opt['task'] + ']')

    # check if single or multithreaded, and single-example or batched examples
    if ',' not in opt['task']:
        # Single task
        print("\t\tcreate_task_world")
        world = create_task_world(opt,
                                  user_agents,
                                  default_world=default_world)
    else:
        print("\t\tMultiWorld")
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if opt.get('numthreads', 1) > 1:
        # use hogwild world if more than one thread requested
        # hogwild world will create sub batch worlds as well if bsz > 1
        world = HogwildWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
Esempio n. 3
0
 def add_task_args(self, task):
     """
     Add arguments specific to the specified task.
     """
     for t in ids_to_tasks(task).split(','):
         agent = load_teacher_module(t)
         try:
             if hasattr(agent, 'add_cmdline_args'):
                 agent.add_cmdline_args(self)
         except argparse.ArgumentError:
             # already added
             pass
Esempio n. 4
0
 def add_task_args(self, args=None):
     # Find which task specified, and add its specific arguments.
     args = sys.argv if args is None else args
     task = None
     for index, item in enumerate(args):
         if item == '-t' or item == '--task':
             task = args[index + 1]
     if task:
         for t in ids_to_tasks(task).split(','):
             agent = get_task_module(t)
             if hasattr(agent, 'add_cmdline_args'):
                 agent.add_cmdline_args(self)
Esempio n. 5
0
 def add_task_args(self, args):
     # Find which task specified, and add its specific arguments.
     args = sys.argv if args is None else args
     task = None
     for index, item in enumerate(args):
         if item == '-t' or item == '--task':
             task = args[index + 1]
     if task:
         for t in ids_to_tasks(task).split(','):
             agent = get_task_module(t)
             if hasattr(agent, 'add_cmdline_args'):
                 agent.add_cmdline_args(self)
Esempio n. 6
0
def create_task(opt, user_agents):
    """Creates a world + task_agents (aka a task)
    assuming ``opt['task']="task_dir:teacher_class:options"``
    e.g. ``"babi:Task1k:1"`` or ``"#babi-1k"`` or ``"#QA"``,
    see ``parlai/tasks/tasks.py`` and see ``parlai/tasks/task_list.py``
    for list of tasks.
    """
    if not opt.get('task'):
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    print('[creating task(s): ' + opt['task'] + ']')

    # Single threaded or hogwild task creation (the latter creates multiple threads).
    # Check datatype for train, because we need to do single-threaded for
    # valid and test in order to guarantee exactly one epoch of training.
    # If batchsize > 1, default to BatchWorld, as numthreads can be used for
    # multithreading in batch mode, e.g. multithreaded data loading
    if opt.get('numthreads',
               1) == 1 or 'train' not in opt['datatype'] or opt.get(
                   'batchsize', 1) > 1:
        if ',' not in opt['task']:
            # Single task
            world = create_task_world(opt, user_agents)
        else:
            # Multitask teacher/agent
            world = MultiWorld(opt, user_agents)

        if opt.get('batchsize', 1) > 1:
            return BatchWorld(opt, world)
        else:
            return world
    else:
        # more than one thread requested: do hogwild training
        if ',' not in opt['task']:
            # Single task
            # TODO(ahm): fix metrics for multiteacher hogwild training
            world_class, task_agents = _get_task_world(opt)
            return HogwildWorld(world_class, opt, task_agents + user_agents)
        else:
            # TODO(ahm): fix this
            raise NotImplementedError('hogwild multiworld not supported yet')
Esempio n. 7
0
def create_task(opt: Opt, user_agents, default_world=None):
    """
    Create a world + task_agents (aka a task).

    Assuming ``opt['task']="task_dir:teacher_class:options"`` e.g. ``"babi:Task1k:1"``
    or ``"#babi-1k"`` or ``"#QA"``, see ``parlai/tasks/tasks.py`` and see
    ``parlai/tasks/task_list.py`` for list of tasks.
    """
    task = opt.get('task')
    if not task:
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    logging.info(f"creating task(s): {opt['task']}")

    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt,
                                  user_agents,
                                  default_world=default_world)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if DatatypeHelper.is_training(
            opt['datatype']) and opt.get('num_workers', 0) > 0:
        # note that we never use Background preprocessing in the valid/test
        # worlds, as we are unable to call Teacher.observe(model_act) in BG
        # preprocessing, so we are unable to compute Metrics or accurately
        # differentiate MultiWorld stats.
        world = BackgroundDriverWorld(opt, world)
    elif opt.get('batchsize', 1) > 1 and opt.get('dynamic_batching'):
        world = DynamicBatchWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
Esempio n. 8
0
def create_task(opt, user_agents):
    """Creates a world + task_agents (aka a task)
    assuming ``opt['task']="task_dir:teacher_class:options"``
    e.g. ``"babi:Task1k:1"`` or ``"#babi-1k"`` or ``"#QA"``,
    see ``parlai/tasks/tasks.py`` and see ``parlai/tasks/task_list.py``
    for list of tasks.
    """
    if not opt.get('task'):
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    print('[creating task(s): ' + opt['task'] + ']')

    # Single threaded or hogwild task creation (the latter creates multiple threads).
    # Check datatype for train, because we need to do single-threaded for
    # valid and test in order to guarantee exactly one epoch of training.
    if opt.get('numthreads', 1) == 1 or 'train' not in opt['datatype']:
        if ',' not in opt['task']:
            # Single task
            world = create_task_world(opt, user_agents)
        else:
            # Multitask teacher/agent
            world = MultiWorld(opt, user_agents)

        if opt.get('batchsize', 1) > 1:
            return BatchWorld(opt, world)
        else:
            return world
    else:
        # more than one thread requested: do hogwild training
        if ',' not in opt['task']:
            # Single task
            # TODO(ahm): fix metrics for multiteacher hogwild training
            world_class, task_agents = _get_task_world(opt)
            return HogwildWorld(world_class, opt, task_agents + user_agents)
        else:
            # TODO(ahm): fix this
            raise NotImplementedError('hogwild multiworld not supported yet')
Esempio n. 9
0
 def add_task_args(self, task: str, partial: Opt):
     """
     Add arguments specific to the specified task.
     """
     for t in ids_to_tasks(task).split(','):
         agent = load_teacher_module(t)
         try:
             if hasattr(agent, 'add_cmdline_args'):
                 agent.add_cmdline_args(self, partial)
         except TypeError as typ:
             raise TypeError(
                 f"Task '{task}' appears to have signature "
                 "add_cmdline_args(argparser) but we have updated the signature "
                 "to add_cmdline_args(argparser, partial_opt). For details, see "
                 "https://github.com/facebookresearch/ParlAI/pull/3328."
             ) from typ
         except argparse.ArgumentError:
             # already added
             pass
Esempio n. 10
0
def create_task(opt, user_agents, default_world=None):
    """Creates a world + task_agents (aka a task)
    assuming ``opt['task']="task_dir:teacher_class:options"``
    e.g. ``"babi:Task1k:1"`` or ``"#babi-1k"`` or ``"#QA"``,
    see ``parlai/tasks/tasks.py`` and see ``parlai/tasks/task_list.py``
    for list of tasks.
    """
    if not (opt.get('task') or opt.get('pytorch_teacher_task')
            or opt.get('pytorch_teacher_dataset')):
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if not opt.get('task'):
        opt['task'] = 'pytorch_teacher'
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    print('[creating task(s): ' + opt['task'] + ']')

    # check if single or multithreaded, and single-example or batched examples
    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt,
                                  user_agents,
                                  default_world=default_world)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if opt.get('numthreads', 1) > 1:
        # use hogwild world if more than one thread requested
        # hogwild world will create sub batch worlds as well if bsz > 1
        world = HogwildWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
Esempio n. 11
0
def create_task(opt, user_agents):
    """Creates a world + task_agents (aka a task)
    assuming opt['task']="task_dir:teacher_class:options"
    e.g. "babi:Task1k:1" or "#babi-1k" or "#QA",
    see parlai/tasks/tasks.py and see parlai/tasks/task_list.py
    for list of tasks.
    """
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    #pdb.set_trace()
    opt['task'] = ids_to_tasks(opt['task'])
    print('[creating task(s): ' + opt['task'] + ']')

    # Single threaded or hogwild task creation (the latter creates multiple threads).
    # Check datatype for train, because we need to do single-threaded for
    # valid and test in order to guarantee exactly one epoch of training.
    if opt.get('numthreads', 1) == 1 or opt['datatype'] != 'train':
        if ',' not in opt['task']:
            # Single task
            world = create_task_world(opt, user_agents)
        else:
            # Multitask teacher/agent
            world = MultiWorld(opt, user_agents)

        if opt.get('batchsize', 1) > 1:
            return BatchWorld(opt, world)
        else:
            return world
    else:
        # more than one thread requested: do hogwild training
        if ',' not in opt['task']:
            # Single task
            # TODO(ahm): fix metrics for multiteacher hogwild training
            world_class, task_agents = _get_task_world(opt)
            return HogwildWorld(world_class, opt, task_agents + user_agents)
        else:
            # TODO(ahm): fix this
            raise NotImplementedError('hogwild multiworld not supported yet')
Esempio n. 12
0
def create_task(opt: Opt, user_agents, default_world=None):
    """
    Create a world + task_agents (aka a task).

    Assuming ``opt['task']="task_dir:teacher_class:options"`` e.g. ``"babi:Task1k:1"``
    or ``"#babi-1k"`` or ``"#QA"``, see ``parlai/tasks/tasks.py`` and see
    ``parlai/tasks/task_list.py`` for list of tasks.
    """
    task = opt.get('task')
    if not task:
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    logging.info(f"creating task(s): {opt['task']}")

    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt,
                                  user_agents,
                                  default_world=default_world)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if opt.get('batchsize', 1) > 1 and opt.get('dynamic_batching'):
        world = DynamicBatchWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
Esempio n. 13
0
def create_task(opt, user_agents):
    """Creates a world + task_agents (aka a task)
    assuming ``opt['task']="task_dir:teacher_class:options"``
    e.g. ``"babi:Task1k:1"`` or ``"#babi-1k"`` or ``"#QA"``,
    see ``parlai/tasks/tasks.py`` and see ``parlai/tasks/task_list.py``
    for list of tasks.
    """
    if not opt.get('task'):
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    print('[creating task(s): ' + opt['task'] + ']')

    # check if single or multithreaded, and single-example or batched examples
    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt, user_agents)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents)

    if opt.get('numthreads', 1) > 1:
        # use hogwild world if more than one thread requested
        # hogwild world will create sub batch worlds as well if bsz > 1
        world = HogwildWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world