コード例 #1
0
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        self.id = 'SparseTfidfRetrieverAgent'

        # we'll need to build the tfid if it's not already
        rebuild_tfidf = not os.path.exists(opt['retriever_tfidfpath'] + '.npz')
        # sets up db
        if not os.path.exists(opt['retriever_dbpath']):
            build_db(opt, opt.get('retriever_task'), opt['retriever_dbpath'],
                     context_length=opt.get('context_length', -1),
                     include_labels=opt.get('include_labels', True))
            # we rebuilt the db, so need to force rebuilding of tfidf
            rebuild_tfidf = True

        self.tfidf_args = AttrDict({
            'db_path': opt['retriever_dbpath'],
            'out_dir': opt['retriever_tfidfpath'],
            'ngram': opt['retriever_ngram'],
            'hash_size': opt['retriever_hashsize'],
            'tokenizer': opt['retriever_tokenizer'],
            'num_workers': opt['retriever_numworkers'],
        })

        if rebuild_tfidf:
            # build tfidf if we built the db or if it doesn't exist
            build_tfidf(self.tfidf_args)

        self.db = DocDB(db_path=opt['retriever_dbpath'])
        self.ranker = TfidfDocRanker(
            tfidf_path=opt['retriever_tfidfpath'], strict=False)
        self.ret_mode = opt['retriever_mode']
        self.cands_hash = {}  # cache for candidates
        self.triples_to_add = []  # in case we want to add more entries
コード例 #2
0
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        self.id = 'SparseTfidfRetrieverAgent'
        if not opt.get('model_file') or opt['model_file'] == '':
            raise RuntimeError('Must set --model_file')

        #print(opt)
        all_mf = opt['extra_mf'].split(',')
        all_mf = [opt['model_file']] + all_mf
        print(all_mf)
        self.rankers = []
        opt['retriever_dbpath'] = opt['model_file'] + '.db'
        opt['retriever_tfidfpath'] = opt['model_file'] + '.tfidf'

        self.db_path = opt['retriever_dbpath']
        self.tfidf_path = opt['retriever_tfidfpath']

        self.tfidf_args = AttrDict({
            'db_path': opt['retriever_dbpath'],
            'out_dir': opt['retriever_tfidfpath'],
            'ngram': opt['retriever_ngram'],
            'hash_size': opt['retriever_hashsize'],
            'tokenizer': opt['retriever_tokenizer'],
            'num_workers': opt['retriever_numworkers'],
        })

        if not os.path.exists(self.db_path):
            conn = sqlite3.connect(self.db_path)
            c = conn.cursor()
            c.execute('CREATE TABLE documents '
                      '(id INTEGER PRIMARY KEY, text, value);')
            conn.commit()
            conn.close()
        self.db = DocDB(db_path=opt['retriever_dbpath'])
        for i in range(len(all_mf)):
            if os.path.exists(all_mf[i] + '.tfidf.npz'):
                self.rankers.append(
                    TfidfDocRanker(all_mf[i] + '.tfidf', strict=False))
        '''if os.path.exists(self.tfidf_path + '.npz'):
            self.ranker = TfidfDocRanker(
                tfidf_path=opt['retriever_tfidfpath'], strict=False)
        
        clinical_path = '/tmp/clinical_tfidf'
        if os.path.exists(clinical_path+'.tfidf.npz'):
             self.ranker2 = TfidfDocRanker(
                tfidf_path=clinical_path+'.tfidf', strict=False)'''

        self.ret_mode = opt['retriever_mode']
        self.cands_hash = {}  # cache for candidates
        self.triples_to_add = []  # in case we want to add more entries

        clen = opt.get('tfidf_context_length', -1)
        self.context_length = clen if clen >= 0 else None
        self.include_labels = opt.get('tfidf_include_labels', True)
        self.reset()
コード例 #3
0
def main():
    """This task consists of an MTurk agent evaluating a wizard model. They
    are assigned a topic and asked to chat.
    """
    start_time = datetime.datetime.today().strftime('%Y-%m-%d-%H-%M')
    argparser = ParlaiParser(False, add_model_args=True)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('-mt',
                           '--max-turns',
                           default=10,
                           type=int,
                           help='maximal number of chat turns')
    argparser.add_argument('--max-resp-time',
                           default=240,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('--max-choice-time',
                           type=int,
                           default=300,
                           help='time limit for turker'
                           'choosing the topic')
    argparser.add_argument('--ag-shutdown-time',
                           default=120,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('-rt',
                           '--range-turn',
                           default='3,5',
                           help='sample range of number of turns')
    argparser.add_argument('--human-eval',
                           type='bool',
                           default=False,
                           help='human vs human eval, no models involved')
    argparser.add_argument('--auto-approve-delay',
                           type=int,
                           default=3600 * 24 * 1,
                           help='how long to wait for auto approval')
    argparser.add_argument('--only-masters',
                           type='bool',
                           default=False,
                           help='Set to true to use only master turks for '
                           'this test eval')
    argparser.add_argument('--unique-workers',
                           type='bool',
                           default=False,
                           help='Each worker must be unique')
    argparser.add_argument('--mturk-log',
                           type=str,
                           default='data/mturklogs/{}.log'.format(start_time))

    def inject_override(opt, override_dict):
        opt['override'] = override_dict
        for k, v in override_dict.items():
            opt[k] = v

    def get_logger(opt):
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)

        fmt = logging.Formatter('%(asctime)s: [ %(message)s ]',
                                '%m/%d/%Y %I:%M:%S %p')
        console = logging.StreamHandler()
        console.setFormatter(fmt)
        logger.addHandler(console)
        if 'mturk_log' in opt:
            logfile = logging.FileHandler(opt['mturk_log'], 'a')
            logfile.setFormatter(fmt)
            logger.addHandler(logfile)
        logger.info('COMMAND: %s' % ' '.join(sys.argv))
        logger.info('-' * 100)
        logger.info('CONFIG:\n%s' % json.dumps(opt, indent=4, sort_keys=True))

        return logger

    # MODEL CONFIG
    # NOTE: please edit this to test your own models
    config = {
        'model':
        'projects:wizard_of_wikipedia:interactive_retrieval',
        'retriever_model_file':
        'models:wikipedia_full/tfidf_retriever/model',
        'responder_model_file':
        'models:wizard_of_wikipedia/full_dialogue_retrieval_model/model',
    }

    argparser.add_model_subargs(config['model'])  # add model args to opt
    start_opt = argparser.parse_args()

    inject_override(start_opt, config)

    if not start_opt.get('human_eval'):
        bot = create_agent(start_opt)
        shared_bot_params = bot.share()
    else:
        shared_bot_params = None

    if not start_opt['human_eval']:
        get_logger(bot.opt)
    else:
        get_logger(start_opt)

    if start_opt['human_eval']:
        folder_name = 'human_eval-{}'.format(start_time)
    else:
        folder_name = '{}-{}'.format(start_opt['model'], start_time)

    start_opt['task'] = os.path.basename(
        os.path.dirname(os.path.abspath(__file__)))
    if 'data_path' not in start_opt:
        start_opt['data_path'] = os.path.join(os.getcwd(), 'data',
                                              'wizard_eval', folder_name)
    start_opt.update(task_config)

    if not start_opt.get('human_eval'):
        mturk_agent_ids = ['PERSON_1']
    else:
        mturk_agent_ids = ['PERSON_1', 'PERSON_2']

    mturk_manager = MTurkManager(opt=start_opt,
                                 mturk_agent_ids=mturk_agent_ids)

    topics_generator = TopicsGenerator(start_opt)
    directory_path = os.path.dirname(os.path.abspath(__file__))
    mturk_manager.setup_server(task_directory_path=directory_path)
    worker_roles = {}
    connect_counter = AttrDict(value=0)

    try:
        mturk_manager.start_new_run()
        agent_qualifications = []
        if not start_opt['is_sandbox']:
            # assign qualifications
            if start_opt['only_masters']:
                agent_qualifications.append(MASTER_QUALIF)
            if start_opt['unique_workers']:
                qual_name = 'UniqueChatEval'
                qual_desc = (
                    'Qualification to ensure each worker completes a maximum '
                    'of one of these chat/eval HITs')
                qualification_id = \
                    mturk_utils.find_or_create_qualification(qual_name, qual_desc,
                                                             False)
                print('Created qualification: ', qualification_id)
                UNIQUE_QUALIF = {
                    'QualificationTypeId': qualification_id,
                    'Comparator': 'DoesNotExist',
                    'RequiredToPreview': True
                }
                start_opt['unique_qualif_id'] = qualification_id
                agent_qualifications.append(UNIQUE_QUALIF)
        mturk_manager.create_hits(qualifications=agent_qualifications)

        def run_onboard(worker):
            if start_opt['human_eval']:
                role = mturk_agent_ids[connect_counter.value %
                                       len(mturk_agent_ids)]
                connect_counter.value += 1
                worker_roles[worker.worker_id] = role
            else:
                role = 'PERSON_1'
            worker.topics_generator = topics_generator
            world = TopicChooseWorld(start_opt, worker, role=role)
            world.parley()
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_single_worker_eligibility(worker):
            return True

        def check_multiple_workers_eligibility(workers):
            valid_workers = {}
            for worker in workers:
                worker_id = worker.worker_id
                if worker_id not in worker_roles:
                    print('Something went wrong')
                    continue
                role = worker_roles[worker_id]
                if role not in valid_workers:
                    valid_workers[role] = worker
                if len(valid_workers) == 2:
                    break
            return valid_workers.values() if len(valid_workers) == 2 else []

        if not start_opt['human_eval']:
            eligibility_function = {
                'func': check_single_worker_eligibility,
                'multiple': False,
            }
        else:
            eligibility_function = {
                'func': check_multiple_workers_eligibility,
                'multiple': True,
            }

        def assign_worker_roles(workers):
            if start_opt['human_eval']:
                for worker in workers:
                    worker.id = worker_roles[worker.worker_id]
            else:
                for index, worker in enumerate(workers):
                    worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            conv_idx = mturk_manager.conversation_index
            world = WizardEval(
                opt=start_opt,
                agents=workers,
                range_turn=[
                    int(s) for s in start_opt['range_turn'].split(',')
                ],
                max_turn=start_opt['max_turns'],
                max_resp_time=start_opt['max_resp_time'],
                model_agent_opt=shared_bot_params,
                world_tag='conversation t_{}'.format(conv_idx),
                agent_timeout_shutdown=opt['ag_shutdown_time'],
            )
            while not world.episode_done():
                world.parley()
            world.save_data()

            world.shutdown()
            gc.collect()

        mturk_manager.start_task(eligibility_function=eligibility_function,
                                 assign_role_function=assign_worker_roles,
                                 task_function=run_conversation)

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
コード例 #4
0
ファイル: run.py プロジェクト: tonydeep/dl4dial-mt-beam
def main():
    """
        Wizard of Wikipedia Data Collection Task.

        The task involves two people holding a conversation. One dialog partner
        chooses a topic to discuss, and then dialog proceeds.

        One partner is the Wizard, who has access to retrieved external
        information conditioned on the last two utterances, as well as
        information regarding the chosen topic.

        The other partner is the Apprentice, who assumes the role of someone
        eager to learn about the chosen topic.
    """
    argparser = ParlaiParser(False, False)
    DictionaryAgent.add_cmdline_args(argparser)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('-min_t',
                           '--min_turns',
                           default=3,
                           type=int,
                           help='minimum number of turns')
    argparser.add_argument('-max_t',
                           '--max_turns',
                           default=5,
                           type=int,
                           help='maximal number of chat turns')
    argparser.add_argument('-mx_rsp_time',
                           '--max_resp_time',
                           default=120,
                           type=int,
                           help='time limit for entering a dialog message')
    argparser.add_argument('-mx_onb_time',
                           '--max_onboard_time',
                           type=int,
                           default=300,
                           help='time limit for turker'
                           'in onboarding')
    argparser.add_argument('--persona-type',
                           default='both',
                           type=str,
                           choices=['both', 'self', 'other'],
                           help='Which personas to load from personachat')
    argparser.add_argument('--auto-approve-delay',
                           type=int,
                           default=3600 * 24 * 1,
                           help='how long to wait for  \
                           auto approval')
    argparser.add_argument(
        '--word-overlap-threshold',
        type=int,
        default=2,
        help='How much word overlap we want between message \
                           and checked sentence')
    argparser.add_argument(
        '--num-good-sentence-threshold',
        type=int,
        default=2,
        help='How many good sentences with sufficient overlap \
                           are necessary for turker to be considered good.')
    argparser.add_argument('--num-passages-retrieved',
                           type=int,
                           default=7,
                           help='How many passages to retrieve per dialog \
                           message')

    opt = argparser.parse_args()
    directory_path = os.path.dirname(os.path.abspath(__file__))
    opt['task'] = os.path.basename(directory_path)
    if 'data_path' not in opt:
        opt['data_path'] = os.getcwd() + '/data/' + opt['task']
        opt['current_working_dir'] = os.getcwd()
    opt.update(task_config)

    mturk_agent_ids = [APPRENTICE, WIZARD]
    opt['min_messages'] = 2

    mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=mturk_agent_ids)
    setup_personas_with_wiki_links(opt)
    ir_agent, task = setup_retriever(opt)
    persona_generator = PersonasGenerator(opt)
    wiki_title_to_passage = setup_title_to_passage(opt)
    mturk_manager.setup_server(task_directory_path=directory_path)
    worker_roles = {}
    connect_counter = AttrDict(value=0)

    try:
        mturk_manager.start_new_run()
        if not opt['is_sandbox']:
            with open(os.path.join(opt['current_working_dir'],
                                   'mtdont.txt')) as f:
                lines = [l.replace('\n', '') for l in f.readlines()]
                for w in lines:
                    mturk_manager.soft_block_worker(w)

        def run_onboard(worker):
            role = mturk_agent_ids[connect_counter.value %
                                   len(mturk_agent_ids)]
            connect_counter.value += 1
            worker_roles[worker.worker_id] = role
            worker.persona_generator = persona_generator
            world = RoleOnboardWorld(opt, worker, role)
            world.parley()
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()
        mturk_manager.create_hits()

        def check_workers_eligibility(workers):
            if opt['is_sandbox']:
                return workers
            valid_workers = {}
            for worker in workers:
                worker_id = worker.worker_id
                if worker_id not in worker_roles:
                    '''Something went wrong...'''
                    continue
                role = worker_roles[worker_id]
                if role not in valid_workers:
                    valid_workers[role] = worker
                if len(valid_workers) == 2:
                    break
            return valid_workers.values() if len(valid_workers) == 2 else []

        eligibility_function = {
            'func': check_workers_eligibility,
            'multiple': True,
        }

        def assign_worker_roles(workers):
            if opt['is_sandbox']:
                for i, worker in enumerate(workers):
                    worker.id = mturk_agent_ids[i % len(mturk_agent_ids)]
            else:
                for worker in workers:
                    worker.id = worker_roles[worker.worker_id]

        def run_conversation(mturk_manager, opt, workers):
            agents = workers[:]
            if not opt['is_sandbox']:
                for agent in agents:
                    worker_roles.pop(agent.worker_id)
            conv_idx = mturk_manager.conversation_index
            world = MTurkWizardOfWikipediaWorld(
                opt,
                agents=agents,
                world_tag='conversation t_{}'.format(conv_idx),
                ir_agent=ir_agent,
                wiki_title_to_passage=wiki_title_to_passage,
                task=task)
            world.reset_random()
            while not world.episode_done():
                world.parley()
            world.save_data()
            if (world.convo_finished and not world.good_wiz
                    and not opt['is_sandbox']):
                mturk_manager.soft_block_worker(world.wizard_worker)
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(eligibility_function=eligibility_function,
                                 assign_role_function=assign_worker_roles,
                                 task_function=run_conversation)

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()