Пример #1
0
    def test_knowledge_retriever(self):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        KnowledgeRetrieverAgent.add_cmdline_args(parser)
        parser.set_params(
            model='projects:wizard_of_wikipedia:knowledge_retriever',
            add_token_knowledge=True,
        )
        knowledge_opt = parser.parse_args([], print_args=False)
        knowledge_agent = create_agent(knowledge_opt)

        knowledge_agent.observe({
            'text': 'what do you think of mountain dew?',
            'chosen_topic': 'Mountain Dew',
            'episode_done': False,
        })

        knowledge_act = knowledge_agent.act()

        title = knowledge_act['title']
        self.assertEqual(title, 'Mountain Dew',
                         'Did not save chosen topic correctly')

        knowledge = knowledge_act['text']
        self.assertIn(TOKEN_KNOWLEDGE, knowledge,
                      'Knowledge token was not inserted correctly')

        checked_sentence = knowledge_act['checked_sentence']
        self.assertEqual(
            checked_sentence,
            'Mountain Dew (stylized as Mtn Dew) is a carbonated soft drink brand produced and owned by PepsiCo.',
            'Did not correctly choose the checked sentence',
        )
Пример #2
0
def create_retriever():
    parser = ParlaiParser(False, False)
    KnowledgeRetrieverAgent.add_cmdline_args(parser)
    parser.set_params(
        model='projects:wizard_of_wikipedia:knowledge_retriever',
        add_token_knowledge=False,
    )
    knowledge_opt = parser.parse_args([])
    return create_agent(knowledge_opt)
Пример #3
0
    def _set_up_knowledge_agent(self, add_token_knowledge=False):
        from parlai.core.params import ParlaiParser

        parser = ParlaiParser(False, False)
        KnowledgeRetrieverAgent.add_cmdline_args(parser, partial_opt=self.opt)
        parser.set_params(
            model='projects:wizard_of_wikipedia:knowledge_retriever',
            add_token_knowledge=add_token_knowledge,
        )
        knowledge_opt = parser.parse_args([])
        self.knowledge_agent = create_agent(knowledge_opt)
Пример #4
0
 def _set_up_knowledge_agent(self,
                             add_token_knowledge: bool = False,
                             shared=None) -> None:
     """
     Set up knowledge agent for knowledge retrieval generated from WoW project.
     """
     parser = ParlaiParser(False, False)
     KnowledgeRetrieverAgent.add_cmdline_args(parser)
     parser.set_params(
         model='projects:wizard_of_wikipedia:knowledge_retriever',
         add_token_knowledge=add_token_knowledge,
     )
     knowledge_opt = parser.parse_args([])
     if shared:
         self.knowledge_agent = KnowledgeRetrieverAgent(
             knowledge_opt, shared.get('knowledge_retriever', None))
     else:
         self.knowledge_agent = KnowledgeRetrieverAgent(knowledge_opt)
Пример #5
0
class InteractiveWorld(InteractiveBaseWorld):
    """
    InteractiveWorld combined with Blended skill talk InteractiveWorld
    and WoW knowledge retrieval InteractiveWorld.
    """
    @staticmethod
    def add_cmdline_args(argparser: ParlaiParser) -> None:
        """
        Add command-line arguments specifically for this task world.
        """
        InteractiveBaseWorld.add_cmdline_args(argparser)
        parser = argparser.add_argument_group('RetNRef Interactive World Args')
        parser.add_argument(
            '--print-checked-sentence',
            type='bool',
            default=True,
            help='Print sentence that the model checks.',
        )
        parser.add_argument(
            '--add-token-knowledge',
            type='bool',
            default=False,
            help='Add knowledge token to retrieved knowledge',
        )

    def __init__(self, opt: Opt, agents: tp.List[tp.Any], shared=None) -> None:
        super().__init__(opt, agents, shared)
        self._set_up_knowledge_agent(opt.get('add_token_knowledge', False),
                                     shared=shared)
        self.print_checked_sentence = opt['print_checked_sentence']

    def _set_up_knowledge_agent(self,
                                add_token_knowledge: bool = False,
                                shared=None) -> None:
        """
        Set up knowledge agent for knowledge retrieval generated from WoW project.
        """
        parser = ParlaiParser(False, False)
        KnowledgeRetrieverAgent.add_cmdline_args(parser)
        parser.set_params(
            model='projects:wizard_of_wikipedia:knowledge_retriever',
            add_token_knowledge=add_token_knowledge,
        )
        knowledge_opt = parser.parse_args([])
        if shared:
            self.knowledge_agent = KnowledgeRetrieverAgent(
                knowledge_opt, shared.get('knowledge_retriever', None))
        else:
            self.knowledge_agent = KnowledgeRetrieverAgent(knowledge_opt)

    def _add_knowledge_to_act(self, act: Message) -> Message:
        """
        After human agent act, if use_knowledge is True, add knowledge to act.
        Knowledge agent first observes human agent's act, then acts itself.
        Key 'knowledge' represents full knowledge consisting of multi knowledge sentences.
        Key 'checked_sentence' represents gold result among full knowledge.
        """
        if self.opt.get('use_knowledge', False):
            try:
                self.knowledge_agent.observe(act, actor_id='apprentice')
                knowledge_act = self.knowledge_agent.act()
            except ValueError:
                warn_once("Knowledge Retrieval Failed Once")
                return act
            act['knowledge'] = knowledge_act['text']
            act['checked_sentence'] = knowledge_act['checked_sentence']
            if self.print_checked_sentence:
                print('[ Using chosen sentence from Wikpedia ]: {}'.format(
                    knowledge_act['checked_sentence']))
            act['title'] = knowledge_act['title']
        return act

    def parley(self) -> None:
        # random initialize human and model persona
        if self.turn_cnt == 0:
            self.p1, self.p2 = self.get_contexts()

        if self.turn_cnt == 0 and self.p1 != '':
            # add the context on to the first message to human
            context_act = Message({
                'id': 'context',
                'text': self.p1,
                'episode_done': False
            })
            # human agent observes his/her persona
            self.agents[0].observe(validate(context_act))
        try:
            # human agent act first
            act = deepcopy(self.agents[0].act())
        except StopIteration:
            self.reset()
            self.finalize_episode()
            self.turn_cnt = 0
            return
        self.acts[0] = act
        if self.turn_cnt == 0 and self.p2 != '':
            # add the context on to the first message to agent 1
            context_act = Message({
                'id': 'context',
                'text': self.p2,
                'episode_done': False
            })
            # model observe its persona
            self.agents[1].observe(validate(context_act))

        # add knowledge to the model observation
        if 'text' in act:
            act = self._add_knowledge_to_act(act)

        # model observe human act and knowledge
        self.agents[1].observe(validate(act))
        # model agent act
        self.acts[1] = self.agents[1].act()

        # add the mdoel reply to the knowledge retriever's dialogue history
        if 'text' in self.acts[1]:
            self.knowledge_agent.observe(validate(self.acts[1]))

        # human agent observes model act
        self.agents[0].observe(validate(self.acts[1]))
        self.update_counters()
        self.turn_cnt += 1

        if act['episode_done']:
            self.finalize_episode()
            self.turn_cnt = 0

    def share(self) -> tp.Dict[str, tp.Any]:
        """
        share knowledge retriever model.
        """
        shared = super().share()
        shared['knowledge_retriever'] = self.knowledge_agent.share()
        return shared
Пример #6
0
def main():
    """
    This task consists of an MTurk agent evaluating a wizard model.

    They are assigned a topic and asked to chat.
    """
    start_time = datetime.datetime.today().strftime('%Y-%m-%d-%H-%M')
    argparser = ParlaiParser(False, add_model_args=True)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('-mt',
                           '--max-turns',
                           default=10,
                           type=int,
                           help='maximal number of chat turns')
    argparser.add_argument(
        '--max-resp-time',
        default=240,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument(
        '--generative-setup',
        default=False,
        help='mimic setup for the WoW generator task (use knowledge token)',
    )
    argparser.add_argument(
        '--max-choice-time',
        type=int,
        default=300,
        help='time limit for turker'
        'choosing the topic',
    )
    argparser.add_argument(
        '--ag-shutdown-time',
        default=120,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument('-rt',
                           '--range-turn',
                           default='3,5',
                           help='sample range of number of turns')
    argparser.add_argument(
        '--human-eval',
        type='bool',
        default=False,
        help='human vs human eval, no models involved',
    )
    argparser.add_argument(
        '--auto-approve-delay',
        type=int,
        default=3600 * 24 * 1,
        help='how long to wait for auto approval',
    )
    argparser.add_argument(
        '--only-masters',
        type='bool',
        default=False,
        help='Set to true to use only master turks for '
        'this test eval',
    )
    argparser.add_argument(
        '--unique-workers',
        type='bool',
        default=False,
        help='Each worker must be unique',
    )
    argparser.add_argument(
        '--prepend-gold-knowledge',
        type='bool',
        default=False,
        help='Add the gold knowledge to the input text from the human for '
        'the model observation.',
    )
    argparser.add_argument(
        '--mturk-log',
        type=str,
        default='data/mturklogs/wizard_of_wikipedia/{}.log'.format(start_time),
    )

    def inject_override(opt, override_dict):
        opt['override'] = override_dict
        for k, v in override_dict.items():
            opt[k] = v

    def get_logger(opt):
        fmt = '%(asctime)s: [ %(message)s ]'
        logfile = None
        if 'mturk_log' in opt:
            logfile = opt['mturk_log']
            if not os.path.isdir(os.path.dirname(logfile)):
                os.makedirs(os.path.dirname(logfile))
        logger = ParlaiLogger(
            "mturk_woz",
            console_level=INFO,
            file_level=INFO,
            console_format=fmt,
            file_format=fmt,
            filename=logfile,
        )
        logger.info('COMMAND: %s' % ' '.join(sys.argv))
        logger.info('-' * 100)
        logger.info('CONFIG:\n%s' % json.dumps(opt, indent=4, sort_keys=True))

        return logger

    # MODEL CONFIG
    # NOTE: please edit this to test your own models
    config = {
        'model_file': 'models:wizard_of_wikipedia/end2end_generator/model',
        'generative_setup': True,
        'prepend_gold_knowledge': True,
        'model': 'projects:wizard_of_wikipedia:generator',
        'beam_size': 10,  # add inference arguments here
        'inference': 'beam',
        'beam_block_ngram': 3,
    }

    # add dialogue model args
    argparser.add_model_subargs(config['model'])
    # add knowledge retriever args
    argparser.add_model_subargs(
        'projects:wizard_of_wikipedia:knowledge_retriever')
    start_opt = argparser.parse_args()

    inject_override(start_opt, config)

    if not start_opt.get('human_eval'):
        # make dialogue responder model
        bot = create_agent(start_opt)
        shared_bot_params = bot.share()
        # make knowledge retriever
        knowledge_opt = {
            'model': 'projects:wizard_of_wikipedia:knowledge_retriever',
            'add_token_knowledge': not start_opt['generative_setup'],
            'datapath': start_opt['datapath'],
            'interactive_mode':
            False,  # interactive mode automatically sets fixed cands
        }
        # add all existing opt to the knowledge opt, without overriding
        # the above arguments
        for k, v in start_opt.items():
            if k not in knowledge_opt and k not in config:
                knowledge_opt[k] = v

        knowledge_agent = KnowledgeRetrieverAgent(knowledge_opt)
        knowledge_agent_shared_params = knowledge_agent.share()

    else:
        shared_bot_params = None

    if not start_opt['human_eval']:
        get_logger(bot.opt)
    else:
        get_logger(start_opt)

    if start_opt['human_eval']:
        folder_name = 'human_eval-{}'.format(start_time)
    else:
        folder_name = '{}-{}'.format(start_opt['model'], start_time)

    start_opt['task'] = os.path.basename(
        os.path.dirname(os.path.abspath(__file__)))
    if 'data_path' not in start_opt:
        start_opt['data_path'] = os.path.join(os.getcwd(), 'data',
                                              'wizard_eval', folder_name)
    start_opt.update(task_config)

    if not start_opt.get('human_eval'):
        mturk_agent_ids = ['PERSON_1']
    else:
        mturk_agent_ids = ['PERSON_1', 'PERSON_2']

    mturk_manager = MTurkManager(opt=start_opt,
                                 mturk_agent_ids=mturk_agent_ids)

    topics_generator = TopicsGenerator(start_opt)
    directory_path = os.path.dirname(os.path.abspath(__file__))
    mturk_manager.setup_server(task_directory_path=directory_path)
    worker_roles = {}
    connect_counter = AttrDict(value=0)

    try:
        mturk_manager.start_new_run()
        agent_qualifications = []
        if not start_opt['is_sandbox']:
            # assign qualifications
            if start_opt['only_masters']:
                agent_qualifications.append(MASTER_QUALIF)
            if start_opt['unique_workers']:
                qual_name = 'UniqueChatEval'
                qual_desc = (
                    'Qualification to ensure each worker completes a maximum '
                    'of one of these chat/eval HITs')
                qualification_id = mturk_utils.find_or_create_qualification(
                    qual_name, qual_desc, False)
                print('Created qualification: ', qualification_id)
                UNIQUE_QUALIF = {
                    'QualificationTypeId': qualification_id,
                    'Comparator': 'DoesNotExist',
                    'RequiredToPreview': True,
                }
                start_opt['unique_qualif_id'] = qualification_id
                agent_qualifications.append(UNIQUE_QUALIF)
        mturk_manager.create_hits(qualifications=agent_qualifications)

        def run_onboard(worker):
            if start_opt['human_eval']:
                role = mturk_agent_ids[connect_counter.value %
                                       len(mturk_agent_ids)]
                connect_counter.value += 1
                worker_roles[worker.worker_id] = role
            else:
                role = 'PERSON_1'
            worker.topics_generator = topics_generator
            world = TopicChooseWorld(start_opt, worker, role=role)
            world.parley()
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_single_worker_eligibility(worker):
            return True

        def check_multiple_workers_eligibility(workers):
            valid_workers = {}
            for worker in workers:
                worker_id = worker.worker_id
                if worker_id not in worker_roles:
                    print('Something went wrong')
                    continue
                role = worker_roles[worker_id]
                if role not in valid_workers:
                    valid_workers[role] = worker
                if len(valid_workers) == 2:
                    break
            return valid_workers.values() if len(valid_workers) == 2 else []

        if not start_opt['human_eval']:
            eligibility_function = {
                'func': check_single_worker_eligibility,
                'multiple': False,
            }
        else:
            eligibility_function = {
                'func': check_multiple_workers_eligibility,
                'multiple': True,
            }

        def assign_worker_roles(workers):
            if start_opt['human_eval']:
                for worker in workers:
                    worker.id = worker_roles[worker.worker_id]
            else:
                for index, worker in enumerate(workers):
                    worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            conv_idx = mturk_manager.conversation_index
            world = WizardEval(
                opt=start_opt,
                agents=workers,
                range_turn=[
                    int(s) for s in start_opt['range_turn'].split(',')
                ],
                max_turn=start_opt['max_turns'],
                max_resp_time=start_opt['max_resp_time'],
                model_agent_opt=shared_bot_params,
                world_tag='conversation t_{}'.format(conv_idx),
                agent_timeout_shutdown=opt['ag_shutdown_time'],
                knowledge_retriever_opt=knowledge_agent_shared_params,
            )
            while not world.episode_done():
                world.parley()
            world.save_data()

            world.shutdown()
            gc.collect()

        mturk_manager.start_task(
            eligibility_function=eligibility_function,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation,
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()