Пример #1
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser()
    parser.add_argument('-n', '--num-examples', default=10)
    parser.set_defaults(datatype='train:ordered')

    ImageLoader.add_cmdline_args(parser)
    opt = parser.parse_args()

    opt['no_cuda'] = False
    opt['gpu'] = 0
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    with world:
        for k in range(int(opt['num_examples'])):
            world.parley()
            print(world.display() + '\n~~')
            if world.epoch_done():
                print('EPOCH DONE')
                break
Пример #2
0
def main():
    # Get command line arguments
    parser = ParlaiParser()
    parser.add_argument('-n', '--num-examples', default=10)
    opt = parser.parse_args()

    agent = Agent(opt)

    opt['datatype'] = 'train'
    world_train = create_task(opt, agent)

    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agent)

    start = time.time()
    # train / valid loop
    for _ in range(1):
        print('[ training ]')
        for _ in range(10):  # train for a bit
            world_train.parley()

        print('[ training summary. ]')
        print(world_train.report())

        print('[ validating ]')
        for _ in range(1):  # check valid accuracy
            world_valid.parley()

        print('[ validation summary. ]')
        print(world_valid.report())

    print('finished in {} s'.format(round(time.time() - start, 2)))
Пример #3
0
    def test_basic_parse(self):
        """Check that the dictionary is correctly adding and parsing short
        sentence.
        """
        from parlai.core.dict import DictionaryAgent
        from parlai.core.params import ParlaiParser

        argparser = ParlaiParser()
        DictionaryAgent.add_cmdline_args(argparser)
        opt = argparser.parse_args()
        dictionary = DictionaryAgent(opt)
        num_builtin = len(dictionary)

        dictionary.observe({'text': 'hello world'})
        dictionary.act()
        assert len(dictionary) - num_builtin == 2

        vec = dictionary.parse('hello world')
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=list)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1

        vec = dictionary.parse('hello world', vec_type=tuple)
        assert len(vec) == 2
        assert vec[0] == num_builtin
        assert vec[1] == num_builtin + 1
Пример #4
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    RemoteAgentAgent.add_cmdline_args(parser)
    opt = parser.parse_args()

    remote = RemoteAgentAgent(opt)
    if opt.get('task'):
        world = create_task(opt, [remote])
    else:
        if opt.get('model'):
            local = create_agent(opt)
        else:
            local = LocalHumanAgent(opt)
        # the remote-host goes **second**
        agents = [local, remote] if not opt['remote_host'] else [remote, local]
        world = DialogPartnerWorld(opt, agents)


    # Talk to the remote agent
    with world:
        while True:
            world.parley()
            print(world.display())
Пример #5
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser()
    parser.add_argument('-n', '--num-examples', default=10, type=int)
    opt = parser.parse_args()

    display_data(opt)
Пример #6
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_messenger_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))

    # Initialize a SQuAD teacher agent, which we will get context from
    module_name = 'parlai.tasks.squad.agents'
    class_name = 'DefaultTeacher'
    my_module = importlib.import_module(module_name)
    task_class = getattr(my_module, class_name)
    task_opt = {}
    task_opt['datatype'] = 'train'
    task_opt['datapath'] = opt['datapath']

    messenger_manager = MessengerManager(opt=opt)
    messenger_manager.setup_server()
    messenger_manager.init_new_state()

    def get_overworld(agent):
        return MessengerOverworld(None, agent)

    def assign_agent_role(agent):
        agent[0].disp_id = 'Agent'

    def run_conversation(manager, opt, agents, task_id):
        task = task_class(task_opt)
        agent = agents[0]
        world = QADataCollectionWorld(
            opt=opt,
            task=task,
            agent=agent
        )
        while not world.episode_done():
            world.parley()
        world.shutdown()

    # World with no onboarding
    messenger_manager.set_onboard_functions({'default': None})
    task_functions = {'default': run_conversation}
    assign_agent_roles = {'default': assign_agent_role}
    messenger_manager.set_agents_required({'default': 1})

    messenger_manager.set_overworld_func(get_overworld)
    messenger_manager.setup_socket()
    try:
        messenger_manager.start_new_run()
        messenger_manager.start_task(
            assign_role_functions=assign_agent_roles,
            task_functions=task_functions,
        )
    except BaseException:
        raise
    finally:
        messenger_manager.shutdown()
Пример #7
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=100000000)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.set_defaults(datatype='valid')
    opt = parser.parse_args(print_args=False)

    eval_model(opt, parser)
Пример #8
0
def main():
    # Get command line arguments
    argparser = ParlaiParser(True, True)
    build = argparser.add_argument_group('Data Building Args')
    build.add_argument('--datafile',
                       help=('The file to be loaded, preprocessed, and saved'))
    build.add_argument('--pytorch-buildteacher', type=str, default='',
        help='Which teacher to use when building the pytorch data')
    build.add_argument('--pytorch-preprocess', type='bool', default=True,
        help='Whether the agent should preprocess the data while building'
             'the pytorch data')
    opt = argparser.parse_args()
    build_data(opt)
Пример #9
0
def main():
    parser = ParlaiParser(True, True)
    ConvAIWorld.add_cmdline_args(parser)
    opt = parser.parse_args()

    agent = ConvAISampleAgent(opt)
    world = ConvAIWorld(opt, [agent])

    while True:
        try:
            world.parley()
        except Exception as e:
            print('Exception: {}'.format(e))
Пример #10
0
    def test_fvqa(self):
        from parlai.core.params import ParlaiParser
        parser = ParlaiParser()
        parser.add_task_args(['-t', 'fvqa'])
        opt = parser.parse_args(args=self.args)

        from parlai.tasks.fvqa.agents import DefaultTeacher
        for dt in ['train:ordered', 'test']:
            opt['datatype'] = dt

            teacher = DefaultTeacher(opt)
            reply = teacher.act()
            check(opt, reply)

        shutil.rmtree(self.TMP_PATH)
Пример #11
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()

    # The dialog model we want to evaluate
    from parlai.agents.ir_baseline.ir_baseline import IrBaselineAgent
    IrBaselineAgent.add_cmdline_args(argparser)
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    # The task that we will evaluate the dialog model on
    task_opt = {}
    task_opt['datatype'] = 'test'
    task_opt['datapath'] = opt['datapath']
    task_opt['task'] = '#MovieDD-Reddit'

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids = [mturk_agent_id],
        all_agent_ids = [ModelEvaluatorWorld.evaluator_agent_id, mturk_agent_id] # In speaking order
    )
    mturk_manager.init_aws(opt=opt)
    
    global run_hit
    def run_hit(hit_index, assignment_index, opt, task_opt, mturk_manager):
        conversation_id = str(hit_index) + '_' + str(assignment_index)

        model_agent = IrBaselineAgent(opt=opt)
        # Create the MTurk agent which provides a chat interface to the Turker
        mturk_agent = MTurkAgent(id=mturk_agent_id, manager=mturk_manager, conversation_id=conversation_id, opt=opt)
        world = ModelEvaluatorWorld(opt=opt, model_agent=model_agent, task_opt=task_opt, mturk_agent=mturk_agent)

        while not world.episode_done():
            world.parley()
        world.shutdown()
        world.review_work()

    mturk_manager.create_hits(opt=opt)
    results = Parallel(n_jobs=opt['num_hits'] * opt['num_assignments'], backend='threading') \
                (delayed(run_hit)(hit_index, assignment_index, opt, task_opt, mturk_manager) \
                    for hit_index, assignment_index in product(range(1, opt['num_hits']+1), range(1, opt['num_assignments']+1)))    
    mturk_manager.shutdown()
Пример #12
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    mturk_agent_1_id = 'mturk_agent_1'
    mturk_agent_2_id = 'mturk_agent_2'
    human_agent_1_id = 'human_1'
    human_agent_2_id = 'human_2'
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids = [mturk_agent_1_id, mturk_agent_2_id],
        all_agent_ids = [human_agent_1_id, human_agent_2_id, mturk_agent_1_id, mturk_agent_2_id] # In speaking order
    )
    mturk_manager.init_aws(opt=opt)

    global run_hit
    def run_hit(hit_index, assignment_index, opt, mturk_manager):
        conversation_id = str(hit_index) + '_' + str(assignment_index)

        # Create mturk agents
        mturk_agent_1 = MTurkAgent(id=mturk_agent_1_id, manager=mturk_manager, conversation_id=conversation_id, opt=opt)
        mturk_agent_2 = MTurkAgent(id=mturk_agent_2_id, manager=mturk_manager, conversation_id=conversation_id, opt=opt)

        # Create the local human agents
        human_agent_1 = LocalHumanAgent(opt=None)
        human_agent_1.id = human_agent_1_id
        human_agent_2 = LocalHumanAgent(opt=None)
        human_agent_2.id = human_agent_2_id

        world = MultiAgentDialogWorld(opt=opt, agents=[human_agent_1, human_agent_2, mturk_agent_1, mturk_agent_2])

        while not world.episode_done():
            world.parley()
        world.shutdown()

    mturk_manager.create_hits(opt=opt)
    results = Parallel(n_jobs=opt['num_hits'] * opt['num_assignments'], backend='threading') \
                (delayed(run_hit)(hit_index, assignment_index, opt, mturk_manager) \
                    for hit_index, assignment_index in product(range(1, opt['num_hits']+1), range(1, opt['num_assignments']+1)))
    mturk_manager.shutdown()
Пример #13
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    # Initialize a SQuAD teacher agent, which we will get context from
    module_name = 'parlai.tasks.squad.agents'
    class_name = 'DefaultTeacher'
    my_module = importlib.import_module(module_name)
    task_class = getattr(my_module, class_name)
    task_opt = {}
    task_opt['datatype'] = 'train'
    task_opt['datapath'] = opt['datapath']

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids = [mturk_agent_id],
        all_agent_ids = [QADataCollectionWorld.collector_agent_id, mturk_agent_id] # In speaking order
    )
    mturk_manager.init_aws(opt=opt)

    global run_hit
    def run_hit(hit_index, assignment_index, task_class, task_opt, opt, mturk_manager):
        conversation_id = str(hit_index) + '_' + str(assignment_index)

        task = task_class(task_opt)
        # Create the MTurk agent which provides a chat interface to the Turker
        mturk_agent = MTurkAgent(id=mturk_agent_id, manager=mturk_manager, conversation_id=conversation_id, opt=opt)
        world = QADataCollectionWorld(opt=opt, task=task, mturk_agent=mturk_agent)
        while not world.episode_done():
            world.parley()
        world.shutdown()
        world.review_work()

    mturk_manager.create_hits(opt=opt)
    results = Parallel(n_jobs=opt['num_hits'] * opt['num_assignments'], backend='threading') \
                (delayed(run_hit)(hit_index, assignment_index, task_class, task_opt, opt, mturk_manager) \
                    for hit_index, assignment_index in product(range(1, opt['num_hits']+1), range(1, opt['num_assignments']+1)))    
    mturk_manager.shutdown()
Пример #14
0
    def test_upgrade_opt(self):
        """
        Test whether upgrade_opt works.
        """
        with testing_utils.tempdir() as tmp:
            modfn = os.path.join(tmp, 'model')
            with open(modfn, 'w') as f:
                f.write('Test.')
            optfn = modfn + '.opt'
            base_opt = {
                'model': 'tests.test_params:_ExampleUpgradeOptAgent',
                'dict_file': modfn + '.dict',
                'model_file': modfn,
            }
            with open(optfn, 'w') as f:
                json.dump(base_opt, f)

            pp = ParlaiParser(True, True)
            opt = pp.parse_args(['--model-file', modfn])
            agents.create_agent(opt)
Пример #15
0
    def setUp(self):
        disconnect_path = os.path.join(parent_dir, 'disconnect-test.pickle')
        if os.path.exists(disconnect_path):
            os.remove(disconnect_path)

        argparser = ParlaiParser(False, False)
        argparser.add_parlai_data_path()
        argparser.add_mturk_args()
        self.opt = argparser.parse_args([])
        self.opt['task'] = 'unittest'
        self.opt['assignment_duration_in_seconds'] = 6
        mturk_agent_ids = ['mturk_agent_1']
        self.mturk_manager = MTurkManager(
            opt=self.opt.copy(), mturk_agent_ids=mturk_agent_ids
        )
        self.worker_manager = self.mturk_manager.worker_manager

        self.worker_state_1 = self.worker_manager.worker_alive(TEST_WORKER_ID_1)
        self.worker_state_2 = self.worker_manager.worker_alive(TEST_WORKER_ID_2)
        self.worker_state_3 = self.worker_manager.worker_alive(TEST_WORKER_ID_3)
Пример #16
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser()
    parser.add_argument('-n', '--num-examples', default=10)
    opt = parser.parse_args()

    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    with world:
        for k in range(int(opt['num_examples'])):
            world.parley()
            print(world.display() + "\n~~")
            if world.epoch_done():
                print("EPOCH DONE")
                break
Пример #17
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser()
    parser.add_argument('-n', '--num-examples', default=10, type=int)
    opt = parser.parse_args()

    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    with world:
        for _ in range(opt['num_examples']):
            world.parley()
            print(world.display() + '\n~~')
            if world.epoch_done():
                print('EPOCH DONE')
                break
Пример #18
0
    def _create_safety_model(self, custom_model_file, device):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        TransformerClassifierAgent.add_cmdline_args(parser, partial_opt=None)
        parser.set_params(
            model='transformer/classifier',
            model_file=custom_model_file,
            print_scores=True,
            data_parallel=False,
        )
        safety_opt = parser.parse_args([])
        safety_opt["override"]["no_cuda"] = False if "cuda" in device else True

        if "cuda:" in device:
            safety_opt["override"]["gpu"] = int(device.split(":")[1])
        elif "cuda" in device:
            safety_opt["override"]["gpu"] = 0

        return create_agent(safety_opt, requireModelExists=True)
Пример #19
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=10)
    opt = parser.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    with world:
        for k in range(int(opt['num_examples'])):
            world.parley()
            print(world.display() + "\n~~")
            if world.epoch_done():
                print("EPOCH DONE")
                break
Пример #20
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.getcwd())

    mturk_agent_1_id = 'mturk_agent_1'
    mturk_agent_2_id = 'mturk_agent_2'
    human_agent_1_id = 'human_1'
    human_agent_2_id = 'human_2'

    # Create the MTurk agents
    opt.update(task_config)
    opt['conversation_id'] = str(int(time.time()))

    opt['mturk_agent_ids'] = [mturk_agent_1_id, mturk_agent_2_id]
    opt['all_agent_ids'] = [
        human_agent_1_id, human_agent_2_id, mturk_agent_1_id, mturk_agent_2_id
    ]

    opt['agent_id'] = mturk_agent_1_id
    mturk_agent_1 = MTurkAgent(opt=opt)

    opt['agent_id'] = mturk_agent_2_id
    mturk_agent_2 = MTurkAgent(opt=opt)

    # Create the local human agents
    human_agent_1 = LocalHumanAgent(opt=None)
    human_agent_1.id = human_agent_1_id
    human_agent_2 = LocalHumanAgent(opt=None)
    human_agent_2.id = human_agent_2_id

    world = MultiAgentDialogWorld(
        opt=opt,
        agents=[human_agent_1, human_agent_2, mturk_agent_1, mturk_agent_2])

    while not world.episode_done():
        world.parley()

    world.shutdown()
Пример #21
0
    def _test_resize_embeddings(self, model):
        with testing_utils.tempdir() as tmpdir:
            model_file = os.path.join(tmpdir, 'model_file')
            _, _ = testing_utils.train_model(
                dict(
                    model=model,
                    task='integration_tests:short_fixed',
                    n_layers=1,
                    n_encoder_layers=2,
                    n_decoder_layers=4,
                    num_epochs=1,
                    dict_tokenizer='bytelevelbpe',
                    bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
                    bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
                    bpe_add_prefix_space=False,
                    model_file=model_file,
                    save_after_valid=True,
                ))

            # now create agent with special tokens
            parser = ParlaiParser()
            parser.set_params(
                model=model,
                task='integration_tests:short_fixed',
                n_layers=1,
                n_encoder_layers=2,
                n_decoder_layers=4,
                dict_tokenizer='bytelevelbpe',
                bpe_vocab=DEFAULT_BYTELEVEL_BPE_VOCAB,
                bpe_merge=DEFAULT_BYTELEVEL_BPE_MERGE,
                bpe_add_prefix_space=False,
                model_file=model_file,
                save_after_valid=True,
                special_tok_lst='PARTY,PARROT',
            )
            opt = parser.parse_args([])
            agent = create_agent(opt)
            # assert that the embeddings were resized
            assert agent.resized_embeddings
            # assert model has special tokens
            self.assertEqual(agent.special_toks, ['PARTY', 'PARROT'])
Пример #22
0
    def get_bot_agents(args: DictConfig,
                       model_opts: Dict[str, str],
                       no_cuda=False) -> Dict[str, dict]:
        """
        Return shared bot agents.

        Pass in model opts with the `model_opts` arg, where `model_opts` is a dictionary
        whose keys are model names and whose values are strings that specify model
        params (i.e. `--model image_seq2seq`).
        """

        # Set up overrides
        model_overrides = {
            'model_parallel': args.blueprint.task_model_parallel
        }
        if no_cuda:
            # If we load many models at once, we have to keep it on CPU
            model_overrides['no_cuda'] = no_cuda
        else:
            logging.warning(
                'WARNING: MTurk task has no_cuda FALSE. Models will run on GPU. Will '
                'not work if loading many models at once.')

        # Convert opt strings to Opt objects
        parser = ParlaiParser(True, True)
        parser.set_params(**model_overrides)
        processed_opts = {}
        for name, opt_string in model_opts.items():
            processed_opts[name] = parser.parse_args(opt_string.split())

        # Load and share all model agents
        logging.info(
            f'Got {len(list(processed_opts.keys()))} models: {processed_opts.keys()}.'
        )
        shared_bot_agents = {}
        for model_name, model_opt in processed_opts.items():
            logging.info('\n\n--------------------------------')
            logging.info(f'model_name: {model_name}, opt_dict: {model_opt}')
            model_agent = create_agent(model_opt, requireModelExists=True)
            shared_bot_agents[model_name] = model_agent.share()
        return shared_bot_agents
Пример #23
0
    def __init__(self, args=None, **kwargs):
        """Initialize the predictor, setting up opt automatically if needed.

        Args is expected to be in the same format as sys.argv: e.g. a list in
        the form ['--model', 'seq2seq', '-hs', 128, '-lr', 0.5].

        kwargs is interpreted by appending '--' to it and replacing underscores
        with hyphens, so 'dict_file=/tmp/dict.tsv' would be interpreted as
        '--dict-file /tmp/dict.tsv'.
        """
        from parlai.core.params import ParlaiParser
        from parlai.core.agents import create_agent

        if args is None:
            args = []
        for k, v in kwargs.items():
            args.append('--' + str(k).replace('_', '-'))
            args.append(str(v))
        parser = ParlaiParser(True, True)
        self.opt = parser.parse_args(args)
        self.agent = create_agent(self.opt)
Пример #24
0
    def __init__(self, args=None, **kwargs):
        """Initializes the predictor, setting up opt automatically if necessary.

        Args is expected to be in the same format as sys.argv: e.g. a list in
        the form ['--model', 'seq2seq', '-hs', 128, '-lr', 0.5].

        kwargs is interpreted by appending '--' to it and replacing underscores
        with hyphens, so 'dict_file=/tmp/dict.tsv' would be interpreted as
        '--dict-file /tmp/dict.tsv'.
        """
        from parlai.core.params import ParlaiParser
        from parlai.core.agents import create_agent

        if args is None:
            args = []
        for k, v in kwargs.items():
            args.append('--' + str(k).replace('_', '-'))
            args.append(str(v))
        parser = ParlaiParser(True, True, model_argv=args)
        self.opt = parser.parse_args(args)
        self.agent = create_agent(self.opt)
Пример #25
0
def setup_args():
    """
    Set up args, specifically for the port number.

    :return: A parser that parses the port from commandline arguments.
    """
    parser = ParlaiParser(False, False)
    parser_grp = parser.add_argument_group('Terminal Chat')
    parser_grp.add_argument('--port',
                            default=35496,
                            type=int,
                            help='Port to run the terminal chat server')
    parser_grp.add_argument("--host",
                            default="localhost",
                            type=str,
                            help="Host to run the terminal chat server")
    parser_grp.add_argument("--jsonl",
                            default="conv.jsonl",
                            type=str,
                            help="Where to store conversation")
    return parser.parse_args()
def main():
    parser = ParlaiParser(True, True)
    parser.set_defaults(batchsize=10,
                        sample=True,
                        max_seq_len=256,
                        beam_size=3,
                        annealing_topk=None,
                        annealing=0.6,
                        length_penalty=0.7)

    ConvAIWorld.add_cmdline_args(parser)
    TransformerAgent.add_cmdline_args(parser)
    opt = parser.parse_args()

    agent = TransformerAgent(opt)
    world = ConvAIWorld(opt, [agent])

    while True:
        try:
            world.parley()
        except Exception as e:
            print('Exception: {}'.format(e))
Пример #27
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    opt = parser.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    print(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs:
    while True:
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if world.epoch_done():
            print("EPOCH DONE")
            break
    def run_actor(self):
        MODEL_FILE = '/Volumes/Data/ParlAI/from_pretrained_wiki/wikiqa_tdifd'
        DB_PATH = '/Volumes/Data/ParlAI/from_pretrained_wiki/wikiqa_tdifd.db'
        TFIDF_PATH = '/Volumes/Data/ParlAI/from_pretrained_wiki/wikiqa_tdifd.tfidf'
        # keep things quiet
        logger.setLevel(ERROR)
        parser = ParlaiParser(True, True)
        parser.set_defaults(
            model_file=MODEL_FILE,
            interactive_mode=True,
        )
        opt = parser.parse_args([], print_args=False)
        opt['interactive_mode'] = True
        agent = create_agent(opt)
        train_world = create_task(opt, agent)
        # pass examples to dictionary
        while not train_world.epoch_done():
            train_world.parley()
        agent.observe(obs)
        reply = 'Hello'

        return reply
Пример #29
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    opt = parser.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    print(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs:
    while True:
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if world.epoch_done():
            print("EPOCH DONE")
            break
Пример #30
0
def main():
    random.seed(42)
    # Get command line arguments
    parser = ParlaiParser()
    parser.add_argument(
        '-n',
        '--num-examples',
        default=-1,
        type=int,
        help='Total number of exs to convert, -1 to convert all examples',
    )
    parser.add_argument(
        '-of',
        '--outfile',
        default=None,
        type=str,
        help='Output file where to save, by default will be created in /tmp',
    )
    parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
    parser.set_defaults(datatype='train:evalmode')
    opt = parser.parse_args()
    build_cands(opt)
Пример #31
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=10)
    # by default we want to display info about the validation set
    parser.set_defaults(datatype='valid')
    opt = parser.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    with world:
        for k in range(int(opt['num_examples'])):
            world.parley()
            print(world.display() + "\n~~")
            if world.epoch_done():
                print("EPOCH DONE")
                break
Пример #32
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()

    # The dialog model we want to evaluate
    from parlai.agents.ir_baseline.ir_baseline import IrBaselineAgent
    IrBaselineAgent.add_cmdline_args(argparser)
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.getcwd())
    model_agent = IrBaselineAgent(opt=opt)

    # The task that we will evaluate the dialog model on
    task_opt = {}
    task_opt['datatype'] = 'test'
    task_opt['datapath'] = opt['datapath']
    task_opt['task'] = '#MovieDD-Reddit'

    # Create the MTurk agent which provides a chat interface to the Turker
    opt.update(task_config)
    mturk_agent_id = 'Worker'
    opt['agent_id'] = mturk_agent_id
    opt['mturk_agent_ids'] = [mturk_agent_id]
    opt['all_agent_ids'] = [
        ModelEvaluatorWorld.evaluator_agent_id, mturk_agent_id
    ]
    opt['conversation_id'] = str(int(time.time()))

    mturk_agent = MTurkAgent(opt=opt)

    world = ModelEvaluatorWorld(opt=opt,
                                model_agent=model_agent,
                                task_opt=task_opt,
                                mturk_agent=mturk_agent)

    while not world.episode_done():
        world.parley()

    world.shutdown()
Пример #33
0
    def setUp(self):
        argparser = ParlaiParser(False, False)
        argparser.add_parlai_data_path()
        argparser.add_mturk_args()
        self.opt = argparser.parse_args()
        self.opt['task'] = 'unittest'
        self.opt['assignment_duration_in_seconds'] = 6
        mturk_agent_ids = ['mturk_agent_1']
        self.mturk_manager = MTurkManager(opt=self.opt.copy(),
                                          mturk_agent_ids=mturk_agent_ids)
        self.worker_manager = self.mturk_manager.worker_manager
        self.mturk_manager.send_message = mock.MagicMock()
        self.mturk_manager.send_state_change = mock.MagicMock()
        self.mturk_manager.send_command = mock.MagicMock()

        self.turk_agent = MTurkAgent(
            self.opt.copy(),
            self.mturk_manager,
            TEST_HIT_ID_1,
            TEST_ASSIGNMENT_ID_1,
            TEST_WORKER_ID_1,
        )
Пример #34
0
    def test_fused_adam(self):
        with self.assertRaises(ImportError):
            # we should crash if the user tries not giving --opt adam
            testing_utils.train_model(
                dict(
                    model_file='zoo:unittest/apex_fused_adam/model',
                    task='integration_tests:nocandidate',
                ))
        # no problem if we give the option

        pp = ParlaiParser(True, True)
        opt = pp.parse_args([
            '--model-file',
            'zoo:unittest/apex_fused_adam/model',
            '--dict-file',
            'zoo:unittest/apex_fused_adam/model.dict',
            '--task',
            'integration_tests:nocandidate',
            '--optimizer',
            'adam',
        ])
        create_agent(opt, requireModelExists=True)
Пример #35
0
def main(args=None):
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                        help=('task to use for valid/test (defaults to the ' +
                              'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                        type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                        type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                        type=float, default=2)
    train.add_argument('-le', '--log-every-n-epochs',
                        type=int, default=0)
    train.add_argument('-vtim', '--validation-every-n-secs',
                        type=float, default=-1)
    train.add_argument('-ve', '--validation-every-n-epochs',
                        type=int, default=0)
    train.add_argument('-vme', '--validation-max-exs',
                        type=int, default=-1,
                        help='max examples to use during validation (default ' +
                             '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                        type=int, default=5,
                        help=('number of iterations of validation where result '
                              + 'does not improve before we stop training'))
    train.add_argument('-dbf', '--dict-build-first',
                        type='bool', default=True,
                        help='build dictionary first before training agent')
    train.add_argument('--chosen-metric', default='accuracy',
                       help='metric with which to measure improvement')
    opt = parser.parse_args(args=args)
    if opt.get('cross_validation_splits_count', 0) > 1 and opt.get('cross_validation_model_index') is None:
        train_cross_valid(opt)
    else:
        train_model(opt)
Пример #36
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_messenger_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt['password'] = '******'  # If password is none anyone can chat

    messenger_manager = MessengerManager(opt=opt)
    messenger_manager.setup_server()
    messenger_manager.init_new_state()

    def get_overworld(opt, agent):
        return MessengerOverworld(opt, agent)

    onboard_functions = {name: worlds[0].run for (name, worlds)
                         in MessengerOverworld.DEMOS.items()}
    messenger_manager.set_onboard_functions(onboard_functions)
    task_functions = {name: worlds[1].run for (name, worlds)
                      in MessengerOverworld.DEMOS.items()}
    assign_agent_roles = {name: worlds[1].assign_roles for (name, worlds)
                          in MessengerOverworld.DEMOS.items()}
    agents_required = {name: worlds[1].MAX_AGENTS for (name, worlds)
                       in MessengerOverworld.DEMOS.items()}
    messenger_manager.set_agents_required(agents_required)

    messenger_manager.set_overworld_func(get_overworld)
    messenger_manager.setup_socket()
    try:
        messenger_manager.start_new_run()
        messenger_manager.start_task(
            assign_role_functions=assign_agent_roles,
            task_functions=task_functions,
        )
    except BaseException:
        raise
    finally:
        messenger_manager.shutdown()
Пример #37
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num_examples', default=1000)
    parser.add_argument('-d', '--display_examples', type='bool', default=False)
    opt = parser.parse_args()
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs:
    for k in range(int(opt['num_examples'])):
        world.parley()
        print("---")
        if opt['display_examples']:
            print(world.display() + "\n~~")
        print(world.report())
        if world.epoch_done():
            print("EPOCH DONE")
            break
    world.shutdown()
Пример #38
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, False)
    parser.set_defaults(datatype='train:ordered')

    opt = parser.parse_args()
    bsz = opt.get('batchsize', 1)
    opt['no_cuda'] = False
    opt['gpu'] = 0
    opt['num_epochs'] = 1
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    logger = ProgressLogger(should_humanize=False)
    print("Beginning image extraction...")
    exs_seen = 0
    total_exs = world.num_examples()
    while not world.epoch_done():
        world.parley()
        exs_seen += bsz
        logger.log(exs_seen, total_exs)
    print("Finished extracting images")
Пример #39
0
def parse_options():
    parser = ParlaiParser()
    pth_file = "world_best.pth"
    # pth_file = 'checkpoints/world-07-Nov-2019-15:56:51/world_epoch_02900.pth'
    parser.add_argument(
        "--load-path",
        type=str,
        default=pth_file,
        help="path to pth file of the world checkpoint",
    )
    parser.add_argument(
        "--print-conv",
        default=False,
        action="store_true",
        help="whether to print the conversation between bots or not",
    )
    parser.add_argument(
        "--conv-save-path",
        type=str,
        default=None,
        help="whether to print the conversation between bots or not",
    )
    return parser.parse_args()
Пример #40
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=100000000)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.set_defaults(datatype='valid')
    opt = parser.parse_args()
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs:
    for k in range(int(opt['num_examples'])):
        world.parley()
        print("---")
        if opt['display_examples']:
            print(world.display() + "\n~~")
        print(world.report())
        if world.epoch_done():
            print("EPOCH DONE")
            break
    world.shutdown()
Пример #41
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    opt = parser.parse_args()
    # Create model and assign it to the specified task
    opt['model'] = 'ir_baseline'
    agent = create_agent(opt)
    opt['datatype'] = 'train:ordered'
    world = create_task(opt, agent)

    # Show some example dialogs:
    for _ in range(len(world)):
        world.parley()
        if opt['display_examples']:
            print(world.display() + "\n~~")
        if world.epoch_done():
            print("TRAIN EPOCH DONE")
            break
    world.shutdown()
    if opt['model_file']:
        agent.save(opt['model_file'])

    # Now eval on validation data.
    opt['datatype'] = 'valid'
    valid_world = create_task(opt, agent)
    for _ in range(len(valid_world)):
        valid_world.parley()
        print("---")
        if opt['display_examples']:
            print(valid_world.display() + "\n~~")
        print(valid_world.report())
        if valid_world.epoch_done():
            print("VALIDATION DONE")
            break
    valid_world.shutdown()
Пример #42
0
    def test_knowledge_retriever(self):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        KnowledgeRetrieverAgent.add_cmdline_args(parser)
        parser.set_params(
            model='projects:wizard_of_wikipedia:knowledge_retriever',
            add_token_knowledge=True,
        )
        knowledge_opt = parser.parse_args([], print_args=False)
        knowledge_agent = create_agent(knowledge_opt)

        knowledge_agent.observe(
            {
                'text': 'what do you think of mountain dew?',
                'chosen_topic': 'Mountain Dew',
                'episode_done': False,
            }
        )

        knowledge_act = knowledge_agent.act()

        title = knowledge_act['title']
        self.assertEqual(title, 'Mountain Dew', 'Did not save chosen topic correctly')

        knowledge = knowledge_act['text']
        self.assertIn(
            TOKEN_KNOWLEDGE, knowledge, 'Knowledge token was not inserted correctly'
        )

        checked_sentence = knowledge_act['checked_sentence']
        self.assertEqual(
            checked_sentence,
            'Mountain Dew (stylized as Mtn Dew) is a carbonated soft drink brand produced and owned by PepsiCo.',
            'Did not correctly choose the checked sentence',
        )
Пример #43
0
import os
import json
from parlai.core.params import ParlaiParser
from parlai.tasks.light_dialog.agents import DefaultTeacher

SPLIT = 'train'
BASE_DIR = '.'

parser = ParlaiParser(True, True, 'Evaluate a model')
DefaultTeacher.add_cmdline_args(parser)
args = ['-dt', SPLIT]
opt = parser.parse_args(args, print_args=False)
teacher = DefaultTeacher(opt)


def first_word(s):
    l = s.split()
    return '' if len(l) == 0 else l[0]


def extractTurn(t, step):
    found_partner = False
    for s in step['text'].split('\n'):
        fw = first_word(s)
        if fw.find('partner') != -1:
            found_partner = True
        if found_partner:
            if fw == '_partner_say':
                t['partner_say'] = s[len(fw):]
            elif fw == '_self_say':
                t['self_say'] = s[len(fw):]
Пример #44
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    # Initialize a SQuAD teacher agent, which we will get context from
    module_name = 'parlai.tasks.squad.agents'
    class_name = 'DefaultTeacher'
    my_module = importlib.import_module(module_name)
    task_class = getattr(my_module, class_name)
    task_opt = {}
    task_opt['datatype'] = 'train'
    task_opt['datapath'] = opt['datapath']

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=[mturk_agent_id])
    mturk_manager.setup_server()

    def run_onboard(worker):
        world = QADataCollectionOnboardWorld(opt=opt, mturk_agent=worker)
        while not world.episode_done():
            world.parley()
        world.shutdown()

    mturk_manager.set_onboard_function(onboard_function=None)

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        mturk_manager.ready_to_accept_workers()

        def check_workers_eligibility(workers):
            return workers

        eligibility_function = {
            'func': check_workers_eligibility,
            'multiple': True,
        }

        def assign_worker_roles(worker):
            worker[0].id = mturk_agent_id

        global run_conversation

        def run_conversation(mturk_manager, opt, workers):
            task = task_class(task_opt)
            mturk_agent = workers[0]
            world = QADataCollectionWorld(opt=opt,
                                          task=task,
                                          mturk_agent=mturk_agent)
            while not world.episode_done():
                world.parley()
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(eligibility_function=eligibility_function,
                                 assign_role_function=assign_worker_roles,
                                 task_function=run_conversation)
    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #45
0
def main():
    """
    This task consists of one local human agent and two MTurk agents,
    each MTurk agent will go through the onboarding step to provide
    information about themselves, before being put into a conversation.
    You can end the conversation by sending a message ending with
    `[DONE]` from human_1.
    """
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    mturk_agent_1_id = 'mturk_agent_1'
    mturk_agent_2_id = 'mturk_agent_2'
    human_agent_1_id = 'human_1'
    mturk_agent_ids = [mturk_agent_1_id, mturk_agent_2_id]
    mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=mturk_agent_ids)
    mturk_manager.setup_server()

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        def run_onboard(worker):
            world = MTurkMultiAgentDialogOnboardWorld(opt=opt,
                                                      mturk_agent=worker)
            while not world.episode_done():
                world.parley()
            world.shutdown()

        # You can set onboard_function to None to skip onboarding
        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            # Create mturk agents
            mturk_agent_1 = workers[0]
            mturk_agent_2 = workers[1]

            # Create the local human agents
            human_agent_1 = LocalHumanAgent(opt=None)
            human_agent_1.id = human_agent_1_id

            world = MTurkMultiAgentDialogWorld(
                opt=opt, agents=[human_agent_1, mturk_agent_1, mturk_agent_2])

            while not world.episode_done():
                world.parley()

            world.shutdown()

        mturk_manager.start_task(eligibility_function=check_worker_eligibility,
                                 assign_role_function=assign_worker_roles,
                                 task_function=run_conversation)

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #46
0
def main():
    """This task consists of one agent, model or MTurk worker, talking to an
    MTurk worker to negotiate a deal.
    """
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('-min_t',
                           '--min_turns',
                           default=5,
                           type=int,
                           help='minimum number of turns')
    argparser.add_argument('-mt',
                           '--max_turns',
                           default=10,
                           type=int,
                           help='maximal number of chat turns')
    argparser.add_argument(
        '-mx_rsp_time',
        '--max_resp_time',
        default=150,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument(
        '--ag_shutdown_time',
        default=120,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument(
        '--persona-type',
        default='both',
        type=str,
        choices=['both', 'self', 'other'],
        help='Which personas to load from personachat',
    )
    opt = argparser.parse_args()

    directory_path = os.path.dirname(os.path.abspath(__file__))
    opt['task'] = os.path.basename(directory_path)

    opt['extract_personas_path'] = os.path.join(opt['datapath'], opt['task'])
    opt.update(task_config)

    mturk_agent_ids = ['PERSON_1']

    mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=mturk_agent_ids)

    mturk_manager.setup_server(task_directory_path=directory_path)

    personas_generator = PersonasGenerator(opt)
    opt['personas_generator'] = personas_generator

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        if not opt['is_sandbox']:
            # ADD BLOCKED WORKERS HERE
            blocked_worker_list = []
            for w in blocked_worker_list:
                mturk_manager.block_worker(
                    w,
                    'We found that you have unexpected behaviors in our '
                    'previous HITs. For more questions please email us.',
                )

        def run_onboard(worker):
            pass

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            worker = workers[0]
            world = RephrasePersonaWorld(opt, worker)
            while not world.episode_done():
                world.parley()
            world.save_data()
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation,
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #47
0
    print('---Extracting and saving personas---')
    teacher_name = 'personachat:{}'.format(opt.get('persona_type'))
    teacher_name += 'Revised' if opt.get('revised') else 'Original'
    opt['task'] = teacher_name
    if 'personas_path' not in opt:
        personas_path = '{}/data/wizard_of_perZOna/personas-{}/'.format(
            os.getcwd(), opt['task']
        )
        opt['personas_path'] = personas_path
    opt['datatype'] = 'train:ordered:stream'
    opt['numthreads'] = 1
    opt['batchsize'] = 1
    personas = extract_and_save(opt)
    return personas


if __name__ == '__main__':
    parser = ParlaiParser()
    parser.add_argument(
        '--persona-type',
        default='both',
        type=str,
        choices=['both', 'self', 'other'],
        help='Which personas to load from personachat',
    )
    parser.add_argument(
        '--revised', default=False, type='bool', help='Whether to use revised personas'
    )
    opt = parser.parse_args()
    personas = main(opt)
Пример #48
0
def main():
    """
    This task consists of one local human agent and two MTurk agents,
    each MTurk agent will go through the onboarding step to provide
    information about themselves, before being put into a conversation.
    You can end the conversation by sending a message ending with
    `[DONE]` from human_1.
    """
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    mturk_agent_1_id = 'mturk_agent_1'
    mturk_agent_2_id = 'mturk_agent_2'
    human_agent_1_id = 'human_1'
    mturk_agent_ids = [mturk_agent_1_id, mturk_agent_2_id]
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=mturk_agent_ids
    )
    mturk_manager.setup_server()

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        def run_onboard(worker):
            world = MTurkMultiAgentDialogOnboardWorld(
                opt=opt,
                mturk_agent=worker
            )
            while not world.episode_done():
                world.parley()
            world.shutdown()

        # You can set onboard_function to None to skip onboarding
        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        eligibility_function = {
            'func': check_worker_eligibility,
            'multiple': False,
        }

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            # Create mturk agents
            mturk_agent_1 = workers[0]
            mturk_agent_2 = workers[1]

            # Create the local human agents
            human_agent_1 = LocalHumanAgent(opt=None)
            human_agent_1.id = human_agent_1_id

            world = MTurkMultiAgentDialogWorld(
                opt=opt,
                agents=[human_agent_1, mturk_agent_1, mturk_agent_2]
            )

            while not world.episode_done():
                world.parley()

            world.shutdown()

        mturk_manager.start_task(
            eligibility_function=eligibility_function,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #49
0
    argparser.add_argument('--num_machines', type=int, default=1)
    argparser.add_argument('--job_timeout', type=float, default=3600*4)

    argparser.add_argument('--split', action='store_true', default=False)
    argparser.add_argument('--train', action='store_true', default=False)
    argparser.add_argument('--eval', action='store_true', default=False)
    argparser.add_argument('--seq2seq', action='store_true', default=False)
    argparser.add_argument('--constrain', action='store_true', default=False)
    argparser.add_argument('--rounds_breakdown', action='store_true', default=False)
    argparser.add_argument('--data_breakdown', action='store_true', default=False)
    argparser.add_argument('--ablation', action='store_true', default=False)

    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()

    # ============ below copied from projects/graph_world2/train.py ============
    opt['bidir'] = True
    opt['action_type_emb_dim'] = 5
    opt['counter_max'] = 3
    opt['counter_emb_dim'] = 5
    # ============ above copied from projects/graph_world2/train.py ============

    if opt['split']:
        overall_split(opt)
        quit()
    if opt['train']:
        overall_run(opt, seq2seq=opt['seq2seq'])
        quit()
    if opt['eval']:
Пример #50
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                       help=('task to use for valid/test (defaults to the '
                             'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                       type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                       type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                       type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                       type=float, default=-1)
    train.add_argument('-vme', '--validation-max-exs',
                       type=int, default=-1,
                       help='max examples to use during validation (default '
                            '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                       type=int, default=5,
                       help=('number of iterations of validation where result'
                             ' does not improve before we stop training'))
    train.add_argument('-vmt', '--validation-metric', default='accuracy',
                       help='key into report table for selecting best '
                            'validation')
    train.add_argument('-dbf', '--dict-build-first',
                       type='bool', default=True,
                       help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    while True:
        world.parley()
        parleys += 1

        if opt['num_epochs'] > 0 and parleys >= max_parleys:
            print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
            print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
            break
        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))
            logs.append('parleys:{}'.format(parleys))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()

            if hasattr(train_report, 'get') and train_report.get('total'):
                total_exs += train_report['total']
                logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec = train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            print(log)
            log_time.reset()

        if (opt['validation_every_n_secs'] > 0 and
                validate_time.time() > opt['validation_every_n_secs']):
            valid_report, valid_world = run_eval(
                agent, opt, 'valid', opt['validation_max_exs'],
                valid_world=valid_world)
            if valid_report[opt['validation_metric']] > best_valid:
                best_valid = valid_report[opt['validation_metric']]
                impatience = 0
                print('[ new best {}: {} ]'.format(
                    opt['validation_metric'], best_valid))
                world.save_agents()
                saved = True
                if opt['validation_metric'] == 'accuracy' and best_valid == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print('[ did not beat best {}: {} impatience: {} ]'.format(
                        opt['validation_metric'], round(best_valid, 4),
                        impatience))
            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping training. ]')
                break
    world.shutdown()
    if not saved:
        world.save_agents()
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Пример #51
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()

    # The dialog model we want to evaluate
    from parlai.agents.ir_baseline.ir_baseline import IrBaselineAgent
    IrBaselineAgent.add_cmdline_args(argparser)
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    # The task that we will evaluate the dialog model on
    task_opt = {}
    task_opt['datatype'] = 'test'
    task_opt['datapath'] = opt['datapath']
    task_opt['task'] = '#MovieDD-Reddit'

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=[mturk_agent_id]
    )
    mturk_manager.setup_server()

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        def run_onboard(worker):
            world = ModelEvaluatorOnboardWorld(opt=opt, mturk_agent=worker)
            while not world.episode_done():
                world.parley()
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(worker):
            worker[0].id = mturk_agent_id

        global run_conversation

        def run_conversation(mturk_manager, opt, workers):
            mturk_agent = workers[0]

            model_agent = IrBaselineAgent(opt=opt)

            world = ModelEvaluatorWorld(
                opt=opt,
                model_agent=model_agent,
                task_opt=task_opt,
                mturk_agent=mturk_agent
            )

            while not world.episode_done():
                world.parley()
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )
    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #52
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    # Initialize a SQuAD teacher agent, which we will get context from
    module_name = 'parlai.tasks.squad.agents'
    class_name = 'DefaultTeacher'
    my_module = importlib.import_module(module_name)
    task_class = getattr(my_module, class_name)
    task_opt = {}
    task_opt['datatype'] = 'train'
    task_opt['datapath'] = opt['datapath']

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=[mturk_agent_id]
    )
    mturk_manager.setup_server()

    def run_onboard(worker):
        world = QADataCollectionOnboardWorld(opt=opt, mturk_agent=worker)
        while not world.episode_done():
            world.parley()
        world.shutdown()

    mturk_manager.set_onboard_function(onboard_function=None)

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        mturk_manager.ready_to_accept_workers()

        def check_workers_eligibility(workers):
            return workers

        eligibility_function = {
            'func': check_workers_eligibility,
            'multiple': True,
        }

        def assign_worker_roles(worker):
            worker[0].id = mturk_agent_id

        global run_conversation

        def run_conversation(mturk_manager, opt, workers):
            task = task_class(task_opt)
            mturk_agent = workers[0]
            world = QADataCollectionWorld(
                opt=opt,
                task=task,
                mturk_agent=mturk_agent
            )
            while not world.episode_done():
                world.parley()
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(
            eligibility_function=eligibility_function,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )
    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #53
0
def main():
    # Get command line arguments
    argparser = ParlaiParser()
    DictionaryAgent.add_cmdline_args(argparser)
    ParsedRemoteAgent.add_cmdline_args(argparser)
    argparser.add_argument('--num-examples', default=1000, type=int)
    argparser.add_argument('--num-its', default=100, type=int)
    argparser.add_argument('--dict-max-exs', default=10000, type=int)
    parlai_home = os.environ['PARLAI_HOME']
    if '--remote-cmd' not in sys.argv:
        if os.system('which luajit') != 0:
            raise RuntimeError('Could not detect torch luajit installed: ' +
                               'please install torch from http://torch.ch ' +
                               'or manually set --remote-cmd for this example.')
        sys.argv.append('--remote-cmd')
        sys.argv.append('luajit {}/parlai/agents/'.format(parlai_home) +
                        'memnn_luatorch_cpu/memnn_zmq_parsed.lua')
    if '--remote-args' not in sys.argv:
        sys.argv.append('--remote-args')
        sys.argv.append('{}/examples/'.format(parlai_home) +
                        'memnn_luatorch_cpu/params_default.lua')

    opt = argparser.parse_args()

    # set up dictionary
    print('Setting up dictionary.')
    dictionary = DictionaryAgent(opt)
    if not opt.get('dict_file'):
        # build dictionary since we didn't load it
        ordered_opt = copy.deepcopy(opt)
        ordered_opt['datatype'] = 'train:ordered'
        ordered_opt['numthreads'] = 1
        world_dict = create_task(ordered_opt, dictionary)

        print('Dictionary building on training data.')
        cnt = 0
        # pass examples to dictionary
        for _ in world_dict:
            cnt += 1
            if cnt > opt['dict_max_exs'] and opt['dict_max_exs'] > 0:
                print('Processed {} exs, moving on.'.format(
                      opt['dict_max_exs']))
                # don't wait too long...
                break

            world_dict.parley()

        # we need to save the dictionary to load it in memnn (sort it by freq)
        dictionary.sort()
        dictionary.save('/tmp/dict.txt', sort=True)

    print('Dictionary ready, moving on to training.')

    opt['datatype'] = 'train'
    agent = ParsedRemoteAgent(opt, {'dictionary_shared': dictionary.share()})
    world_train = create_task(opt, agent)
    opt['datatype'] = 'valid'
    world_valid = create_task(opt, agent)

    start = time.time()
    with world_train:
        for _ in range(opt['num_its']):
            print('[ training ]')
            for _ in range(opt['num_examples'] * opt.get('numthreads', 1)):
                world_train.parley()
            world_train.synchronize()

            print('[ validating ]')
            world_valid.reset()
            for _ in world_valid:  # check valid accuracy
                world_valid.parley()

            print('[ validation summary. ]')
            report_valid = world_valid.report()
            print(report_valid)
            if report_valid['accuracy'] > 0.95:
                break

        # show some example dialogs after training:
        world_valid = create_task(opt, agent)
        for _k in range(3):
            world_valid.parley()
            print(world_valid.display())

    print('finished in {} s'.format(round(time.time() - start, 2)))
Пример #54
0
def main():
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    mturk_agent_1_id = 'mturk_agent_1'
    mturk_agent_2_id = 'mturk_agent_2'
    human_agent_1_id = 'human_1'
    mturk_agent_ids = [mturk_agent_1_id, mturk_agent_2_id]
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids = mturk_agent_ids
    )
    mturk_manager.setup_server()

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        def run_onboard(worker):
            world = MTurkMultiAgentDialogOnboardWorld(
                opt=opt,
                mturk_agent=worker
            )
            while not world.episode_done():
                world.parley()
            world.shutdown()

        # You can set onboard_function to None to skip onboarding
        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            # Create mturk agents
            mturk_agent_1 = workers[0]
            mturk_agent_2 = workers[1]

            # Create the local human agents
            human_agent_1 = LocalHumanAgent(opt=None)
            human_agent_1.id = human_agent_1_id

            world = MTurkMultiAgentDialogWorld(
                opt=opt,
                agents=[human_agent_1, mturk_agent_1, mturk_agent_2]
            )

            while not world.episode_done():
                world.parley()

            world.shutdown()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )

    except:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #55
0
def main():
    """This task consists of an MTurk agent evaluating a chit-chat model. They
    are asked to chat to the model adopting a specific persona. After their
    conversation, they are asked to evaluate their partner on several metrics.
    """
    argparser = ParlaiParser(False, add_model_args=True)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('-mt', '--max-turns', default=10, type=int,
                           help='maximal number of chat turns')
    argparser.add_argument('--max-resp-time', default=240,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('--max-persona-time', type=int,
                           default=300, help='time limit for turker'
                           'entering the persona')
    argparser.add_argument('--ag-shutdown-time', default=120,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('--persona-type', default='both', type=str,
                           choices=['both', 'self', 'other'],
                           help='Which personas to load from personachat')
    argparser.add_argument('--revised', default=False, type='bool',
                           help='Whether to use revised personas')
    argparser.add_argument('-rt', '--range-turn', default='5,6',
                           help='sample range of number of turns')
    argparser.add_argument('--auto-approve-delay', type=int,
                           default=3600 * 24 * 1,
                           help='how long to wait for auto approval')
    argparser.add_argument('--only-masters', type='bool', default=False,
                           help='Set to True to use only master turks for this' +
                                ' test eval, default is %(default)s')

    # ADD MODEL ARGS HERE, UNCOMMENT TO USE KVMEMNN MODEL AS AN EXAMPLE
    # argparser.set_defaults(
    #     model='projects.personachat.kvmemnn.kvmemnn:Kvmemnn',
    #     model_file='models:convai2/kvmemnn/model',
    # )

    opt = argparser.parse_args()

    # add additional model args
    opt['override'] = {
        'no_cuda': True,
        'interactive_mode': True,
        'tensorboard_log': False
    }

    bot = create_agent(opt)
    shared_bot_params = bot.share()
    print(
        '=== Actual bot opt === :\n {}'.format(
            '\n'.join(["[{}] : {}".format(k, v) for k, v in bot.opt.items()])
        )
    )
    folder_name = (
        'master_{}_YOURCOMMENT__'.format(opt['only_masters']) +
        '__'.join(['{}_{}'.format(k, v) for k, v in opt['override'].items()])
    )

    #  this is mturk task, not convai2 task from ParlAI
    opt['task'] = 'convai2:self'
    if 'data_path' not in opt:
        opt['data_path'] = os.getcwd() + '/data/' + folder_name
    opt.update(task_config)

    mturk_agent_ids = ['PERSON_1']

    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=mturk_agent_ids
    )

    persona_generator = PersonasGenerator(opt)
    mturk_manager.setup_server()

    try:
        mturk_manager.start_new_run()
        agent_qualifications = []
        if opt['only_masters']:
            if opt['is_sandbox']:
                agent_qualifications.append(MASTER_QUALIF_SDBOX)
            else:
                agent_qualifications.append(MASTER_QUALIF)
        mturk_manager.create_hits(qualifications=agent_qualifications)

        if not opt['is_sandbox']:
            # ADD SOFT-BLOCKED WORKERS HERE
            # NOTE: blocking qual *must be* specified
            blocked_worker_list = []
            for w in blocked_worker_list:
                print('Soft Blocking {}\n'.format(w))
                mturk_manager.soft_block_worker(w)
                time.sleep(0.1)  # do the sleep to prevent amazon query drop

        def run_onboard(worker):
            worker.persona_generator = persona_generator
            world = PersonaProfileWorld(opt, worker)
            world.parley()
            world.shutdown()
        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            agents = workers[0]
            conv_idx = mturk_manager.conversation_index
            world = Convai2EvalWorld(
                opt=opt,
                agents=[agents],
                range_turn=[int(s) for s in opt['range_turn'].split(',')],
                max_turn=opt['max_turns'],
                max_resp_time=opt['max_resp_time'],
                model_agent_opt=shared_bot_params,
                world_tag='conversation t_{}'.format(conv_idx),
                agent_timeout_shutdown=opt['ag_shutdown_time'],
            )
            world.reset_random()
            while not world.episode_done():
                world.parley()
            world.save_data()

            world.shutdown()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #56
0
def main():
    completed_workers = []
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=[mturk_agent_id]
    )
    mturk_manager.setup_server()
    qual_name = 'ParlAIExcludeQual{}t{}'.format(
        random.randint(10000, 99999), random.randint(10000, 99999))
    qual_desc = (
        'Qualification for a worker not correctly completing the '
        'first iteration of a task. Used to filter to different task pools.'
    )
    qualification_id = \
        mturk_utils.find_or_create_qualification(qual_name, qual_desc)
    print('Created qualification: ', qualification_id)

    def run_onboard(worker):
        world = QualificationFlowOnboardWorld(opt, worker)
        while not world.episode_done():
            world.parley()
        world.shutdown()

    mturk_manager.set_onboard_function(onboard_function=run_onboard)

    try:
        mturk_manager.start_new_run()
        agent_qualifications = [{
            'QualificationTypeId': qualification_id,
            'Comparator': 'DoesNotExist',
            'RequiredToPreview': True
        }]
        mturk_manager.create_hits(qualifications=agent_qualifications)

        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(worker):
            worker[0].id = mturk_agent_id

        global run_conversation

        def run_conversation(mturk_manager, opt, workers):
            mturk_agent = workers[0]
            world = QualificationFlowSoloWorld(
                opt=opt,
                mturk_agent=mturk_agent,
                qualification_id=qualification_id,
                firstTime=(mturk_agent.worker_id not in completed_workers),
            )
            while not world.episode_done():
                world.parley()
            completed_workers.append(mturk_agent.worker_id)
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )
    except BaseException:
        raise
    finally:
        mturk_utils.delete_qualification(qualification_id)
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #57
0
def main():
    """This task consists of one agent, model or MTurk worker, talking to an
    MTurk worker to negotiate a deal.
    """
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('--two_mturk_agents', dest='two_mturk_agents',
                           action='store_true', help='data collection mode '
                           'with converations between two MTurk agents')

    opt = argparser.parse_args()
    opt['task'] = 'dealnodeal'
    opt['datatype'] = 'valid'
    opt.update(task_config)

    local_agent_1_id = 'local_1'
    mturk_agent_ids = ['mturk_agent_1']
    if opt['two_mturk_agents']:
        mturk_agent_ids.append('mturk_agent_2')

    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=mturk_agent_ids
    )

    mturk_manager.setup_server()

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        mturk_manager.set_onboard_function(onboard_function=None)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            agents = workers[:]

            # Create a local agent
            if not opt['two_mturk_agents']:
                if 'model' in opt:
                    local_agent = create_agent(opt)
                else:
                    local_agent = LocalHumanAgent(opt=None)

                local_agent.id = local_agent_1_id
                agents.append(local_agent)

            opt["batchindex"] = mturk_manager.started_conversations

            world = MTurkDealNoDealDialogWorld(
                opt=opt,
                agents=agents
            )

            while not world.episode_done():
                world.parley()

            world.shutdown()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Пример #58
0
def main():
    # Get command line arguments
    argparser = ParlaiParser()
    DictionaryAgent.add_cmdline_args(argparser)
    opt = argparser.parse_args()
    build_dict(opt)