Example #1
0
def eval_model(opt, parser, printargs=True):
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)
    # Show arguments after loading model
    parser.opt = agent.opt
    if (printargs):
        parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    for _ in range(int(opt['num_examples'])):
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            print(str(int(tot_time)) + "s elapsed: " + str(world.report()))
            log_time.reset()
        if world.epoch_done():
            print("EPOCH DONE")
            break
    print(world.report())
    world.shutdown()
Example #2
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    RemoteAgentAgent.add_cmdline_args(parser)
    opt = parser.parse_args()

    remote = RemoteAgentAgent(opt)
    if opt.get('task'):
        world = create_task(opt, [remote])
    else:
        if opt.get('model'):
            local = create_agent(opt)
        else:
            local = LocalHumanAgent(opt)
        # the remote-host goes **second**
        agents = [local, remote] if not opt['remote_host'] else [remote, local]
        world = DialogPartnerWorld(opt, agents)


    # Talk to the remote agent
    with world:
        while True:
            world.parley()
            print(world.display())
Example #3
0
 def __init__(self, parser):
     opt = parser.parse_args()
     # Possibly build a dictionary (not all models do this).
     if opt['dict_build_first'] and 'dict_file' in opt:
         if opt['dict_file'] is None and opt.get('model_file'):
             opt['dict_file'] = opt['model_file'] + '.dict'
         print("[ building dictionary first... ]")
         build_dict(opt)
     # Create model and assign it to the specified task
     self.agent = create_agent(opt)
     self.world = create_task(opt, self.agent)
     self.train_time = Timer()
     self.validate_time = Timer()
     self.log_time = Timer()
     self.save_time = Timer()
     print('[ training... ]')
     self.parleys = 0
     self.max_num_epochs = opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')
     self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')
     self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')
     self.val_every_n_secs = opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')
     self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')
     self.best_valid = 0
     self.impatience = 0
     self.saved = False
     self.valid_world = None
     self.opt = opt
Example #4
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=10)
    opt = parser.parse_args()

    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    with world:
        for k in range(int(opt['num_examples'])):
            world.parley()
            print(world.display() + "\n~~")
            if world.epoch_done():
                print("EPOCH DONE")
                break
Example #5
0
    def __init__(self, args=None, **kwargs):
        """Initializes the predictor, setting up opt automatically if necessary.

        Args is expected to be in the same format as sys.argv: e.g. a list in
        the form ['--model', 'seq2seq', '-hs', 128, '-lr', 0.5].

        kwargs is interpreted by appending '--' to it and replacing underscores
        with hyphens, so 'dict_file=/tmp/dict.tsv' would be interpreted as
        '--dict-file /tmp/dict.tsv'.
        """
        from parlai.core.params import ParlaiParser
        from parlai.core.agents import create_agent

        if args is None:
            args = []
        for k, v in kwargs.items():
            args.append('--' + str(k).replace('_', '-'))
            args.append(str(v))
        parser = ParlaiParser(True, True, model_argv=args)
        self.opt = parser.parse_args(args)
        self.agent = create_agent(self.opt)
Example #6
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    opt = parser.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    print(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs:
    while True:
        world.parley()
        if opt['display_examples']:
            print("---")
            print(world.display() + "\n~~")
        if world.epoch_done():
            print("EPOCH DONE")
            break
Example #7
0
    def train(self):
        opt = self.opt
        world = self.world
        with world:
            while True:
                world.parley()
                self.parleys += 1

                if world.get_total_epochs() >= self.max_num_epochs:
                    self.log()
                    print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
                        self.max_num_epochs, self.train_time.time()))
                    break
                if self.train_time.time() > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(self.train_time.time()))
                    break
                if self.log_time.time() > self.log_every_n_secs:
                    self.log()
                if self.validate_time.time() > self.val_every_n_secs:
                    stop_training = self.validate()
                    if stop_training:
                        break
                if self.save_time.time() > self.save_every_n_secs:
                    print("[ saving model: " + opt['model_file'] + " ]")
                    world.save_agents()
                    self.save_time.reset()

        if not self.saved:
            # save agent
            world.save_agents()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        _rep, wrld = run_eval(self.agent, opt, 'valid', write_log=True)
        wrld.shutdown()  # may need to shut down threads, remote connections
        _rep, wrld = run_eval(self.agent, opt, 'test', write_log=True)
        wrld.shutdown()  # may need to shut down threads, remote connections
Example #8
0
        def run_conversation(mturk_manager, opt, workers):
            agents = workers[:]

            # Create a local agent
            if not opt['two_mturk_agents']:
                if 'model' in opt:
                    local_agent = create_agent(opt)
                else:
                    local_agent = LocalHumanAgent(opt=None)

                local_agent.id = local_agent_1_id
                agents.append(local_agent)

            opt["batchindex"] = mturk_manager.started_conversations

            world = MTurkDealNoDealDialogWorld(
                opt=opt,
                agents=agents
            )

            while not world.episode_done():
                world.parley()

            world.shutdown()
Example #9
0
def main():
    random.seed(42)

    # Get command line arguments
    parser = ParlaiParser(True, True)
    parser.add_argument('-n', '--num-examples', default=100000000)
    parser.add_argument('-d', '--display-examples', type='bool', default=False)
    parser.set_defaults(datatype='valid')
    opt = parser.parse_args()
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs:
    for k in range(int(opt['num_examples'])):
        world.parley()
        print("---")
        if opt['display_examples']:
            print(world.display() + "\n~~")
        print(world.report())
        if world.epoch_done():
            print("EPOCH DONE")
            break
    world.shutdown()
Example #10
0
def main():
    """
    This task consists of an MTurk agent evaluating a wizard model.

    They are assigned a topic and asked to chat.
    """
    start_time = datetime.datetime.today().strftime('%Y-%m-%d-%H-%M')
    argparser = ParlaiParser(False, add_model_args=True)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument('-mt',
                           '--max-turns',
                           default=10,
                           type=int,
                           help='maximal number of chat turns')
    argparser.add_argument(
        '--max-resp-time',
        default=240,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument(
        '--max-choice-time',
        type=int,
        default=300,
        help='time limit for turker'
        'choosing the topic',
    )
    argparser.add_argument(
        '--ag-shutdown-time',
        default=120,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument('-rt',
                           '--range-turn',
                           default='3,5',
                           help='sample range of number of turns')
    argparser.add_argument(
        '--human-eval',
        type='bool',
        default=False,
        help='human vs human eval, no models involved',
    )
    argparser.add_argument(
        '--auto-approve-delay',
        type=int,
        default=3600 * 24 * 1,
        help='how long to wait for auto approval',
    )
    argparser.add_argument(
        '--only-masters',
        type='bool',
        default=False,
        help='Set to true to use only master turks for '
        'this test eval',
    )
    argparser.add_argument(
        '--unique-workers',
        type='bool',
        default=False,
        help='Each worker must be unique',
    )
    argparser.add_argument(
        '--mturk-log',
        type=str,
        default='data/mturklogs/wizard_of_wikipedia/{}.log'.format(start_time),
    )

    def inject_override(opt, override_dict):
        opt['override'] = override_dict
        for k, v in override_dict.items():
            opt[k] = v

    def get_logger(opt):
        fmt = '%(asctime)s: [ %(message)s ]'
        logfile = None
        if 'mturk_log' in opt:
            logfile = opt['mturk_log']
            if not os.path.isdir(os.path.dirname(logfile)):
                os.makedirs(os.path.dirname(logfile))
        logger = ParlaiLogger(
            "mturk_woz",
            console_level=INFO,
            file_level=INFO,
            console_format=fmt,
            file_format=fmt,
            filename=logfile,
        )
        logger.info('COMMAND: %s' % ' '.join(sys.argv))
        logger.info('-' * 100)
        logger.info('CONFIG:\n%s' % json.dumps(opt, indent=4, sort_keys=True))

        return logger

    # MODEL CONFIG
    # NOTE: please edit this to test your own models
    config = {
        'model':
        'projects:wizard_of_wikipedia:interactive_retrieval',
        'retriever_model_file':
        'models:wikipedia_full/tfidf_retriever/model',
        'responder_model_file':
        'models:wizard_of_wikipedia/full_dialogue_retrieval_model/model',
    }

    argparser.add_model_subargs(config['model'])  # add model args to opt
    start_opt = argparser.parse_args()

    inject_override(start_opt, config)

    if not start_opt.get('human_eval'):
        bot = create_agent(start_opt)
        shared_bot_params = bot.share()
    else:
        shared_bot_params = None

    if not start_opt['human_eval']:
        get_logger(bot.opt)
    else:
        get_logger(start_opt)

    if start_opt['human_eval']:
        folder_name = 'human_eval-{}'.format(start_time)
    else:
        folder_name = '{}-{}'.format(start_opt['model'], start_time)

    start_opt['task'] = os.path.basename(
        os.path.dirname(os.path.abspath(__file__)))
    if 'data_path' not in start_opt:
        start_opt['data_path'] = os.path.join(os.getcwd(), 'data',
                                              'wizard_eval', folder_name)
    start_opt.update(task_config)

    if not start_opt.get('human_eval'):
        mturk_agent_ids = ['PERSON_1']
    else:
        mturk_agent_ids = ['PERSON_1', 'PERSON_2']

    mturk_manager = MTurkManager(opt=start_opt,
                                 mturk_agent_ids=mturk_agent_ids)

    topics_generator = TopicsGenerator(start_opt)
    directory_path = os.path.dirname(os.path.abspath(__file__))
    mturk_manager.setup_server(task_directory_path=directory_path)
    worker_roles = {}
    connect_counter = AttrDict(value=0)

    try:
        mturk_manager.start_new_run()
        agent_qualifications = []
        if not start_opt['is_sandbox']:
            # assign qualifications
            if start_opt['only_masters']:
                agent_qualifications.append(MASTER_QUALIF)
            if start_opt['unique_workers']:
                qual_name = 'UniqueChatEval'
                qual_desc = (
                    'Qualification to ensure each worker completes a maximum '
                    'of one of these chat/eval HITs')
                qualification_id = mturk_utils.find_or_create_qualification(
                    qual_name, qual_desc, False)
                print('Created qualification: ', qualification_id)
                UNIQUE_QUALIF = {
                    'QualificationTypeId': qualification_id,
                    'Comparator': 'DoesNotExist',
                    'RequiredToPreview': True,
                }
                start_opt['unique_qualif_id'] = qualification_id
                agent_qualifications.append(UNIQUE_QUALIF)
        mturk_manager.create_hits(qualifications=agent_qualifications)

        def run_onboard(worker):
            if start_opt['human_eval']:
                role = mturk_agent_ids[connect_counter.value %
                                       len(mturk_agent_ids)]
                connect_counter.value += 1
                worker_roles[worker.worker_id] = role
            else:
                role = 'PERSON_1'
            worker.topics_generator = topics_generator
            world = TopicChooseWorld(start_opt, worker, role=role)
            world.parley()
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_single_worker_eligibility(worker):
            return True

        def check_multiple_workers_eligibility(workers):
            valid_workers = {}
            for worker in workers:
                worker_id = worker.worker_id
                if worker_id not in worker_roles:
                    print('Something went wrong')
                    continue
                role = worker_roles[worker_id]
                if role not in valid_workers:
                    valid_workers[role] = worker
                if len(valid_workers) == 2:
                    break
            return valid_workers.values() if len(valid_workers) == 2 else []

        if not start_opt['human_eval']:
            eligibility_function = {
                'func': check_single_worker_eligibility,
                'multiple': False,
            }
        else:
            eligibility_function = {
                'func': check_multiple_workers_eligibility,
                'multiple': True,
            }

        def assign_worker_roles(workers):
            if start_opt['human_eval']:
                for worker in workers:
                    worker.id = worker_roles[worker.worker_id]
            else:
                for index, worker in enumerate(workers):
                    worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        def run_conversation(mturk_manager, opt, workers):
            conv_idx = mturk_manager.conversation_index
            world = WizardEval(
                opt=start_opt,
                agents=workers,
                range_turn=[
                    int(s) for s in start_opt['range_turn'].split(',')
                ],
                max_turn=start_opt['max_turns'],
                max_resp_time=start_opt['max_resp_time'],
                model_agent_opt=shared_bot_params,
                world_tag='conversation t_{}'.format(conv_idx),
                agent_timeout_shutdown=opt['ag_shutdown_time'],
            )
            while not world.episode_done():
                world.parley()
            world.save_data()

            world.shutdown()
            gc.collect()

        mturk_manager.start_task(
            eligibility_function=eligibility_function,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation,
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Example #11
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if (opt['load_from_checkpoint'] and opt.get('model_file')
                and PathManager.exists(opt['model_file'] + '.checkpoint')):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.')
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            logging.info("building dictionary first...")
            build_dict(opt, skip_if_built=True)

        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.agent.opt.log()
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()

        self.parleys = 0
        self.max_num_epochs = (opt['num_epochs']
                               if opt['num_epochs'] > 0 else float('inf'))
        self.max_train_time = (opt['max_train_time']
                               if opt['max_train_time'] > 0 else float('inf'))
        self.log_every_n_secs = (opt['log_every_n_secs'] if
                                 opt['log_every_n_secs'] > 0 else float('inf'))
        self.val_every_n_secs = (opt['validation_every_n_secs']
                                 if opt['validation_every_n_secs'] > 0 else
                                 float('inf'))
        self.save_every_n_secs = (opt['save_every_n_secs']
                                  if opt['save_every_n_secs'] > 0 else
                                  float('inf'))
        self.val_every_n_epochs = (opt['validation_every_n_epochs']
                                   if opt['validation_every_n_epochs'] > 0 else
                                   float('inf'))

        # smart defaults for --validation-metric-mode
        if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
            opt['validation_metric_mode'] = 'min'
        elif opt['validation_metric'] in {
                'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu'
        }:
            opt['validation_metric_mode'] = 'max'
        if opt.get('validation_metric_mode') is None:
            opt['validation_metric_mode'] = 'max'

        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.train_reports = []
        self.valid_reports = []
        self.best_valid = None

        self.impatience = 0
        self.saved = False
        self.valid_worlds = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if opt.get('model_file') and PathManager.exists(opt['model_file'] +
                                                        trainstats_suffix):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self.parleys = obj.get('parleys', 0)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)
                self.valid_reports = obj.get('valid_reports', [])
                self.train_reports = obj.get('train_reports', [])
                if 'best_valid' in obj:
                    self.best_valid = obj['best_valid']
                else:
                    # old method
                    if opt.get('model_file') and PathManager.exists(
                            opt['model_file'] + '.best_valid'):
                        with PathManager.open(
                                opt['model_file'] + ".best_valid", 'r') as f:
                            x = f.readline()
                            self.best_valid = float(x)
                            f.close()

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger = TensorboardLogger(opt)
Example #12
0
def eval_ppl(opt, build_dict=None, dict_file=None):
    """Evaluates the the perplexity of a model.

    This uses a dictionary which implements the following functions:
    - tokenize(text): splits string up into list of tokens
    - __in__(text): checks whether dictionary contains a token
    - keys(): returns an iterator over all tokens in the dictionary

    :param opt: option dict
    :param build_dict: function which returns a dictionary class implementing
        the functions above.
    :param dict_file: file used when loading the dictionary class set via the
        "dictionary_class" argument (defaults to
        parlai.core.dict:DictionaryAgent).

    Either build_dict or dict_file must be set (both default to None) to
    determine the dictionary used for the evaluation.
    """
    if not build_dict and not dict_file:
        raise RuntimeError('eval_ppl script either needs a dictionary build '
                           'function or a dictionary file.')

    if build_dict:
        dict_agent = build_dict()
    else:
        dict_opt = copy.deepcopy(opt)
        dict_opt['model'] = dict_opt.get('dictionary_class',
                                         'parlai.core.dict:DictionaryAgent')
        dict_opt['model_file'] = dict_file
        if 'override' in dict_opt:
            del dict_opt['override']
        dict_agent = create_agent(dict_opt, requireModelExists=True)

    # create agents
    agent = create_agent(opt)
    world = create_task(opt, [agent, dict_agent],
                        default_world=PerplexityWorld)

    # set up logging
    log_time = Timer()
    tot_time = 0

    while not world.epoch_done():
        world.parley()  # process an example

        if log_time.time() > 1:  # log every 1 sec
            tot_time += log_time.time()
            report = world.report()
            print('{}s elapsed, {}%% complete, {}'.format(
                int(tot_time),
                round_sigfigs(report['exs'] / world.num_examples() * 100, 3),
                report))
            log_time.reset()
    print('EPOCH DONE')
    tot_time += log_time.time()
    final_report = world.report()
    print('{}s elapsed: {}'.format(int(tot_time), final_report))
    print("============================")
    print("FINAL PPL: " + str(final_report['ppl']))
    if final_report.get('ppl', 0) == float('inf'):
        print('Note: you got inf perplexity. Consider adding (or raising) the '
              'minimum probability you assign to each possible word. If you '
              'assign zero probability to the correct token in the evaluation '
              'vocabulary, you get inf probability immediately.')
Example #13
0
    def train(self):
        if is_distributed():
            warn_once(
                "Distributed training outputs average-per-worker metrics during "
                "training, and may be slightly distorted. Validation/test are "
                "unadulterated.")
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                world.parley()
                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = (
                    self._preempted_epochs +
                    num_workers() * self.world.get_total_epochs())
                exs_per_epoch = self.world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object(
                    (self.train_time.time(), self.log_time.time(),
                     self.validate_time.time()))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time))
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    stop_training = self.validate()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    print("[ saving model checkpoint: {}.checkpoint".format(
                        opt['model_file']))
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_world = _maybe_load_eval_world(self.agent, opt, 'valid')
        v_report = run_eval(valid_world, opt, 'valid', write_log=True)
        test_world = _maybe_load_eval_world(self.agent, opt, 'test')
        t_report = run_eval(test_world, opt, 'test', write_log=True)
        if valid_world:
            valid_world.shutdown()
        if test_world:
            test_world.shutdown()

        return v_report, t_report
Example #14
0
def eval_wordstat(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    """
    random.seed(42)

    # Setup control information
    initialize_control_information(opt)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    batch_size = opt['batchsize']

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    data = {}  # This will be written to the output json file
    data['opt'] = agent.opt  # Save the opt to json

    # Determine the output filename
    if opt['gold_response']:  # Special output file for gold response
        model_dir, _ = os.path.split(opt.get('model_file'))
        outfile = os.path.join(model_dir, 'goldresponse')
        if opt['use_reply'] != 'label':
            raise ValueError(
                'You should set --use-reply label (not --use-reply model) '
                'when measuring goldresponse stats'
            )
    else:
        outfile = "%s.%s.%s.%s" % (
            opt.get('model_file'),
            opt.get('datatype'),
            "use%sreply" % agent.opt['use_reply'],
            "beam%i" % agent.opt['beam_size'],
        )
        if agent.opt['beam_size'] > 1:
            outfile += ".beamminnbest%i" % agent.opt['beam_min_n_best']
        if len(agent.control_settings) > 0:
            outfile += ".setcontrols:" + "_".join(
                [
                    "%s%s" % (c, str(agent.control_settings[c]['set_value']))
                    for c in sorted(agent.control_settings.keys())
                ]
            )
        if agent.opt['beam_reorder'] not in ['none', False]:
            outfile += ".beamreorder_%s" % agent.opt['beam_reorder']
        if len(agent.wd_features) > 0:
            sorted_bfw = sorted(
                list(zip(agent.wd_features, agent.wd_wts)), key=lambda x: x[0]
            )
            outfile += ".WDfeatures:" + "_".join(
                ["%s%s" % (f, str(w)) for f, w in sorted_bfw]
            )
    if opt['num_examples'] != -1:
        outfile += ".numex%i" % opt['num_examples']
    outfile += ".wordstats.json"
    print("\nOutfile: %s\n" % outfile)

    cnt = 0
    word_statistics = {
        'mean_wlength': [],  # list of length (in words) of utterances
        'mean_clength': [],  # list of length (in chars) of utterances
        'freqs_cnt': Counter(),  # Counter for word frequencies, bucketed
        'word_cnt': 0,  # total number of words in all utterances
        'pred_list': [],  # list of generated utterances after applying normalize_answer
        'pure_pred_list': [],  # list of generated utterances
        'context_list': [],  # list of text inputs (persona and conversation history)
    }
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    # This dictionary records all the sentence-level controllable attributes
    # For each attribute, we have a list of all the values
    sent_attrs = {attr: [] for attr in ATTR2SENTSCOREFN.keys()}  # str to list of floats

    # histories will be a list of ConvAI2History objects
    histories = []

    def process_prediction(prediction, word_statistics):
        word_statistics['pred_list'].append(normalize_answer(prediction))
        freqs, _cnt, wlength, clength = get_word_stats(
            prediction, dictionary, bins=bins
        )
        word_statistics['word_cnt'] += _cnt
        word_statistics['mean_wlength'].append(wlength)
        word_statistics['mean_clength'].append(clength)
        word_statistics['freqs_cnt'] += Counter(freqs)
        return word_statistics

    t0 = time.time()
    while not world.epoch_done():
        world.parley()
        # orig eval_wordstat.py handles bsz=1 but for simplicity we assume bsz>1
        assert batch_size != 1
        for w in world.worlds:
            try:
                try:
                    response_act = w.acts[-1]
                    prediction = response_act['text']
                except KeyError:
                    continue
                if opt['gold_response']:
                    # If we're measuring gold response, use eval_label as prediction
                    prediction = w.acts[0]['eval_labels'][0]
                    response_act = {'text': prediction}
                word_statistics['context_list'].append(w.acts[0]['text'])
                word_statistics['pure_pred_list'].append(prediction)
            except IndexError:
                continue
            cnt += 1
            word_statistics = process_prediction(prediction, word_statistics)

            # Compute and record sentence-level attributes
            history = ConvAI2History(w.acts[0]['text'])
            histories.append(history)
            sent_attrs = update_sent_attr_stats(sent_attrs, history, prediction)

        # Periodically log some info
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(), report)
            print(text)

        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    print("Time to process %i examples: %f seconds" % (cnt, time.time() - t0))

    # Compute percent unique
    # Note this is w.r.t. normalized pred_list not original pure_pred_list
    unique_list = []
    cntr = Counter(word_statistics['pred_list'])
    for k, v in cntr.items():
        if v == 1:
            unique_list.append(k)
    unique_percent = len(unique_list) / len(word_statistics['pred_list']) * 100

    # Print a final report
    report = world.report()
    if opt['gold_response']:
        report['ppl'] = 0.0  # For gold responses, overwrite the perplexity
    print(report)

    # Put all information in data dict
    data['unique_percent'] = unique_percent  # percent of all responses that are unique
    data['word_statistics'] = word_statistics  # word stats, as in orig eval_wordstat
    data['report'] = report  # the final report
    data['histories'] = [
        (hist.persona_lines, hist.partner_utts, hist.own_utts) for hist in histories
    ]  # history for each example
    data['sent_attrs'] = sent_attrs  # all sentence attribute values for responses

    # Write data to outfile
    print("Writing to %s..." % outfile)
    with open(outfile, 'w') as f:
        json.dump(data, f)
Example #15
0
def build_data(opt):
    if not opt.get('model', False):
        opt['model'] = 'repeat_label'
    agent = create_agent(opt)
    #If build teacher not specified, we are simply looking for the file
    if not opt.get('pytorch_teacher_task', None):
        df = opt.get('pytorch_datafile')
        # check if the user set a datafile
        if not df:
            raise Exception(
                'Tried to find data but `--pytorch-datafile` is not set')
        # check if the user provided the already built file
        if 'pytorch' not in df:
            df += '.pytorch' + (agent.getID() if opt.get(
                'pytorch_preprocess', True) else '')
        if not os.path.isfile(df):
            raise Exception('Tried to find data but it is not built, please'
                            'specify `--pytorch-teacher-task`')
        else:
            return df

    ordered_opt = copy.deepcopy(opt)
    # we use streaming to build the data
    dt = opt['datatype'].split(':')[0]
    ordered_opt['datatype'] = dt + ':ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['task'] = ordered_opt['pytorch_teacher_task']
    ordered_opt['no_cuda'] = True
    world_data = create_task(ordered_opt, agent)
    teacher = world_data.agents[0]
    agent = world_data.agents[1]

    datafile = None
    if opt.get('pytorch_datafile'):
        datafile = opt.get('pytorch_datafile')
    elif hasattr(teacher, 'datafile') and teacher.datafile:
        datafile = teacher.datafile
    else:
        dpath = os.path.join(opt.get('datapath', '~'), ordered_opt['task'], dt)
        os.makedirs(dpath, exist_ok=True)
        datafile = os.path.join(dpath, 'pytorch_data')
    if not datafile:
        raise Exception(
            'Tried to build data but either `pytorch-teacher-task` does not '
            'have a datafile or `--pytorch-datafile` is not set')

    if isinstance(datafile,
                  collections.Sequence) and not type(datafile) == str:
        datafile = datafile[0] + "".join(
            ["_".join(d.split("/")) for d in datafile[1:]])
    pytorch_datafile = datafile + ".pytorch"
    preprocess = opt.get('pytorch_preprocess', True)
    if preprocess:
        pytorch_datafile += agent.getID()
    if os.path.isfile(pytorch_datafile):
        # Data already built
        print("[ pytorch data already built. ]")
        return pytorch_datafile
    print('----------\n[ setting up pytorch data, saving to {}. ]\n----------'.
          format(pytorch_datafile))

    num_eps = 0
    num_exs = 0
    current = []
    episode_done = False
    include_labels = opt.get('include_labels', True)
    context_length = opt.get('context_length', -1)
    context = deque(maxlen=context_length if context_length > 0 else None)
    logger = ProgressLogger(should_humanize=False, throttle=0.1)
    total_exs = world_data.num_examples()
    # pass examples to dictionary
    with open(pytorch_datafile, 'w') as pytorch_data:
        while num_exs < total_exs:
            while not episode_done:
                action = teacher.act()
                current.append(action)
                episode_done = action.get('episode_done', False)

            #build separate episodes
            for ex in current:
                context.append(ex.get('text', ''))
                if len(context) > 1:
                    ex['text'] = '\n'.join(context)
                ex['episode_done'] = True
                labels = ex.get('labels', ex.get('eval_labels', None))
                if labels is not None and include_labels:
                    context.append(random.choice(labels))
                #generate observation from new example
                if preprocess:
                    ex = agent.observe(ex)
                    ex.pop('label_candidates', '')
                    ex['preprocessed'] = True
                num_eps += 1
                num_exs += 1
                logger.log(num_exs, total_exs)
                pytorch_data.write(json.dumps(make_serializable(ex)) + "\n")
            #reset
            episode_done = False
            current.clear()
            context.clear()

    with open(pytorch_datafile + '.length', 'w') as pytorch_data_len:
        pytorch_data_len.write(
            json.dumps({
                'num_eps': num_eps,
                'num_exs': num_exs
            }))

    print('[ pytorch data built. ]')
    return pytorch_datafile
Example #16
0
            'dict_unktoken': '__UNK__', 'dict_tokenizer': 'split', 'dict_lower': False, 'hiddensize': 1024, \
            'embeddingsize': 300, 'numlayers': 2, 'learningrate': 0.5, 'dropout': 0.1, 'bidirectional': False, \
            'attention': 'general', 'no_cuda': False, 'gpu': -1, 'rank_candidates': False, 'truncate': -1, 'encoder': 'lstm', \
            'decoder': 'same', 'optimizer': 'adam', 'personachat_useprevdialog': True, 'personachat_printattn': False, \
            'personachat_attnsentlevel': True, 'personachat_sharelt': False, 'personachat_reweight': 'use', \
            'personachat_guidesoftmax': False, 'personachat_newsetting': '', 'personachat_interact': False, \
            'personachat_pdmn': False, 'personachat_tfidfperp': False, 'personachat_learnreweight': True, \
            'personachat_embshareonly_pm_dec': False, 'personachat_s2sinit': False, 'interactive_mode': True, \
            'use_persona': 'self', 'parlai_home': ros_integration_home + '/../../../', 'override': {}, \
            'starttime': 'Jun15_16-58'}
    opt['model_type'] = 'profilememory'  # for builder

    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    #### model ready to go ####

    logging.info("Service /roboy/cognition/generative_nlp/answer is ready")

    # start socket server
    server = p.procbridge.ProcBridgeServer(host, port, request_handler)
    server.start()
    print('listening...')

    try:
        for line in sys.stdin:
            if line.strip() == 'exit':
                break
    except KeyboardInterrupt:
Example #17
0
    def test_token_level_loss_logging(self):
        """
        Test functionality of token level probability + ranking logging.

        Regression for all inference types: 'beam', 'greedy', 'topk', 'nucleus',
        'delayedbeam'
        """
        inference_types = ['beam', 'greedy', 'topk', 'nucleus', 'delayedbeam']
        gold_data = {
            'beam': {
                'text_token_info': [
                    ('__start__', 0.0, 1.0),
                    ('5', -2.5510462364763953e-05, 0.0),
                    ('__end__', -1.1920922133867862e-06, 0.0),
                ],
                'extra_args': ['--beam-size', '3'],
            },
            'greedy': {
                'text_token_info': [
                    ('__start__', 0.0, 1.0),
                    ('5', -2.5510462364763953e-05, 0.0),
                    ('__end__', -1.1920922133867862e-06, 0.0),
                ],
                'extra_args': [],
            },
            # sampling based token selection will produce non-deterministic output, so we can't do data regression
            'topk': {
                'extra_args': ['--topk', '2']
            },
            'topk_multiple_beams': {
                'extra_args': ['--topk', '2', '--beam-size', '5']
            },
            # sampling based token selection will produce non-deterministic output, so we can't do data regression
            'nucleus': {
                'extra_args': ['--topp', '0.3']
            },
            'nucleus_multiple_beams': {
                'extra_args': ['--topp', '0.3', '--beam-size', '5']
            },
            # sampling based token selection will produce non-deterministic output, so we can't do data regression
            'delayedbeam': {
                'extra_args': ['--topk', '2', '--beam-delay', '2']
            },
        }

        for inference_type in inference_types:
            args = [
                '--model-file',
                'zoo:unittest/transformer_generator2/model',
                '--inference',
                inference_type,
                '--truncate',
                '1024',
                '-v',
            ] + gold_data[inference_type]['extra_args']

            pp = ParlaiParser(True, True)
            agent = create_agent(pp.parse_args(args), True)
            obs = {'text': '5', 'episode_done': False}
            agent.observe(obs)
            act = agent.act()

            if 'text_token_info' in gold_data[inference_type]:
                for i, tok_data in enumerate(act['text_token_info']):
                    assert (
                        gold_data[inference_type]['text_token_info'][i][0] ==
                        tok_data[0]
                    ), f"failed token prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}"
                    assert math.isclose(
                        gold_data[inference_type]['text_token_info'][i][1],
                        tok_data[1]
                    ), f"failed token probability prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}"
                    assert math.isclose(
                        gold_data[inference_type]['text_token_info'][i][2],
                        tok_data[2]
                    ), f"failed token rank prediction for inference type {inference_type} at token {gold_data[inference_type]['text_token_info'][i][0]}"
Example #18
0
def detect(opt):
    """
    Checks a task for offensive language.
    """
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    agent.opt.log()
    if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
        offensive_string_matcher = OffensiveStringMatcher()
    if opt['safety'] == 'classifier' or opt['safety'] == 'all':
        offensive_classifier = OffensiveLanguageClassifier()

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    stats = {
        'bad_words': [],
        'bad_words_cnt': 0,
        'string_offensive': 0,
        'classifier_offensive': 0,
        'total_offensive': 0,
        'total': 0,
    }

    def report(world, stats):
        report = world.report()
        log = {
            'word_offenses':
            stats['bad_words_cnt'],
            'classifier_offenses%':
            100 * (stats['classifier_offensive'] / stats['total']),
            'string_offenses%':
            100 * (stats['string_offensive'] / stats['total']),
            'total_offenses%':
            100 * (stats['total_offensive'] / stats['total']),
        }
        text, log = log_time.log(report['exs'], world.num_examples(), log)
        logging.info(text)

    def classify(text, stats):
        offensive = False
        stats['total'] += 1
        if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
            bad_words = offensive_string_matcher.contains_offensive_language(
                text)
            if bad_words:
                stats['string_offensive'] += 1
                offensive = True
                stats['bad_words'].append(bad_words)
        if opt['safety'] == 'classifier' or opt['safety'] == 'all':
            if text in offensive_classifier:
                stats['classifier_offensive'] += 1
                offensive = True
        if offensive:
            stats['total_offensive'] += 1

    while not world.epoch_done():
        world.parley()
        stats['bad_words'] = []
        for a in world.acts:
            text = a.get('text', '')
            classify(text, stats)
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                classify(l, stats)
        if len(stats['bad_words']) > 0 and opt['display_examples']:
            logging.info(world.display())
            logging.info("Offensive words detected: {}".format(', '.join(
                stats['bad_words'])))
        stats['bad_words_cnt'] += len(stats['bad_words'])
        if log_time.time() > log_every_n_secs:
            report(world, stats)

    if world.epoch_done():
        logging.info("epoch done")
    report(world, stats)
    return world.report()
from parlai.core.agents import create_agent
from parlai.core.worlds import create_task
import logging
import traceback
import time
logging.basicConfig(level=logging.INFO)

# Set up Flask app
app = Flask(__name__)

# ParlAI bot
BOT = {}
with open("hred_model_opt.json") as f:
    BOT["opt"] = json.load(f)
BOT["opt"]["task"] = "parlai.agents.local_human.local_human:LocalHumanAgent"
agent = create_agent(BOT.get("opt"), requireModelExists=True)
BOT["agent"] = agent
BOT["world"] = create_task(BOT.get("opt"), BOT.get("agent"))
logging.debug(f"Bot parlai settings: {BOT['opt']}")


def create_response(bot_reply, end_of_session=False):
    """Base reply for Alexa"""
    response = {
        "version": "1.0",
        "response": {
            "outputSpeech": {
                "type": "PlainText",
                "text": bot_reply,
            },
            "reprompt": {
Example #20
0
def build_data(opt):
    agent = create_agent(opt)
    #If build teacher not specified, we are simply looking for the file
    if not opt.get('pytorch_buildteacher', None):
        df = opt.get('datafile')
        # check if the user set a datafile
        if not df:
            raise Exception('Tried to find data but `--datafile` is not set')
        # check if the user provided the already built file
        if 'pytorch' not in df:
            df += '.pytorch' + (agent.getID() if opt.get(
                'pytorch_preprocess', True) else '')
        if not os.path.isfile(df):
            raise Exception('Tried to find data but it is not built, please'
                            'specify `--pytorch_buildteacher`')
        else:
            return df

    ordered_opt = copy.deepcopy(opt)
    # we use streaming to build the data
    dt = opt['datatype'].split(':')[0]
    ordered_opt['datatype'] = dt + ':ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['task'] = ordered_opt['pytorch_buildteacher']
    world_data = create_task(ordered_opt, agent)
    teacher = world_data.agents[0]

    datafile = teacher.datafile if hasattr(teacher,
                                           'datafile') else opt.get('datafile')
    if not datafile:
        raise Exception(
            'Tried to build data but either `pytorch_buildteacher` does not '
            'have a datafile or `--datafile` is not set')

    pytorch_datafile = datafile + ".pytorch"
    if opt.get('preprocess', True):
        pytorch_datafile += agent.getID()
    if os.path.isfile(pytorch_datafile):
        # Data already built
        print("[ pytorch data already built. ]")
        return pytorch_datafile
    print('----------\n[ setting up pytorch data. ]\n----------')

    num_eps = 0
    num_exs = 0
    current = []
    episode_done = False
    include_labels = opt.get('include_labels', True)
    context_length = opt.get('context_length', -1)
    context = deque(maxlen=context_length if context_length > 0 else None)
    preprocess = opt.get('pytorch_preprocess', True)
    # pass examples to dictionary
    with open(pytorch_datafile, 'w') as pytorch_data:
        while not world_data.epoch_done():
            while not episode_done:
                action = teacher.act()
                current.append(action)
                episode_done = action.get('episode_done', False)

            #build separate episodes
            for ex in current:
                context.append(ex.get('text', ''))
                if len(context) > 1:
                    ex['text'] = '\n'.join(context)
                ex['episode_done'] = True
                labels = ex.get('labels', ex.get('eval_labels', None))
                if labels is not None and include_labels:
                    context.append(random.choice(labels))
                #generate observation from new example
                if preprocess:
                    ex = agent.observe(ex)
                    ex.pop('label_candidates', '')
                    ex['preprocessed'] = True
                num_eps += 1
                num_exs += 1
                pytorch_data.write(json.dumps(make_serializable(ex)) + "\n")
            #reset
            episode_done = False
            current.clear()
            context.clear()

    with open(pytorch_datafile + '.length', 'w') as pytorch_data_len:
        pytorch_data_len.write(
            json.dumps({
                'num_eps': num_eps,
                'num_exs': num_exs
            }))

    print('[ pytorch data built. ]')
    return pytorch_datafile
Example #21
0
    def train_steps(self):
        """
        Core training loop.

        Yields a metrics dict with each log.
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1
                self._train_steps = self.parleys // self.update_freq
                self._last_log_steps += 1 / self.update_freq

                # the following additionally updates self._total_epochs
                train_time, log_time, validate_time, save_time = self._get_time(
                    world)
                # get the total training examples done, compute epochs
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    yield self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if self._train_steps >= self.max_train_steps:
                    logging.info(
                        f'max_train_steps elapsed:{self._train_steps} '
                        f'time elapsed:{train_time}s')
                    break
                if (log_time > self.log_every_n_secs
                        or self._last_log_steps >= self.log_every_n_steps):
                    yield self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs
                        or self._train_steps - self._last_valid_steps >=
                        self.val_every_n_steps):
                    try:
                        # log before we validate
                        if self._last_log_steps:
                            yield self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    self._last_valid_steps = self._train_steps
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if save_time > self.save_every_n_secs and opt.get(
                        'model_file'):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not sync_object(self.saved):
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)
Example #22
0
    def test_load_dpr(self):
        opt = ParlaiParser(True, True).parse_args([])
        # First, we'll load up a DPR model from the zoo dpr file.
        default_query_encoder = DprQueryEncoder(opt,
                                                dpr_model='bert',
                                                pretrained_path=DPR_ZOO_MODEL)
        rag_sequence_query_encoder = DprQueryEncoder(
            opt,
            dpr_model='bert_from_parlai_rag',
            pretrained_path=RAG_SEQUENCE_ZOO_MODEL,
        )
        assert not torch.allclose(
            default_query_encoder.embeddings.weight.float().cpu(),
            rag_sequence_query_encoder.embeddings.weight.float().cpu(),
        )
        # 1. Create a zoo RAG Agent, which involves a trained DPR model
        rag = create_agent(
            Opt({
                'model_file':
                modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
                'override': {
                    'retriever_debug_index': 'compressed',
                    'fp16': False
                },
            }))
        # The default rag token model should have different query encoders
        # from both the RAG_SEQUENCE_ZOO_MODEL, and the default DPR_ZOO_MODEL
        assert not torch.allclose(
            rag_sequence_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )
        assert not torch.allclose(
            default_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )

        # 2. create a RAG Agent with the rag_sequence_zoo_model DPR model
        rag = create_agent(
            Opt({
                'model_file':
                modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
                'override': {
                    'retriever_debug_index':
                    'compressed',
                    'dpr_model_file':
                    modelzoo_path(opt['datapath'], RAG_SEQUENCE_ZOO_MODEL),
                    'query_model':
                    'bert_from_parlai_rag',
                    'fp16':
                    False,
                },
            }))
        # If we override the DPR Model file, we should now have the same
        # weights as the query encoder from above.
        assert torch.allclose(
            rag_sequence_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )

        # 3. Create a RAG Agent with the default DPR zoo model
        rag = create_agent(
            Opt({
                'model_file':
                modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
                'override': {
                    'retriever_debug_index': 'compressed',
                    'dpr_model_file': modelzoo_path(opt['datapath'],
                                                    DPR_ZOO_MODEL),
                    'fp16': False,
                },
            }))

        # This model was trained with the DPR_ZOO_MODEL, and yet now should have the same weights
        # as we explicitly specified it.
        assert torch.allclose(
            default_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )
Example #23
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                       help=('task to use for valid/test (defaults to the '
                             'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                       type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                       type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                       type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                       type=float, default=-1)
    train.add_argument('-vme', '--validation-max-exs',
                       type=int, default=-1,
                       help='max examples to use during validation (default '
                            '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                       type=int, default=5,
                       help=('number of iterations of validation where result'
                             ' does not improve before we stop training'))
    train.add_argument('-vmt', '--validation-metric', default='accuracy',
                       help='key into report table for selecting best '
                            'validation')
    train.add_argument('-dbf', '--dict-build-first',
                       type='bool', default=True,
                       help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    while True:
        world.parley()
        parleys += 1

        if opt['num_epochs'] > 0 and parleys >= max_parleys:
            print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
            print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
            break
        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))
            logs.append('parleys:{}'.format(parleys))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()

            if hasattr(train_report, 'get') and train_report.get('total'):
                total_exs += train_report['total']
                logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec = train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            print(log)
            log_time.reset()

        if (opt['validation_every_n_secs'] > 0 and
                validate_time.time() > opt['validation_every_n_secs']):
            valid_report, valid_world = run_eval(
                agent, opt, 'valid', opt['validation_max_exs'],
                valid_world=valid_world)
            if valid_report[opt['validation_metric']] > best_valid:
                best_valid = valid_report[opt['validation_metric']]
                impatience = 0
                print('[ new best {}: {} ]'.format(
                    opt['validation_metric'], best_valid))
                world.save_agents()
                saved = True
                if opt['validation_metric'] == 'accuracy' and best_valid == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print('[ did not beat best {}: {} impatience: {} ]'.format(
                        opt['validation_metric'], round(best_valid, 4),
                        impatience))
            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping training. ]')
                break
    world.shutdown()
    if not saved:
        world.save_agents()
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Example #24
0
    def act(self):
        obs = self.observation
        reply = {}
        reply['id'] = self.getID()
        if 'labels' in obs:
            return self.train_act()
        if 'text' in obs:
            self.rebuild()  # no-op if nothing has been queued to store
            all_doc_ids = []
            all_doc_scores = []
            for r in self.rankers:
                doc_ids, doc_scores = r.closest_docs(
                    obs['text'], self.opt.get('retriever_num_retrieved', 5))
                all_doc_ids.append(doc_ids)
                all_doc_scores.append(doc_scores)
            '''doc_ids, doc_scores = self.ranker.closest_docs(
                obs['text'],
                self.opt.get('retriever_num_retrieved', 5)
            )
            doc_ids2, doc_scores2 = self.ranker2.closest_docs(
                obs['text'],
                self.opt.get('retriever_num_retrieved', 5)
            )'''
            if False and obs.get(
                    'label_candidates'):  # TODO: Alex (doesn't work)
                # these are better selection than stored facts
                # rank these options instead
                cands = obs['label_candidates']
                cands_id = id(cands)
                if cands_id not in self.cands_hash:
                    # cache candidate set
                    # will not update if cand set changes contents
                    c_list = list(cands)
                    self.cands_hash[cands_id] = (get_tfidf_matrix(
                        live_count_matrix(self.tfidf_args, c_list)), c_list)
                c_ids, c_scores = self.ranker.closest_docs(
                    obs['text'],
                    self.opt.get('retriever_num_retrieved', 5),
                    matrix=self.cands_hash[cands_id][0])
                reply['text_candidates'] = [
                    self.cands_hash[cands_id][1][cid] for cid in c_ids
                ]
                reply['candidate_scores'] = c_scores
                if len(reply['text_candidates']) > 0:
                    reply['text'] = reply['text_candidates'][0]
                else:
                    reply['text'] = ''
            elif len(doc_ids) > 0:
                # return stored fact
                # total = sum(doc_scores)
                # doc_probs = [d / total for d in doc_scores]

                # returned
                all_picks = []
                all_pick = []
                for doc_ids in all_doc_ids:
                    picks = [self.doc2txt(int(did)) for did in doc_ids]
                    if len(doc_ids) > 0:
                        pick = self.doc2txt(int(doc_ids[0]))
                    else:
                        pick = ''
                    if self.opt.get('remove_title', False):
                        picks = ['\n'.join(p.split('\n')[1:]) for p in picks]
                        pick = '\n'.join(pick.split('\n')[1:])
                    all_picks.append(picks)
                    all_pick.append(pick)
                '''picks = [self.doc2txt(int(did)) for did in doc_ids]
                pick = self.doc2txt(int(doc_ids[0]))  # select best response
                
                picks2 = [self.doc2txt(int(did)) for did in doc_ids2]
                if len(doc_ids2) > 0:
                    pick2 = self.doc2txt(int(doc_ids2[0]))
                else:
                    pick2 = ''
                if self.opt.get('remove_title', False):
                    picks = ['\n'.join(p.split('\n')[1:]) for p in picks]
                    pick = '\n'.join(pick.split('\n')[1:])
                    picks2 = ['\n'.join(p.split('\n')[1:]) for p in picks2]
                    pick2 = '\n'.join(pick2.split('\n')[1:])'''

                reply['text_candidates'] = [
                    p for sublist in all_picks for p in sublist
                ]
                '''reply['text_candidates'] = picks+picks2'''
                #print(doc_scores)
                #print(doc_scores2)
                #print(type(doc_scores))
                reply['candidate_scores'] = [
                    d for sublist in all_doc_scores for d in sublist
                ]
                '''reply['candidate_scores'] = np.concatenate((doc_scores, doc_scores2), axis=None)'''
                #reply['candidate_scores'] = doc_scores+doc_scores2

                # could pick single choice based on probability scores?
                # pick = int(choice(doc_ids, p=doc_probs))
                reply['text'] = '\n'.join(all_pick)

                context = ' '.join(reply['text_candidates'])
                #print('1:', picks[0])
                #print('2:', picks2[0])
                #print('context:', context)
                #print('question:', obs['text'])
                qa_opts = {}
                qa_opts[
                    'task'] = 'parlai.agents.local_human_silent.local_human_silent:LocalHumanAgent'
                qa_opts['model'] = 'drqa'
                qa_opts['model_file'] = 'models:drqa/squad/model'
                qa_opts['query'] = context + '\n' + obs['text']
                qa_agent = create_agent(qa_opts, requireModelExists=True)
                qa_world = create_task(qa_opts, qa_agent)
                qa_world.parley()

            else:
                # no cands and nothing found, return generic response
                reply['text'] = choice([
                    'Can you say something more interesting?',
                    'Why are you being so short with me?',
                    'What are you really thinking?',
                    'Can you expand on that?',
                ])

        return reply
Example #25
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException:
                    if is_distributed():
                        raise RuntimeError(
                            "StopTrainException not supported for " "distributed mode"
                        )
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs())
                )
                exs_per_epoch = world.num_examples()
                self._total_exs = int(np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object(
                    (
                        self.train_time.time(),
                        self.log_time.time(),
                        self.validate_time.time(),
                    )
                )

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    print(
                        '[ num_epochs completed:{} time elapsed:{}s ]'.format(
                            self.max_num_epochs, train_time
                        )
                    )
                    break
                if train_time > self.max_train_time:
                    print('[ max_train_time elapsed:{}s ]'.format(train_time))
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (
                    validate_time > self.val_every_n_secs
                    or self._total_epochs - self.last_valid_epoch
                    >= self.val_every_n_epochs
                ):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        if is_distributed():
                            raise RuntimeError(
                                "StopTrainException not supported for distributed mode"
                            )
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (
                    self.save_time.time() > self.save_every_n_secs
                    and opt.get('model_file')
                    and is_primary_worker()
                ):
                    print(
                        "[ saving model checkpoint: {}.checkpoint".format(
                            opt['model_file']
                        )
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()
        elif opt.get('model_file'):
            # reload best validation model
            self.agent = create_agent(opt)

        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
        v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
Example #26
0
def main(parser):
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    with world:
        while True:
            world.parley()
            parleys += 1

            if opt['num_epochs'] > 0 and parleys >= max_parleys:
                print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
                break
            if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
                print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
                break
            if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
                if opt['display_examples']:
                    print(world.display() + '\n~~')

                logs = []
                # time elapsed
                logs.append('time:{}s'.format(math.floor(train_time.time())))
                logs.append('parleys:{}'.format(parleys))

                # get report and update total examples seen so far
                if hasattr(agent, 'report'):
                    train_report = agent.report()
                    agent.reset_metrics()
                else:
                    train_report = world.report()
                    world.reset_metrics()

                if hasattr(train_report, 'get') and train_report.get('total'):
                    total_exs += train_report['total']
                    logs.append('total_exs:{}'.format(total_exs))

                # check if we should log amount of time remaining
                time_left = None
                if opt['num_epochs'] > 0:
                    exs_per_sec = train_time.time() / total_exs
                    time_left = (max_exs - total_exs) * exs_per_sec
                if opt['max_train_time'] > 0:
                    other_time_left = opt['max_train_time'] - train_time.time()
                    if time_left is not None:
                        time_left = min(time_left, other_time_left)
                    else:
                        time_left = other_time_left
                if time_left is not None:
                    logs.append('time_left:{}s'.format(math.floor(time_left)))

                # join log string and add full metrics report to end of log
                log = '[ {} ] {}'.format(' '.join(logs), train_report)

                print(log)
                log_time.reset()

            if (opt['validation_every_n_secs'] > 0 and
                    validate_time.time() > opt['validation_every_n_secs']):
                valid_report, valid_world = run_eval(
                    agent, opt, 'valid', opt['validation_max_exs'],
                    valid_world=valid_world)
                if valid_report[opt['validation_metric']] > best_valid:
                    best_valid = valid_report[opt['validation_metric']]
                    impatience = 0
                    print('[ new best {}: {} ]'.format(
                        opt['validation_metric'], best_valid))
                    world.save_agents()
                    saved = True
                    if opt['validation_metric'] == 'accuracy' and best_valid > 99.5:
                        print('[ task solved! stopping. ]')
                        break
                else:
                    impatience += 1
                    print('[ did not beat best {}: {} impatience: {} ]'.format(
                            opt['validation_metric'], round(best_valid, 4),
                            impatience))
                validate_time.reset()
                if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                    print('[ ran out of patience! stopping training. ]')
                    break
    if not saved:
        # save agent
        world.save_agents()
    elif opt.get('model_file'):
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Example #27
0
        def run_conversation(mturk_manager, opt, workers):
            """
            Runs the conversation
            :param mturk_manager: MTurk manager
            :param opt: command line arguments
            :param workers: list of workers
            :return: Nothing.
            """
            global game_id
            global worker_record

            conversation_start_time = time.time()

            # Copy workers into agents list
            agents = workers[:]
            # Get worker names
            names = get_worker_names(agents)
            print(names)

            # Create a local agent
            if not opt['two_mturk_agents']:
                if 'model' in opt:
                    local_agent = create_agent(opt)
                else:
                    local_agent = LocalHumanAgent(opt=None)

                local_agent.id = local_agent_1_id
                agents.append(local_agent)

            opt["batchindex"] = mturk_manager.started_conversations

            print("Loading game {}".format(game_id))

            print(list(worker_record.keys()))
            print(agents[0].worker_id)
            print(agents[1].worker_id)

            # If the workers never played before, start with the warm-up round
            if (agents[0].worker_id
                    not in worker_record) and (agents[1].worker_id
                                               not in worker_record):
                world = MTurkDMGDialogWarmupWorld(
                    opt=opt,
                    agents=agents,
                )

                print("--- Starting Warming-Up Round ---")
                while not world.episode_done():
                    if world.parley():
                        break

            world = MTurkDMGDialogWorld(opt=opt,
                                        agents=agents,
                                        game_id=game_id,
                                        names=names)

            get_pay = {agents[0].worker_id: False, agents[1].worker_id: False}

            print("--- Starting Game ---")
            while not world.episode_done():
                print("Parley!")
                world.parley()

            print("# # # DONE # # #")

            if world.disconnected:
                print("Game ended due to disconnect.")
                if world.round_nr > 1:
                    for agent in agents:
                        if not agent.disconnected:
                            print("CHECK: Agent {} did NOT disconnect".format(
                                agent.worker_id))
                            get_pay[agent.worker_id] = True
                        else:
                            print("CHECK: Agent {} DID disconnect".format(
                                agent.worker_id))

            else:
                # Only save records when game was complete
                print("Updating records")
                update_records(agents, game_id)
                save_records()

                if world.total_score > 24:
                    print("Total score was above 24, paying both workers.")
                    get_pay = {
                        agents[0].worker_id: True,
                        agents[1].worker_id: True
                    }
                else:
                    print("Score too low!")

            if world.end_time:
                conversation_end_time = world.end_time
            else:
                conversation_end_time = conversation_start_time
            world.shutdown()
            print("# # # Game ended # # #")

            duration = conversation_end_time - conversation_start_time
            duration_mins = duration / 60.0
            time_bonus = None

            if duration_mins > 1:
                if duration_mins >= 25:
                    time_bonus = 1.50
                else:
                    time_bonus = int(duration_mins - 10) * 0.10
                    time_bonus = round(time_bonus, 2)

            if time_bonus and time_bonus > 1.5:
                time_bonus = 1.5
            if time_bonus and time_bonus < 0:
                time_bonus = None
            pay_workers(agents, get_pay, time_bonus)
            print("Conversation closed.")
Example #28
0
def eval_wordstat(opt, print_parser=None):
    """Evaluates a model.

    Arguments:
    opt -- tells the evaluation function how to run
    print_parser -- if provided, prints the options that are set within the
        model after loading the model
    """
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(
            opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    cnt = 0
    mean_wlength = []
    mean_clength = []
    freqs_cnt = Counter()
    word_cnt = 0
    bins = [int(i) for i in opt['freq_bins'].split(',')]
    pred_list = []

    while not world.epoch_done():
        cnt += 1
        world.parley()
        prediction = world.acts[-1]['text']
        pred_list.append(normalize_answer(prediction))
        freqs, _cnt, wlength, clength = get_word_stats(prediction,
                                                       dictionary,
                                                       bins=bins)
        word_cnt += _cnt

        mean_wlength.append(wlength)
        mean_clength.append(clength)

        freqs_cnt += Counter(freqs)

        if log_time.time() > log_every_n_secs or (
                opt['num_examples'] > 0
                and cnt >= opt['num_examples']) or world.epoch_done():
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(),
                                        report)
            print(text)
            stat_str = 'total_words: {}, '.format(word_cnt) + ', '.join([
                '<{}:{} ({:.{prec}f}%)'.format(
                    b,
                    freqs_cnt.get(b, 0),
                    (freqs_cnt.get(b, 0) / word_cnt) * 100,
                    prec=2) for b in bins
            ])
            print(
                "Word statistics: {}, avg_word_length: {:.{prec}f}, avg_char_length: {:.{prec}f}"
                .format(stat_str,
                        numpy.array(mean_wlength).mean(),
                        numpy.array(mean_clength).mean(),
                        prec=2))
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")

    if opt['compute_unique'] is True:
        unique_list = []
        cntr = Counter(pred_list)
        for k, v in cntr.items():
            if v == 1:
                unique_list.append(k)
        print("Unique responses: {:.{prec}f}%".format(len(unique_list) /
                                                      len(pred_list) * 100,
                                                      prec=2))

    if opt['dump_predictions_path'] is not None:
        with open(opt['dump_predictions_path'], 'w') as f:
            f.writelines(['{}\n'.format(i) for i in pred_list])
        if opt['compute_unique'] is True:
            with open(opt['dump_predictions_path'] + '_unique', 'w') as f:
                f.writelines(['{}\n'.format(i) for i in unique_list])

    report = world.report()
    print(report)
    return report
Example #29
0
def main(parser):
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    with world:
        while True:
            world.parley()
            parleys += 1

            if opt['num_epochs'] > 0 and parleys >= max_parleys:
                print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
                break
            if opt['max_train_time'] > 0 and train_time.time(
            ) > opt['max_train_time']:
                print('[ max_train_time elapsed: {} ]'.format(
                    train_time.time()))
                break
            if opt['log_every_n_secs'] > 0 and log_time.time(
            ) > opt['log_every_n_secs']:
                if opt['display_examples']:
                    print(world.display() + '\n~~')

                logs = []
                # time elapsed
                logs.append('time:{}s'.format(math.floor(train_time.time())))
                logs.append('parleys:{}'.format(parleys))

                # get report and update total examples seen so far
                if hasattr(agent, 'report'):
                    train_report = agent.report()
                    agent.reset_metrics()
                else:
                    train_report = world.report()
                    world.reset_metrics()

                if hasattr(train_report, 'get') and train_report.get('total'):
                    total_exs += train_report['total']
                    logs.append('total_exs:{}'.format(total_exs))

                # check if we should log amount of time remaining
                time_left = None
                if opt['num_epochs'] > 0:
                    exs_per_sec = train_time.time() / total_exs
                    time_left = (max_exs - total_exs) * exs_per_sec
                if opt['max_train_time'] > 0:
                    other_time_left = opt['max_train_time'] - train_time.time()
                    if time_left is not None:
                        time_left = min(time_left, other_time_left)
                    else:
                        time_left = other_time_left
                if time_left is not None:
                    logs.append('time_left:{}s'.format(math.floor(time_left)))

                # join log string and add full metrics report to end of log
                log = '[ {} ] {}'.format(' '.join(logs), train_report)

                print(log)
                log_time.reset()

            if (opt['validation_every_n_secs'] > 0
                    and validate_time.time() > opt['validation_every_n_secs']):
                valid_report, valid_world = run_eval(agent,
                                                     opt,
                                                     'valid',
                                                     opt['validation_max_exs'],
                                                     valid_world=valid_world)
                if valid_report[opt['validation_metric']] > best_valid:
                    best_valid = valid_report[opt['validation_metric']]
                    impatience = 0
                    print('[ new best {}: {} ]'.format(
                        opt['validation_metric'], best_valid))
                    world.save_agents()
                    saved = True
                    if opt['validation_metric'] == 'accuracy' and best_valid > 99.5:
                        print('[ task solved! stopping. ]')
                        break
                else:
                    impatience += 1
                    print('[ did not beat best {}: {} impatience: {} ]'.format(
                        opt['validation_metric'], round(best_valid, 4),
                        impatience))
                validate_time.reset()
                if opt['validation_patience'] > 0 and impatience >= opt[
                        'validation_patience']:
                    print('[ ran out of patience! stopping training. ]')
                    break
    if not saved:
        # save agent
        world.save_agents()
    elif opt.get('model_file'):
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Example #30
0
def main():
    argparser = ParlaiParser(
        False, True, description="MTurk evaluator for GeneratorMMIAgent")
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()

    # Custom args
    agent.add_cmdline_args(argparser)
    argparser.set_defaults(
        model='transformer/generatorMMI',
        model_file='parlai_internal/forward.ckpt.checkpoint',
        model_file_backwards='parlai_internal/backward.ckpt.checkpoint',
        inference='beam',
        beam_size=8)

    opt = argparser.parse_args()
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    opt.update(task_config)

    # add additional model args
    opt['override'] = {
        'task': opt['task'],
        'inference': opt['inference'],
        'beam_size': opt['beam_size'],
        'no_cuda': True,
        'interactive_mode': True,
        'tensorboard_log': False,
    }

    # print(f"CURRENT OPTIONS {opt}")

    # Set up the model we want to evaluate
    tester_agent = create_agent(opt)

    # The task that we will evaluate the dialog model on
    task_opt = {}
    task_opt['datatype'] = 'test'
    task_opt['datapath'] = opt['datapath']
    task_opt['task'] = '#DailyDialog'
    # task_opt['task'] = '#Persona-Chat'

    mturk_agent_id = 'Worker'
    mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=[mturk_agent_id])
    mturk_manager.setup_server(heroku_app_name="dialogue-hw4-mturk-eval",
                               existing_app=True)

    try:
        mturk_manager.start_new_run()
        mturk_manager.create_hits()

        def run_onboard(worker):
            world = ModelEvaluatorOnboardWorld(opt=opt, mturk_agent=worker)
            while not world.episode_done():
                world.parley()
            world.shutdown()

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()

        def check_worker_eligibility(worker):
            return True

        def assign_worker_roles(worker):
            worker[0].id = mturk_agent_id

        global run_conversation

        def run_conversation(mturk_manager, opt, workers):
            mturk_agent = workers[0]

            world = ModelEvaluatorWorld(
                opt=opt,
                model_agent=tester_agent,
                task_opt=task_opt,
                mturk_agent=mturk_agent,
            )

            for i in range(51):
                # while not world.episode_done():
                world.parley()
            world.shutdown()
            world.review_work()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation,
        )
    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Example #31
0
def create_supp(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    :return: the final result of calling report()
    """
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    # Extract supp examples from misses on deploy set
    num_seen = 0
    num_misses = 0
    num_supp = 0
    num_supp_correct = 0
    examples = []
    while not world.epoch_done():
        world.parley()
        # Examples are considered one at a time
        num_seen += 1
        if num_seen % 1000 == 0:
            print(f"{num_seen}/{world.num_examples()}")
        report = world.report()
        if report['accuracy'] < 1.0:
            # Example is a miss (i.e., model got it wrong)
            num_misses += 1
            if random.random() < opt['conversion_rate']:
                # Example will be converted (e.g., bot recognized mistake and asked)
                num_supp += 1
                texts = world.acts[0]['text'].split('\n')
                context = texts[-1]
                memories = texts[:-1]
                candidates = world.acts[0]['label_candidates']
                # Reward of 1 indicates positive, -1 indicates negative (for training)
                # For now, we only train with positives, and the reward field is unused
                reward = 1

                if random.random() < opt['conversion_acc']:
                    # Example will be converted correctly (e.g., good user response)
                    num_supp_correct += 1
                    response = world.acts[0]['eval_labels'][0]
                else:
                    # Example will be converted incorrectly (e.g., bad user response)
                    response = random.choice(
                        world.acts[0]['label_candidates'][:NUM_INLINE_CANDS -
                                                          1])

                example = Parley(context, response, reward, candidates,
                                 memories)
                examples.append(example)
        world.reset_metrics()

    print("EPOCH DONE")
    print(f"Model file: {opt['model_file']}")
    print(f"Deploy file: {opt['task']}")
    print(f"Supp file: {opt['outfile']}")
    print(f"Deploy size (# examples seen): {num_seen}")
    print(f"Supp size (# examples converted): {num_supp}")

    acc = 1 - (num_misses / num_seen)
    print(f"Accuracy (% of deploy): {acc * 100:.1f}% ({num_misses} misses)")
    print(f"Conversion rate (% of misses): {num_supp/num_misses * 100:.2f}% "
          f"({num_supp}/{num_misses})")
    print(
        f"Conversion acc (% of converted): {num_supp_correct/num_supp * 100:.2f}% "
        f"({num_supp_correct}/{num_supp})")

    with open(opt['outfile'], 'w') as outfile:
        for ex in examples:
            outfile.write(json.dumps(ex.to_dict()) + '\n')
Example #32
0
def build_data(opt):
    agent = create_agent(opt)
    #If build teacher not specified, we are simply looking for the file
    if not opt.get('pytorch_buildteacher', None):
        df = opt.get('datafile')
        # check if the user set a datafile
        if not df:
            raise Exception('Tried to find data but `--datafile` is not set')
        # check if the user provided the already built file
        if 'pytorch' not in df:
            df += '.pytorch' + (agent.getID() if opt.get('pytorch_preprocess', True) else '')
        if not os.path.isfile(df):
            raise Exception('Tried to find data but it is not built, please'
                            'specify `--pytorch-buildteacher`')
        else:
            return df

    ordered_opt = copy.deepcopy(opt)
    # we use streaming to build the data
    dt = opt['datatype'].split(':')[0]
    ordered_opt['datatype'] = dt + ':ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['task'] = ordered_opt['pytorch_buildteacher']
    world_data = create_task(ordered_opt, agent)
    teacher = world_data.agents[0]

    datafile = teacher.datafile if hasattr(teacher, 'datafile') else opt.get('datafile')
    if not datafile:
        raise Exception('Tried to build data but either `pytorch-buildteacher` does not '
                        'have a datafile or `--datafile` is not set')

    if isinstance(datafile, collections.Sequence):
        datafile = datafile[0] + "".join(["_".join(d.split("/")) for d in datafile[1:]])
    pytorch_datafile = datafile + ".pytorch"
    preprocess = opt.get('pytorch_preprocess', True)
    if preprocess:
        pytorch_datafile += agent.getID()
    if os.path.isfile(pytorch_datafile):
        # Data already built
        print("[ pytorch data already built. ]")
        return pytorch_datafile
    print('----------\n[ setting up pytorch data, saving to {}. ]\n----------'.format(pytorch_datafile))

    num_eps = 0
    num_exs = 0
    current = []
    episode_done = False
    include_labels = opt.get('include_labels', True)
    context_length = opt.get('context_length', -1)
    context = deque(maxlen=context_length if context_length > 0 else None)
    # pass examples to dictionary
    with open(pytorch_datafile, 'w') as pytorch_data:
        while not world_data.epoch_done():
            while not episode_done:
                action = teacher.act()
                current.append(action)
                episode_done = action.get('episode_done', False)

            #build separate episodes
            for ex in current:
                context.append(ex.get('text', ''))
                if len(context) > 1:
                    ex['text'] = '\n'.join(context)
                ex['episode_done'] = True
                labels = ex.get('labels', ex.get('eval_labels', None))
                if labels is not None and include_labels:
                    context.append(random.choice(labels))
                #generate observation from new example
                if preprocess:
                    ex = agent.observe(ex)
                    ex.pop('label_candidates', '')
                    ex['preprocessed'] = True
                num_eps += 1
                num_exs += 1
                pytorch_data.write(json.dumps(make_serializable(ex)) + "\n")
            #reset
            episode_done = False
            current.clear()
            context.clear()

    with open(pytorch_datafile + '.length', 'w') as pytorch_data_len:
        pytorch_data_len.write(json.dumps({'num_eps':num_eps, 'num_exs':num_exs}))

    print('[ pytorch data built. ]')
    return pytorch_datafile
Example #33
0
    def __init__(self, opt):
        # if python is called from a non-interactive shell, like a bash script,
        # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are
        # not produced. This line brings them back
        signal.signal(signal.SIGINT, signal.default_int_handler)

        if isinstance(opt, ParlaiParser):
            print(
                '[ Deprecated Warning: TrainLoop should be passed opt not Parser ]'
            )
            opt = opt.parse_args()
        # Possibly load from checkpoint
        trainstats_suffix = '.trainstats'  # we might load training statistics from here
        if opt['load_from_checkpoint'] and opt.get(
                'model_file') and os.path.isfile(opt['model_file'] +
                                                 '.checkpoint'):
            opt['init_model'] = opt['model_file'] + '.checkpoint'
            trainstats_suffix = '.checkpoint.trainstats'
        # Possibly build a dictionary (not all models do this).
        if opt['dict_build_first'] and 'dict_file' in opt:
            # If data built via pytorch data teacher, we need to load prebuilt dict
            if opt.get('pytorch_teacher_task'):
                opt['dict_file'] = get_pyt_dict_file(opt)
            elif opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            print("[ building dictionary first... ]")
            build_dict(opt, skip_if_built=True)
        # Create model and assign it to the specified task
        self.agent = create_agent(opt)
        self.world = create_task(opt, self.agent)
        # set up timers
        self.train_time = Timer()
        self.validate_time = Timer()
        self.log_time = Timer()
        self.save_time = Timer()
        print('[ training... ]')
        self.parleys = 0
        self.max_num_epochs = opt[
            'num_epochs'] if opt['num_epochs'] > 0 else float('inf')
        self.max_train_time = opt['max_train_time'] if opt['max_train_time'] > 0 \
            else float('inf')
        self.log_every_n_secs = opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 \
            else float('inf')
        self.val_every_n_secs = \
            opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 \
            else float('inf')
        self.save_every_n_secs = opt['save_every_n_secs'] if opt['save_every_n_secs'] \
            > 0 else float('inf')
        self.val_every_n_epochs = \
            opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 \
            else float('inf')
        self.last_valid_epoch = 0
        self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
        self.best_valid = None
        if opt.get('model_file') and os.path.isfile(opt['model_file'] +
                                                    '.best_valid'):
            with open(opt['model_file'] + ".best_valid", 'r') as f:
                x = f.readline()
                self.best_valid = float(x)
                f.close()
        self.impatience = 0
        self.saved = False
        self.valid_world = None
        self.opt = opt

        # we may have been preempted, make sure we note that amount
        self._preempted_epochs = 0.0
        if (opt.get('model_file')
                and os.path.isfile(opt['model_file'] + trainstats_suffix)):
            # looks like we were preempted. make sure we load up our total
            # training stats, etc
            with open(opt['model_file'] + trainstats_suffix) as ts:
                obj = json.load(ts)
                self._preempted_epochs = obj.get('total_epochs', 0)
                self.train_time.total = obj.get('train_time', 0)
                self.impatience = obj.get('impatience', 0)

        if opt['tensorboard_log'] is True:
            self.writer = TensorboardLogger(opt)
Example #34
0
def main():
    # Get command line arguments
    parser = ParlaiParser(True, True)
    train = parser.add_argument_group('Training Loop Arguments')
    train.add_argument('-et', '--evaltask',
                        help=('task to use for valid/test (defaults to the ' +
                              'one used for training if not set)'))
    train.add_argument('-d', '--display-examples',
                        type='bool', default=False)
    train.add_argument('-e', '--num-epochs', type=float, default=-1)
    train.add_argument('-ttim', '--max-train-time',
                        type=float, default=-1)
    train.add_argument('-ltim', '--log-every-n-secs',
                        type=float, default=2)
    train.add_argument('-vtim', '--validation-every-n-secs',
                        type=float, default=-1)
    train.add_argument('-vme', '--validation-max-exs',
                        type=int, default=-1,
                        help='max examples to use during validation (default ' +
                             '-1 uses all)')
    train.add_argument('-vp', '--validation-patience',
                        type=int, default=5,
                        help=('number of iterations of validation where result '
                              + 'does not improve before we stop training'))
    train.add_argument('-vmt', '--validation-metric', default='accuracy',
                       help='key into report table for selecting best validation')
    train.add_argument('-dbf', '--dict-build-first',
                        type='bool', default=True,
                        help='build dictionary first before training agent')
    opt = parser.parse_args()
    # Possibly build a dictionary (not all models do this).
    if opt['dict_build_first'] and 'dict_file' in opt:
        if opt['dict_file'] is None and opt.get('model_file'):
            opt['dict_file'] = opt['model_file'] + '.dict'
        print("[ building dictionary first... ]")
        build_dict.build_dict(opt)
    # Create model and assign it to the specified task
    agent = create_agent(opt)
    world = create_task(opt, agent)

    train_time = Timer()
    validate_time = Timer()
    log_time = Timer()
    print('[ training... ]')
    parleys = 0
    total_exs = 0
    max_exs = opt['num_epochs'] * len(world)
    max_parleys = math.ceil(max_exs / opt['batchsize'])
    best_valid = 0
    impatience = 0
    saved = False
    valid_world = None
    while True:
        world.parley()
        parleys += 1

        if opt['num_epochs'] > 0 and parleys >= max_parleys:
            print('[ num_epochs completed: {} ]'.format(opt['num_epochs']))
            break
        if opt['max_train_time'] > 0 and train_time.time() > opt['max_train_time']:
            print('[ max_train_time elapsed: {} ]'.format(train_time.time()))
            break
        if opt['log_every_n_secs'] > 0 and log_time.time() > opt['log_every_n_secs']:
            if opt['display_examples']:
                print(world.display() + '\n~~')

            logs = []
            # time elapsed
            logs.append('time:{}s'.format(math.floor(train_time.time())))
            logs.append('parleys:{}'.format(parleys))

            # get report and update total examples seen so far
            if hasattr(agent, 'report'):
                train_report = agent.report()
                agent.reset_metrics()
            else:
                train_report = world.report()
                world.reset_metrics()

            if hasattr(train_report, 'get') and train_report.get('total'):
                total_exs += train_report['total']
                logs.append('total_exs:{}'.format(total_exs))

            # check if we should log amount of time remaining
            time_left = None
            if opt['num_epochs'] > 0:
                exs_per_sec = train_time.time() / total_exs
                time_left = (max_exs - total_exs) * exs_per_sec
            if opt['max_train_time'] > 0:
                other_time_left = opt['max_train_time'] - train_time.time()
                if time_left is not None:
                    time_left = min(time_left, other_time_left)
                else:
                    time_left = other_time_left
            if time_left is not None:
                logs.append('time_left:{}s'.format(math.floor(time_left)))

            # join log string and add full metrics report to end of log
            log = '[ {} ] {}'.format(' '.join(logs), train_report)

            print(log)
            log_time.reset()

        if (opt['validation_every_n_secs'] > 0 and
                validate_time.time() > opt['validation_every_n_secs']):
            valid_report, valid_world = run_eval(
                agent, opt, 'valid', opt['validation_max_exs'],
                valid_world=valid_world)
            if valid_report[opt['validation_metric']] > best_valid:
                best_valid = valid_report[opt['validation_metric']]
                impatience = 0
                print('[ new best {}: {} ]'.format(
                    opt['validation_metric'], best_valid))
                world.save_agents()
                saved = True
                if opt['validation_metric'] == 'accuracy' and best_valid == 1:
                    print('[ task solved! stopping. ]')
                    break
            else:
                impatience += 1
                print('[ did not beat best {}: {} impatience: {} ]'.format(
                        opt['validation_metric'], round(best_valid, 4),
                        impatience))
            validate_time.reset()
            if opt['validation_patience'] > 0 and impatience >= opt['validation_patience']:
                print('[ ran out of patience! stopping training. ]')
                break
    world.shutdown()
    if not saved:
        world.save_agents()
    else:
        # reload best validation model
        agent = create_agent(opt)

    run_eval(agent, opt, 'valid', write_log=True)
    run_eval(agent, opt, 'test', write_log=True)
Example #35
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    assert partner is not None
    partner_opt_file = opt.get('partner_opt_file')
    if partner_opt_file:
        assert partner_opt_file == partner + '.opt', (
            'Unless you think it is save,'
            ' you can remove assert')
    else:
        partner_opt_file = partner + '.opt'

    # Create agents
    if opt['model_file'].split(':')[0] == 'human':
        agent1 = MyLocalHumanAgent(opt)
        assert partner is not None
    else:
        agent1 = create_agent(opt, requireModelExists=True)
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)

    # Set IDs
    agent1.id = agent1.id + '_1'
    agent2.id = agent2.id + '_2'

    model_id = agent1.id + '_' + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    all_report = []
    if opt['num_self_chats'] < 0:
        opt['num_self_chats'] = len(world.messages)

    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)
        all_report.append(report)

        world.write(logger, all_report, opt['outfile'])

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, all_report, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Example #36
0
def build_data(opt):
    if not opt.get('model', False):
        opt['model'] = 'repeat_label'
    preprocess = opt.get('pytorch_preprocess', True)
    opt['dict_file'] = get_pyt_dict_file(opt)
    dictionary = None
    if 'dict_maxexs' in opt:
        # Note: only build dictionary if dict loop args specified
        dictionary = build_dict(opt, skip_if_built=True)
    agent = create_agent(opt)
    # If build teacher not specified, we are simply looking for the file
    if not opt.get('pytorch_teacher_task', None):
        df = opt.get('pytorch_datapath')
        # check if the user set a datafile
        if not df:
            raise Exception(
                'Tried to find data but `--pytorch-datapath` is not set')
        # check if the user provided the already built file
        if 'pytorch' not in df:
            df += '.pytorch' + (agent.getID() if opt.get(
                'pytorch_preprocess', True) else '')
        if not os.path.isfile(df):
            raise Exception('Tried to find data but it is not built, please'
                            'specify `--pytorch-teacher-task`')
        else:
            return df

    ordered_opt = copy.deepcopy(opt)
    # we use streaming to build the data
    dt = opt['datatype'].split(':')[0]
    ordered_opt['datatype'] = dt + ':ordered:stream'
    ordered_opt['numthreads'] = 1
    ordered_opt['batchsize'] = 1
    ordered_opt['task'] = ordered_opt['pytorch_teacher_task']
    ordered_opt.pop('pytorch_teacher_dataset')
    ordered_opt['no_cuda'] = True
    world_data = create_task(ordered_opt, agent)
    teacher = world_data.get_task_agent()
    agent = world_data.agents[1]
    datapath = os.path.join(
        opt.get('datapath', '.'),
        '{}_pyt_data'.format(ordered_opt['task'].replace(':', '_')),
        dt,
    )
    if preprocess:
        datapath += '_{}_preprocess'.format(agent.getID().replace(':', '_'))
    if os.path.isdir(datapath) and 'data_length' in os.listdir(datapath):
        # Data already built
        print("[ pytorch data already built, at {}. ]".format(datapath))
        return datapath
    print('----------\n[ setting up pytorch data, saving to {}/ ]\n----------'.
          format(datapath))
    os.makedirs(datapath, exist_ok=True)
    num_eps = 0
    num_exs = 0
    current = []
    episode_done = False
    include_labels = opt.get('pytorch_include_labels', True)
    context_length = opt.get('pytorch_context_length', -1)
    context = deque(maxlen=context_length if context_length > 0 else None)
    total_exs = world_data.num_examples()
    pbar = tqdm.tqdm(total=total_exs,
                     unit='ex',
                     unit_scale=True,
                     desc='Building pytorch data')
    idx_to_char = []
    cumulative_char_len = 0
    # pass examples to dictionary
    with open(os.path.join(datapath, 'data'), 'w') as pytorch_data:
        while num_exs < total_exs:
            while not episode_done:
                # TODO: eventually all teachers should return Messages, so
                # we should assert this
                action = Message(teacher.act())
                current.append(action)
                episode_done = action.get('episode_done', False)

            # build separate episodes
            for ex in current:
                context.append(ex.get('text', ''))
                if len(context) > 1:
                    ex.force_set('text', '\n'.join(context))
                ex.force_set('episode_done', True)
                labels = ex.get('labels', ex.get('eval_labels', None))
                if labels is not None and include_labels:
                    context.append(random.choice(labels))
                # generate observation from new example
                if preprocess:
                    ex = agent.observe(ex)
                    ex.pop('label_candidates', '')
                    ex['preprocessed'] = True
                num_eps += 1
                num_exs += 1
                pbar.update(1)
                ex_len = pytorch_data.write(
                    json.dumps(make_serializable(ex)) + "\n")
                idx_to_char.append(cumulative_char_len)
                cumulative_char_len += ex_len
            # reset
            episode_done = False
            current.clear()
            context.clear()
    pbar.close()
    with open(os.path.join(datapath, 'char_index'), 'w') as char_index:
        json.dump(idx_to_char, char_index)
    with open(os.path.join(datapath, 'data_length'), 'w') as pytorch_data_len:
        pytorch_data_len.write(
            json.dumps({
                'num_eps': num_eps,
                'num_exs': num_exs
            }))
    if dictionary:
        dictionary.save(get_pyt_dict_file(opt), sort=True)

    print('[ pytorch data built. ]')
    return datapath
def eval_wordstat(opt, print_parser=None):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    :param print_parser: if provided, prints the options that are set within the
        model after loading the model
    """
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(
            opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    batch_size = opt['batchsize']

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    cnt = 0
    word_statistics = {
        'mean_wlength': [],
        'mean_clength': [],
        'freqs_cnt': Counter(),
        'word_cnt': 0,
        'pred_list': [],
        'pure_pred_list': [],
        'context_list': [],
        'unique_words': set(),
    }
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    def process_prediction(prediction, word_statistics):
        normalized = normalize_answer(prediction)
        word_statistics['pred_list'].append(normalized)
        freqs, _cnt, wlength, clength = get_word_stats(prediction,
                                                       dictionary,
                                                       bins=bins)
        word_statistics['word_cnt'] += _cnt
        word_statistics['mean_wlength'].append(wlength)
        word_statistics['mean_clength'].append(clength)
        word_statistics['freqs_cnt'] += Counter(freqs)
        word_statistics['unique_words'] |= set(normalized.split(" "))
        return word_statistics

    while not world.epoch_done():
        world.parley()
        if batch_size == 1:
            cnt += 1
            prediction = world.acts[-1]['text']
            word_statistics['context_list'].append(world.acts[0]['text'])
            word_statistics['pure_pred_list'].append(prediction)
            word_statistics = process_prediction(prediction, word_statistics)
        else:
            for w in world.worlds:
                try:
                    if 'text' not in w.acts[-1]:
                        continue
                    prediction = w.acts[-1]['text']
                    word_statistics['context_list'].append(w.acts[0]['text'])
                    word_statistics['pure_pred_list'].append(prediction)
                except IndexError:
                    continue
                cnt += 1
                word_statistics = process_prediction(prediction,
                                                     word_statistics)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(),
                                        report)
            print(text)
            stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
            stat_str += ', '.join([
                '<{}:{} ({:.{prec}f}%)'.format(
                    b,
                    word_statistics['freqs_cnt'].get(b, 0),
                    (word_statistics['freqs_cnt'].get(b, 0) /
                     word_statistics['word_cnt']) * 100,
                    prec=2,
                ) for b in bins
            ])
            print("Word statistics: {}, avg_word_length: {:.{prec}f}, "
                  "avg_char_length: {:.{prec}f}".format(
                      stat_str,
                      numpy.array(word_statistics['mean_wlength']).mean(),
                      numpy.array(word_statistics['mean_clength']).mean(),
                      prec=2,
                  ))
        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")

    if opt['compute_unique'] is True:
        unique_list = []
        cntr = Counter(word_statistics['pred_list'])
        for k, v in cntr.items():
            if v == 1:
                unique_list.append(k)
        print("Unique responses: {:.{prec}f}%".format(
            len(unique_list) / len(word_statistics['pred_list']) * 100,
            prec=2))
    print("Total unique tokens:", len(word_statistics['unique_words']))

    if opt['dump_predictions_path'] is not None:
        with open(opt['dump_predictions_path'], 'w') as f:
            f.writelines([
                'CONTEXT: {}\nPREDICTION:{}\n\n'.format(c, p) for c, p in zip(
                    word_statistics['context_list'],
                    word_statistics['pure_pred_list'],
                )
            ])
        if opt['compute_unique'] is True:
            with open(opt['dump_predictions_path'] + '_unique', 'w') as f:
                f.writelines(['{}\n'.format(i) for i in unique_list])

    stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
    stat_str += ', '.join([
        '<{}:{} ({:.{prec}f}%)'.format(
            b,
            word_statistics['freqs_cnt'].get(b, 0),
            (word_statistics['freqs_cnt'].get(b, 0) /
             word_statistics['word_cnt']) * 100,
            prec=2,
        ) for b in bins
    ])
    print("Word statistics: {}, avg_word_length: {:.{prec}f}, "
          "avg_char_length: {:.{prec}f}".format(
              stat_str,
              numpy.array(word_statistics['mean_wlength']).mean(),
              numpy.array(word_statistics['mean_clength']).mean(),
              prec=2,
          ))

    report = world.report()
    print(report)
    return report
Example #38
0
def self_chat(opt):
    random.seed(opt['seed'])
    partner = opt['partner_model_file']
    partner_opt_file = opt.get('partner_opt_file')

    # Create agents
    agent1 = create_agent(opt, requireModelExists=True)
    agent1.opt.log("Agent 1 Opt")
    if partner is None:
        # Self chat with same model
        agent2 = agent1.clone()
    else:
        # Self chat with different models
        if partner_opt_file:
            print(f"WARNING: Loading override opts from: {partner_opt_file}")
            with PathManager.open(partner_opt_file) as f:
                partner_opt = json.load(f)
        else:
            partner_opt = {}
        partner_opt['interactive_mode'] = opt.get('interactive_mode', True)
        print(
            f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}"
        )
        agent2 = create_agent_from_model_file(partner, partner_opt)
        agent2.opt.log("Agent 2 Opt")

    # Set IDs
    agent1.id = agent1.id + "_1"
    agent2.id = agent2.id + "_2"

    model_id = agent1.id + "_" + agent2.id

    world = create_task(opt, user_agents=[agent1, agent2])

    # Set up world logging
    logger = WorldLogger(opt)
    log_time = TimeLogger()

    # Run some self chats.
    for i in range(opt['num_self_chats']):
        _run_self_chat_episode(opt, world, logger)
        report = world.report()
        text, report = log_time.log(i + 1, opt['num_self_chats'], report)
        logging.info(text)

    # Save chats
    if opt['outfile'] is None:
        outfile = '/tmp/{}_selfchat'.format(model_id)
    else:
        outfile = opt['outfile']

    if opt['save_format'] == 'conversations' and hasattr(world, 'write'):
        # use self chat specific world to write conversation
        # this might be useful for logging extra contextual
        # information (like personas)
        world.write(logger, outfile)
    else:
        # use default logger write function
        logger.write(outfile, world, opt['save_format'])

    return logger.get_logs()
Example #39
0
    def train(self):
        """
        Perform a training run.

        :return: tuple of reports (validation_report, test_report)
        """
        logging.info('training...')
        opt = self.opt
        world = self.world
        with world:
            while True:
                # do one example / batch of examples
                try:
                    world.parley()
                except StopTrainException as e:
                    logging.info(f"Stopping from {e}")
                    break

                self.parleys += 1

                # get the total training examples done, compute epochs
                self._total_epochs = self._preempted_epochs + sum(
                    all_gather_list(world.get_total_epochs()))
                exs_per_epoch = world.num_examples()
                self._total_exs = int(
                    np.round(self._total_epochs * exs_per_epoch))
                # and use the primary worker's timings for everything
                train_time, log_time, validate_time = sync_object((
                    self.train_time.time(),
                    self.log_time.time(),
                    self.validate_time.time(),
                ))

                # check counters and timers
                if self._total_epochs >= self.max_num_epochs:
                    self.log()
                    logging.info(
                        f'num_epochs completed:{self.max_num_epochs} time elapsed:{train_time}s'
                    )
                    break
                if train_time > self.max_train_time:
                    logging.info(f'max_train_time elapsed:{train_time}s')
                    break
                if log_time > self.log_every_n_secs:
                    self.log()
                if (validate_time > self.val_every_n_secs
                        or self._total_epochs - self.last_valid_epoch >=
                        self.val_every_n_epochs):
                    try:
                        # log before we validate
                        self.log()
                        world.reset_metrics()
                        stop_training = self.validate()
                    except StopTrainException:
                        break
                    # reset the log time because we logged right before validating
                    self.log_time.reset()
                    self.last_valid_epoch = self._total_epochs
                    if stop_training:
                        break
                    # make sure metrics are clean before we log
                    world.reset_metrics()
                if (self.save_time.time() > self.save_every_n_secs
                        and opt.get('model_file') and is_primary_worker()):
                    logging.info(
                        f"saving model checkpoint: {opt['model_file']}.checkpoint"
                    )
                    if opt['tensorboard_log'] and is_primary_worker():
                        self.tb_logger.flush()
                    self.save_model('.checkpoint')
                    self.save_time.reset()

        if not self.saved and is_primary_worker():
            # save agent
            self.save_model()

        # there's a rare edge case where the we never saved the model, and we try
        # # to reload it. This sync_object ensures all workers wait for the primary
        # worker to finish flushing before loading from disk.
        sync_object(None)
        if opt.get('model_file'):
            # clean up all our memory, just to make sure we don't OOM on GPU when
            # reloading the world
            del world
            del self.world
            del self.agent
            del self.valid_worlds
            # reload best validation model
            self.agent = create_agent(opt)

        # perform final validation/testing
        valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
        max_exs = opt['validation_max_exs'] if opt.get(
            'short_final_eval') else -1
        v_report = self._run_eval(valid_worlds,
                                  opt,
                                  'valid',
                                  max_exs,
                                  write_log=True)
        test_worlds = load_eval_worlds(self.agent, opt, 'test')
        t_report = self._run_eval(test_worlds,
                                  opt,
                                  'test',
                                  max_exs,
                                  write_log=True)
        if valid_worlds:
            for valid_world in valid_worlds:
                valid_world.shutdown()
        if test_worlds:
            for test_world in test_worlds:
                test_world.shutdown()

        print_announcements(opt)

        return v_report, t_report
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Create ConvAI2 data so we can assign personas.
    convai2_opt = opt.copy()
    convai2_opt['task'] = 'convai2:both'
    convai2_agent = RepeatLabelAgent(convai2_opt)
    convai2_world = create_task(convai2_opt, convai2_agent)

    def get_new_personas():
        # Find a new episode
        while True:
            convai2_world.parley()
            msg = convai2_world.get_acts()[0]
            if msg['episode_done']:
                convai2_world.parley()
                msg = convai2_world.get_acts()[0]
                break
        txt = msg.get('text', '').split('\n')
        bot_persona = ""
        for t in txt:
            if t.startswith("partner's persona:"):
                print(t.replace("partner's ", 'your '))
            if t.startswith('your persona:'):
                bot_persona += t + '\n'
        print("Enter [DONE] if you want a new partner at any time.")
        return bot_persona

    # Now run interactive mode, chatting with personas!
    cnt = 0
    while True:
        if cnt == 0:
            bot_persona = get_new_personas()
        # Run the parts of world.parley() in turn,
        # but insert persona into user message.
        acts = world.acts
        agents = world.agents
        acts[0] = agents[0].act()
        # add the persona on to the first message
        if cnt == 0:
            acts[0].force_set('text', bot_persona + acts[0].get('text', 'hi'))
        agents[1].observe(acts[0])
        acts[1] = agents[1].act()
        agents[0].observe(acts[1])
        world.update_counters()
        cnt = cnt + 1

        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world.episode_done():
            print("CHAT DONE ")
            print("In case you were curious you were talking to this bot:")
            print(bot_persona.split('\n'))
            if not world.epoch_done():
                print("\n... preparing new chat... \n")
            cnt = 0
Example #41
0
def main():
    """This task consists of an MTurk agent evaluating a Controllable Dialog model.
    """
    start_time = datetime.datetime.today().strftime('%Y-%m-%d-%H-%M')
    argparser = ParlaiParser(False, add_model_args=True)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.add_argument(
        '--max-resp-time',
        default=240,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument(
        '--max-choice-time',
        type=int,
        default=300,
        help='time limit for turker'
        'choosing the topic',
    )
    argparser.add_argument(
        '--ag-shutdown-time',
        default=120,
        type=int,
        help='time limit for entering a dialog message',
    )
    argparser.add_argument('--num-turns',
                           default=6,
                           type=int,
                           help='number of turns of dialogue')
    argparser.add_argument(
        '--human-eval',
        type='bool',
        default=False,
        help='human vs human eval, no models involved',
    )
    argparser.add_argument(
        '--auto-approve-delay',
        type=int,
        default=3600 * 24 * 2,
        help='how long to wait for auto approval',
    )
    argparser.add_argument(
        '--only-masters',
        type='bool',
        default=False,
        help='Set to true to use only master turks for '
        'this test eval',
    )
    argparser.add_argument(
        '--create-model-qualif',
        type='bool',
        default=True,
        help='Create model qualif so unique eval between'
        'models.',
    )
    argparser.add_argument(
        '--limit-workers',
        type=int,
        default=len(SETTINGS_TO_RUN),
        help='max HITs a worker can complete',
    )
    argparser.add_argument(
        '--mturk-log',
        type=str,
        default=('data/mturklogs/controllable/{}.log'.format(start_time)),
    )
    argparser.add_argument(
        '--short-eval',
        type='bool',
        default=True,
        help='Only ask engagingness question and persona'
        'question.',
    )
    # persona specific arguments
    argparser.add_argument('--persona-type',
                           type=str,
                           default='self',
                           choices=['self', 'other', 'none'])
    argparser.add_argument(
        '--persona-datatype',
        type=str,
        default='valid',
        choices=['train', 'test', 'valid'],
    )
    argparser.add_argument('--max-persona-time',
                           type=int,
                           default=360,
                           help='max time to view persona')

    def get_logger(opt):
        fmt = '%(asctime)s: [ %(message)s ]'
        logfn = None
        if 'mturk_log' in opt:
            logfn = opt['mturk_log']
            if not os.path.isdir(os.path.dirname(logfn)):
                os.makedirs(os.path.dirname(logfn), exist_ok=True)
        logger = ParlaiLogger(
            name="mturk_controllable",
            console_level=INFO,
            file_level=INFO,
            console_format=fmt,
            file_format=fmt,
            filename=logfn,
        )
        logger.info('COMMAND: %s' % ' '.join(sys.argv))
        logger.info('-' * 100)
        logger.info('CONFIG:\n%s' % json.dumps(opt, indent=4, sort_keys=True))

        return logger

    start_opt = argparser.parse_args()

    task_config['task_description'] = task_config['task_description'].format(
        start_opt['reward'])

    # set options
    start_opt['limit_workers'] = len(SETTINGS_TO_RUN)
    start_opt['allowed_conversations'] = 1
    start_opt['max_hits_per_worker'] = start_opt['limit_workers']
    start_opt['task'] = os.path.basename(
        os.path.dirname(os.path.abspath(__file__)))

    start_opt.update(task_config)

    logger = get_logger(start_opt)

    model_share_params = {}
    worker_models_seen = {}
    model_opts = {}
    model_counts = {}

    lock = Lock()

    for setup in SETTINGS_TO_RUN:
        assert 'human' not in setup
        model_counts[setup] = 0
        agent_config = getattr(mcf, setup)
        combined_config = copy.deepcopy(start_opt)
        for k, v in agent_config.items():
            combined_config[k] = v
            combined_config['override'][k] = v
        folder_name = '{}-{}'.format(setup, start_time)
        combined_config['save_data_path'] = os.path.join(
            start_opt['datapath'], 'local_controllable_dialogue', folder_name)
        model_opts[setup] = combined_config
        bot = create_agent(combined_config, True)
        model_share_params[setup] = bot.share()

    if not start_opt.get('human_eval'):
        mturk_agent_ids = ['PERSON_1']
    else:
        mturk_agent_ids = ['PERSON_1', 'PERSON_2']

    mturk_manager = MTurkManager(opt=start_opt,
                                 mturk_agent_ids=mturk_agent_ids)

    personas_generator = PersonasGenerator(start_opt)

    directory_path = os.path.dirname(os.path.abspath(__file__))

    mturk_manager.setup_server(task_directory_path=directory_path)

    try:
        mturk_manager.start_new_run()
        agent_qualifications = []
        # assign qualifications
        if start_opt['create_model_qualif']:
            qual_name = 'ControlEvalRound2'
            qual_desc = (
                'Qualification to ensure workers complete only a certain'
                'number of these HITs')
            qualification_id = mturk_utils.find_or_create_qualification(
                qual_name, qual_desc, False)
            print('Created qualification: ', qualification_id)
            start_opt['unique_qualif_id'] = qualification_id

        def run_onboard(worker):
            worker.personas_generator = personas_generator
            world = PersonaAssignWorld(start_opt, worker)
            world.parley()
            world.shutdown()

        def check_worker_eligibility(worker):
            worker_id = worker.worker_id
            lock.acquire()
            retval = len(worker_models_seen.get(worker_id,
                                                [])) < len(SETTINGS_TO_RUN)
            lock.release()
            return retval

        def assign_worker_roles(workers):
            for index, worker in enumerate(workers):
                worker.id = mturk_agent_ids[index % len(mturk_agent_ids)]

        mturk_manager.set_onboard_function(onboard_function=run_onboard)
        mturk_manager.ready_to_accept_workers()
        mturk_manager.create_hits(qualifications=agent_qualifications)

        def run_conversation(mturk_manager, opt, workers):
            conv_idx = mturk_manager.conversation_index

            # gotta find a bot this worker hasn't seen yet
            assert len(workers) == 1
            worker_id = workers[0].worker_id
            lock.acquire()
            if worker_id not in worker_models_seen:
                worker_models_seen[worker_id] = set()
            print("MODELCOUNTS:")
            print(pprint.pformat(model_counts))
            logger.info("MODELCOUNTS\n" + pprint.pformat(model_counts))
            model_options = [
                (model_counts[setup_name] + 10 * random.random(), setup_name)
                for setup_name in SETTINGS_TO_RUN
                if setup_name not in worker_models_seen[worker_id]
            ]
            if not model_options:
                lock.release()
                logger.error(
                    "Worker {} already finished all settings! Returning none".
                    format(worker_id))
                return None
            _, model_choice = min(model_options)

            worker_models_seen[worker_id].add(model_choice)
            model_counts[model_choice] += 1
            lock.release()

            world = ControllableDialogEval(
                opt=model_opts[model_choice],
                agents=workers,
                num_turns=start_opt['num_turns'],
                max_resp_time=start_opt['max_resp_time'],
                model_agent_opt=model_share_params[model_choice],
                world_tag='conversation t_{}'.format(conv_idx),
                agent_timeout_shutdown=opt['ag_shutdown_time'],
                model_config=model_choice,
            )
            world.reset_random()
            while not world.episode_done():
                world.parley()
            world.save_data()

            lock.acquire()
            if not world.convo_finished:
                model_counts[model_choice] -= 1
                worker_models_seen[worker_id].remove(model_choice)
            lock.release()

            world.shutdown()
            gc.collect()

        mturk_manager.start_task(
            eligibility_function=check_worker_eligibility,
            assign_role_function=assign_worker_roles,
            task_function=run_conversation,
        )

    except BaseException:
        raise
    finally:
        mturk_manager.expire_all_unassigned_hits()
        mturk_manager.shutdown()
Example #42
0
def detect(opt, printargs=None, print_parser=None):
    """
    Checks a task for offensive language.
    """
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
        offensive_string_matcher = OffensiveStringMatcher()
    if opt['safety'] == 'classifier' or opt['safety'] == 'all':
        offensive_classifier = OffensiveLanguageClassifier()

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    stats = {
        'bad_words': [],
        'bad_words_cnt': 0,
        'string_offensive': 0,
        'classifier_offensive': 0,
        'total_offensive': 0,
        'total': 0,
    }

    def report(world, stats):
        report = world.report()
        log = {
            'word_offenses':
            stats['bad_words_cnt'],
            'classifier_offenses%':
            100 * (stats['classifier_offensive'] / stats['total']),
            'string_offenses%':
            100 * (stats['string_offensive'] / stats['total']),
            'total_offenses%':
            100 * (stats['total_offensive'] / stats['total']),
        }
        text, log = log_time.log(report['exs'], world.num_examples(), log)
        print(text)

    def classify(text, stats):
        offensive = False
        stats['total'] += 1
        if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
            bad_words = offensive_string_matcher.contains_offensive_language(
                text)
            if bad_words:
                stats['string_offensive'] += 1
                offensive = True
                stats['bad_words'].append(bad_words)
        if opt['safety'] == 'classifier' or opt['safety'] == 'all':
            if text in offensive_classifier:
                stats['classifier_offensive'] += 1
                offensive = True
        if offensive:
            stats['total_offensive'] += 1

    while not world.epoch_done():
        world.parley()
        stats['bad_words'] = []
        for a in world.acts:
            text = a.get('text', '')
            classify(text, stats)
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                classify(l, stats)
        if len(stats['bad_words']) > 0 and opt['display_examples']:
            print(world.display())
            print("[Offensive words detected:]", ', '.join(stats['bad_words']))
            print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        stats['bad_words_cnt'] += len(stats['bad_words'])
        if log_time.time() > log_every_n_secs:
            report(world, stats)

    if world.epoch_done():
        print("EPOCH DONE")
    report(world, stats)
    return world.report()