Пример #1
0
    def run(self):
        opt = self.opt

        world = self._setup_world()
        logger = TodWorldLogger(opt)

        # set up logging
        log_every_n_secs = opt.get("log_every_n_secs", -1)
        if log_every_n_secs <= 0:
            log_every_n_secs = float("inf")
        log_time = TimeLogger()

        # episode counter
        max_episodes = opt.get("num_episodes", -1)
        if max_episodes < 0:
            max_episodes = float("inf")
        world_num_episodes = world.num_episodes()
        if world_num_episodes > 0:
            max_episodes = min(max_episodes, world_num_episodes)

        ep_count = 0
        episode_metrics = []
        while not world.epoch_done() and ep_count < max_episodes:
            episode_metrics.extend(self._run_episode(opt, world, logger))
            ep_count += opt.get("batchsize", 1)
            if log_time.time() > log_every_n_secs:
                report = world.report()
                text, report = log_time.log(ep_count, max_episodes, report)
                logging.info(text)

        return self._save_outputs(opt, world, logger, episode_metrics)
def _eval_single_world(opt, agent, task):
    print('[ Evaluating task {} using datatype {}. ] '.format(
        task, opt.get('datatype', 'N/A')))
    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(),
                                        report)
            print(text)

    report = world.report()
    world.reset()
    return report
Пример #3
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print('[ Deprecated Warning: self_chat should be passed opt not Parser ]')
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    # Check for `selfchat` in the task name
    if 'selfchat' not in opt['task']:
        warn_once(
            'You are using self chat with task {}. '.format(opt['task'])
            + 'If your task has an existing self chat world, then run with '
            '-t {}:selfchat'.format(opt['task'])
        )

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
            print(world.display())
        if log_time.time() > log_every_n_secs:
            text = log_time.log(cnt, max_cnt)
            print(text)

    if opt.get('display_examples'):
        print('-- end of episode --')

    logger.reset_world()  # flush last episode
    logger.write(opt['outfile'], opt['format'])
    return logger.get_logs()
def detect(opt, printargs=None, print_parser=None):
    """
    Checks a task for offensive language.
    """
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    bad = OffensiveStringMatcher()

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        world.parley()
        words = []
        for a in world.acts:
            offensive = bad.contains_offensive_language(a.get('text', ''))
            if offensive:
                words.append(offensive)
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                offensive = bad.contains_offensive_language(l)
                if offensive:
                    words.append(offensive)
        if len(words) > 0 and opt['display_examples']:
            print(world.display())
            print("[Offensive words detected:]", ', '.join(words))
            print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        cnt += len(words)
        if log_time.time() > log_every_n_secs:
            report = world.report()
            log = {'offenses': cnt}
            text, log = log_time.log(report['exs'], world.num_examples(), log)
            print(text)

    if world.epoch_done():
        print("EPOCH DONE")
    print(
        str(cnt) + " offensive messages found out of " +
        str(world.num_examples()) + " messages.")
    return world.report()
Пример #5
0
def get_word_counts(opt, count_inputs):
    """
    Goes through the dataset specified in opt, returns word counts and all utterances.

    Inputs:
      count_inputs: If True, include both input and reply when counting words and
        utterances. Otherwise, only include reply text.

    Returns:
      word_counter: a Counter mapping each word to the total number of times it appears
      total_count: int. total word count, i.e. the sum of the counts for each word
      all_utts: list of strings. all the utterances that were used for counting words
    """
    # Create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Count word frequency for all words in dataset
    word_counter = Counter()
    total_count = 0
    all_utts = []
    log_timer = TimeLogger()
    while True:
        world.parley()

        # Count words in reply
        reply = world.acts[0].get('labels',
                                  world.acts[0].get('eval_labels'))[0]
        words = reply.split()
        word_counter.update(words)
        total_count += len(words)
        all_utts.append(reply)

        # Optionally count words in input text
        if count_inputs:
            input = world.acts[0]['text']
            input = input.split('\n')[
                -1]  # e.g. in ConvAI2, this removes persona
            words = input.split()
            word_counter.update(words)
            total_count += len(words)
            all_utts.append(input)

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)

        if world.epoch_done():
            print('EPOCH DONE')
            break

    assert total_count == sum(word_counter.values())

    return word_counter, total_count, all_utts
Пример #6
0
def _eval_single_world(opt, agent, task):
    print('[ Evaluating task {} using datatype {}. ] '.format(
        task, opt.get('datatype', 'N/A')))
    # set up world logger
    world_logger = WorldLogger(opt) if opt['save_world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    # task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
            # for a in world.acts:
            # print (a)
            # print (world.get_acts())
            # print (world.acts)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        world.num_examples(), report)
            print(text)

    report = world.report()
    print("Printing Report")
    print(report)
    world.reset()

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        base_outfile = opt['report_filename'].split('.')[0]
        print("filename: ", base_outfile)
        outfile = base_outfile + f'_{task}_replies.jsonl'
        # world_logger.write_jsonl_format(outfile)
        world_logger.write_parlai_format(outfile)

    return report
def get_word_counts(opt, count_inputs):
    """
    Goes through the dataset specified in opt and gets word counts.

    Inputs:
      count_inputs: If True, include both input and reply when counting words
        and utterances. Otherwise, only include reply text.

    Returns:
      word_counter_per_sent: a Counter mapping each word to the number of
        utterances in which it appears.
      num_sents: int. number of utterances counted
    """
    # Create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Count word frequency for all words in dataset
    word_counter_per_sent = Counter()
    num_sents = 0
    count = 0
    log_timer = TimeLogger()
    while True:
        count += 1

        world.parley()
        reply = world.acts[0].get('labels',
                                  world.acts[0].get('eval_labels'))[0]

        words = reply.split()
        words_no_dups = list(set(words))  # remove duplicates
        word_counter_per_sent.update(words_no_dups)
        num_sents += 1

        # Optionally count words in input text
        if count_inputs:
            input = world.acts[0]['text']
            input = input.split('\n')[
                -1]  # e.g. in ConvAI2, this removes persona
            words = input.split()
            words_no_dups = list(set(words))  # remove duplicates
            word_counter_per_sent.update(words_no_dups)
            num_sents += 1

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)

        if world.epoch_done():
            print('EPOCH DONE')
            break

    return word_counter_per_sent, num_sents
Пример #8
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    world_logger = WorldLogger(opt) if opt['world_logs'] else None

    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))
    world.reset()

    return report
Пример #9
0
def build_cands(opt):
    # create repeat label agent and assign it to the specified task
    if opt['numthreads'] > 1:
        # Broken in hogwild mode. Just fall back to single processing mode
        opt['numthreads'] = 1
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(prefix='{}_{}_'.format(
            opt['task'], opt['datatype']),
                                   suffix='.txt')[1]
    else:
        outfile = opt['outfile']

    if opt.get('num_examples', -1) == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    print('[ starting to build candidates from task.. (ex:' +
          str(num_examples) + ')]')
    print('[ saving output to {} ]'.format(outfile))
    cands = set()
    for _ in range(num_examples):
        world.parley()
        # We get the acts of the first agent, which is the teacher.
        # this part is modified to get all utterances
        for acts in world.get_acts():
            acts = world.get_acts()[0]
            if isinstance(acts, dict):
                # We turn into a batch of 1 example, in case batching is being used.
                acts = [acts]
            for a in acts:
                candidate = a.get('labels', a.get('eval_labels', None))
                if candidate is not None:
                    candidate = candidate[0]
                    cands.add(candidate)

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)
        if world.epoch_done():
            print('EPOCH DONE')
            break
    fw = open(outfile, 'w')
    fw.write('\n'.join(cands))
    fw.close()
Пример #10
0
def self_chat(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: self_chat should be passed opt not Parser ]'
        )
        opt = opt.parse_args()

    random.seed(opt['seed'])
    # Create models
    agent1 = create_agent(opt, requireModelExists=True)
    agent2 = agent1.clone()
    if hasattr(agent2, 'id'):
        agent2.id = agent2.id + "2"

    world = create_task(opt, [agent1, agent2])

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent1.opt
        print_parser.print_args()

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()
    logger = WorldLogger(opt)

    # Run some self chats.
    max_cnt = opt['num_examples']
    cnt = 0
    while cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        logger.log(world)

        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if log_time.time() > log_every_n_secs:
            text = log_time.log(cnt, max_cnt)
            print(text)

    logger.write(opt['outfile'], opt['format'])
Пример #11
0
    def run(self):
        self.opt['no_cuda'] = True
        if 'ordered' not in self.opt['datatype'] and 'train' in self.opt[
                'datatype']:
            self.opt['datatype'] = self.opt['datatype'] + ':ordered'
        agent = create_agent(self.opt)
        agent.opt.log()
        num_examples = self.opt['num_examples']
        field = self.opt['field'] + '_vec'
        if num_examples < 0:
            num_examples = float('inf')
        assert self.opt['batchsize'] == 1
        assert isinstance(agent, TorchAgent)

        world = create_task(self.opt, agent)
        teacher = world.get_task_agent()

        # set up logging
        log_every_n_secs = self.opt.get('log_every_n_secs', -1)
        if log_every_n_secs <= 0:
            log_every_n_secs = float('inf')
        log_time = TimeLogger()

        lengths = []

        cnt = 0
        total = min(teacher.num_examples(), num_examples)
        while not teacher.epoch_done() and cnt < num_examples:
            act = teacher.act()
            processed = agent.observe(act)
            try:
                text_vec = processed[field]
            except KeyError:
                raise KeyError(f"Pick one of {list(processed.keys())}")
            if text_vec is not None and (not self.opt['final_only']
                                         or act.get('episode_done')):
                cnt += 1
                lengths.append(float(len(text_vec)))
            agent.self_observe({})

            if log_time.time() > log_every_n_secs:
                report = self._compute_stats(lengths)
                text, report = log_time.log(report['exs'], total, report)
                logging.info(text)

        report = self._compute_stats(lengths)
        print(nice_report(report))
        return report
Пример #12
0
    def run_generation(self):
        """
        Actually run the evaluations.
        """
        # set up logging
        log_every_n_secs = self.opt.get('log_every_n_secs', -1)
        if log_every_n_secs <= 0:
            log_every_n_secs = float('inf')
        log_time = TimeLogger()

        # max number of examples to evaluate
        max_cnt = (
            self.opt['num_examples'] if self.opt['num_examples'] > 0 else float('inf')
        )
        self.cnt = 0
        self.n_valid = 0
        self.log_count = 0
        total_cnt = self.world.num_examples()

        while not self.world.epoch_done() and self.cnt < max_cnt:
            self.cnt += self.opt.get('batchsize', 1)
            self.world.parley()
            acts = self.world.get_acts()
            if acts[-1]['text'] != INVALID:
                try:
                    self.world.acts[0]['text'] += f"\n{acts[-1]['knowledge']}"
                except RuntimeError:
                    self.world.acts[0].force_set(
                        'text', f"{self.world.acts[0]['text']}\n{acts[-1]['knowledge']}"
                    )
                self.world.acts[0]['f1_overlap'] = acts[-1]['f1_overlap']
                self.world_logger.log(self.world)
                self.n_valid += 1
                if (
                    self.n_valid > 0
                    and self.n_valid % self.opt['write_every_n_valid_exs'] == 0
                ):
                    self.log()
            if log_time.time() > log_every_n_secs:
                report = self.world.report()
                report['n_valid'] = self.n_valid
                text, report = log_time.log(
                    report.get('exs', 0), min(max_cnt, total_cnt), report
                )
                logging.info(text)
Пример #13
0
def build_cands(opt):
    opt.log()
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(prefix='{}_{}_'.format(
            opt['task'], opt['datatype']),
                                   suffix='.txt')[1]
    else:
        outfile = opt['outfile']

    if opt.get('num_examples', -1) == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    logging.info(
        f'Starting to build candidates from task.. (ex: {num_examples})')
    logging.info(f'Saving output to {outfile}')
    cands = set()
    for _ in range(num_examples):
        world.parley()
        # We get the acts of the first agent, which is the teacher.
        acts = world.get_acts()[0]
        if isinstance(acts, dict):
            # We turn into a batch of 1 example, in case batching is being used.
            acts = [acts]
        for a in acts:
            candidate = a.get('labels', a.get('eval_labels', None))
            if candidate is not None:
                candidate = candidate[0]
                cands.add(candidate)
        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            logging.info(text)
        if world.epoch_done():
            logging.info('epoch done')
            break
    fw = open(outfile, 'w')
    fw.write('\n'.join(cands))
    fw.close()
Пример #14
0
def dump_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    opt.log()
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(prefix='{}_{}_'.format(
            opt['task'], opt['datatype']),
                                   suffix='.txt')[1]
    else:
        outfile = opt['outfile']

    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    logging.debug('starting to convert...')
    logging.info(f'saving output to {outfile}')
    fw = open(outfile, 'w')
    text = ''
    for _ in range(num_examples):
        world.parley()
        world.acts[0]['labels'] = world.acts[0].get(
            'labels', world.acts[0].pop('eval_labels', None))

        samp = world.acts[0]
        text += samp["text"].replace("\n", " ") + " "
        fw.write("__label__%s %s\n" %
                 (samp["labels"][0].replace(' ', '_'), text))
        if world.acts[0].get('episode_done', False):
            text = ''

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            logging.info(text)

        if world.epoch_done():
            logging.info('epoch done')
            break
    fw.close()
Пример #15
0
def dump_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    opt.log()
    ignorefields = opt.get('ignore_fields', '')
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(prefix='{}_{}_'.format(
            opt['task'], opt['datatype']),
                                   suffix='.txt')[1]
    else:
        outfile = opt['outfile']

    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    logging.debug('starting to convert...')
    logging.info(f'saving output to {outfile}')
    fw = open(outfile, 'w')
    for _ in range(num_examples):
        world.parley()
        acts = world.get_acts()
        value = acts[0].get('labels', acts[0].pop('eval_labels', None))
        acts[0].force_set('labels', value)
        txt = msg_to_str(acts[0], ignore_fields=ignorefields)
        fw.write(txt + '\n')
        if acts[0].get('episode_done', False):
            fw.write('\n')

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            logging.info(text)

        if world.epoch_done():
            logging.info('epoch done')
            break
    fw.close()
Пример #16
0
def make_dataset(opt):

    # Initialize control information so we can compute sentence attributes.
    # Here we set build_task=False so we don't download data/controllable_dialogue
    # (because we're trying to create it instead).
    initialize_control_information(opt, build_task=False)

    # Create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    ignorefields = opt.get('ignore_fields', '')
    outfile = opt['outfile']

    # Number of examples to process
    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']

    # List of controls to include:
    controls = opt['controls'].split(',') if opt['controls'] != '' else []

    print('[ starting to convert.. ]')
    print('[ saving output to {} ]'.format(outfile))
    fw = open(outfile, 'w')
    log_timer = TimeLogger()

    for _ in range(num_examples):
        world.parley()
        world.acts[0]['labels'] = world.acts[0].get(
            'labels', world.acts[0].pop('eval_labels', None))

        # Need to get history in order to compute control values
        hist = ConvAI2History(world.acts[0]['text'], assume_persontokens=False)
        response = world.acts[0]['labels'][0]

        # Compute control values
        for ctrl in controls:
            ctrl_val = eval_attr(response, hist, ctrl)
            if ctrl == 'avg_nidf':
                assert ctrl_val >= 0
                assert ctrl_val <= 1
            elif ctrl == 'question':
                assert ctrl_val in [0, 1]
            elif ctrl == 'lastuttsim':
                if ctrl_val is not None:
                    assert ctrl_val >= -1
                    assert ctrl_val <= 1
            else:
                raise Exception('unexpected ctrl name: %s' % ctrl)
            world.acts[0][ctrl] = ctrl_val  # add control value to act

        # Write to file
        txt = msg_to_str(world.acts[0], ignore_fields=ignorefields)
        fw.write(txt + '\n')
        if world.acts[0].get('episode_done', False):
            fw.write('\n')

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)

        if world.epoch_done():
            print('EPOCH DONE')
            break
    fw.close()
Пример #17
0
def verify(opt, printargs=None, print_parser=None):
    if opt['datatype'] == 'train':
        logging.warn("changing datatype from train to train:ordered")
        opt['datatype'] = 'train:ordered'
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    counts = {}
    counts['missing_text'] = 0
    counts['missing_labels'] = 0
    counts['missing_label_candidates'] = 0
    counts['empty_string_label_candidates'] = 0
    counts['label_candidates_with_missing_label'] = 0
    counts['did_not_return_message'] = 0

    # Show some example dialogs.
    while not world.epoch_done():
        world.parley()

        act = world.acts[0]

        if not isinstance(act, Message):
            counts['did_not_return_message'] += 1

        if 'text' not in act and 'image' not in act:
            warn("warning: missing text field:\n", act, opt)
            counts['missing_text'] += 1

        if 'labels' not in act and 'eval_labels' not in act:
            warn("warning: missing labels/eval_labels field:\n", act, opt)
            counts['missing_labels'] += 1
        else:
            if 'label_candidates' not in act:
                counts['missing_label_candidates'] += 1
            else:
                labels = act.get('labels', act.get('eval_labels'))
                is_label_cand = {}
                for l in labels:
                    is_label_cand[l] = False
                for c in act['label_candidates']:
                    if c == '':
                        warn("warning: empty string label_candidate:\n", act,
                             opt)
                        counts['empty_string_label_candidates'] += 1
                    if c in is_label_cand:
                        if is_label_cand[c] is True:
                            warn(
                                "warning: label mentioned twice in candidate_labels:\n",
                                act,
                                opt,
                            )
                        is_label_cand[c] = True
                for _, has in is_label_cand.items():
                    if has is False:
                        warn("warning: label missing in candidate_labels:\n",
                             act, opt)
                        counts['label_candidates_with_missing_label'] += 1

        if log_time.time() > log_every_n_secs:
            text, log = report(world, counts, log_time)
            if print_parser:
                print(text)

    try:
        # print dataset size if available
        logging.info(f'Loaded {world.num_episodes()} episodes with a '
                     f'total of {world.num_examples()} examples')
    except Exception:
        pass

    return report(world, counts, log_time)
Пример #18
0
def eval_wordstat(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    """
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    agent.opt.log()

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(
            opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    batch_size = opt['batchsize']

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    cnt = 0
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    word_statistics = {
        'mean_wlength': [],
        'mean_clength': [],
        'freqs_cnt': Counter(),
        'word_cnt': 0,
        'pred_list': [],
        'pure_pred_list': [],
        'context_list': [],
        'unique_words': set(),
    }
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    def process_prediction(prediction, word_statistics):
        normalized = normalize_answer(prediction)
        word_statistics['pred_list'].append(normalized)
        freqs, _cnt, wlength, clength = get_word_stats(prediction,
                                                       dictionary,
                                                       bins=bins)
        word_statistics['word_cnt'] += _cnt
        word_statistics['mean_wlength'].append(wlength)
        word_statistics['mean_clength'].append(clength)
        word_statistics['freqs_cnt'] += Counter(freqs)
        word_statistics['unique_words'] |= set(normalized.split(" "))
        return word_statistics

    while not world.epoch_done():
        world.parley()
        if batch_size == 1:
            cnt += 1
            prediction = world.acts[-1]['text']
            word_statistics['context_list'].append(world.acts[0]['text'])
            word_statistics['pure_pred_list'].append(prediction)
            word_statistics = process_prediction(prediction, word_statistics)
        else:
            for w in world.worlds:
                try:
                    if 'text' not in w.acts[-1]:
                        continue
                    prediction = w.acts[-1]['text']
                    word_statistics['context_list'].append(w.acts[0]['text'])
                    word_statistics['pure_pred_list'].append(prediction)
                except IndexError:
                    continue
                cnt += 1
                word_statistics = process_prediction(prediction,
                                                     word_statistics)

        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'],
                                        min(max_cnt, world.num_examples()),
                                        report)
            print(text)
            stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
            stat_str += ', '.join([
                '<{}:{} ({:.{prec}f}%)'.format(
                    b,
                    word_statistics['freqs_cnt'].get(b, 0),
                    (word_statistics['freqs_cnt'].get(b, 0) /
                     word_statistics['word_cnt']) * 100,
                    prec=2,
                ) for b in bins
            ])
            print("Word statistics: {}, avg_word_length: {:.{prec}f}, "
                  "avg_char_length: {:.{prec}f}".format(
                      stat_str,
                      numpy.array(word_statistics['mean_wlength']).mean(),
                      numpy.array(word_statistics['mean_clength']).mean(),
                      prec=2,
                  ))
        if cnt >= max_cnt:
            break
    if world.epoch_done():
        print("EPOCH DONE")

    if opt['compute_unique'] is True:
        unique_list = []
        cntr = Counter(word_statistics['pred_list'])
        for k, v in cntr.items():
            if v == 1:
                unique_list.append(k)
        print("Unique responses: {:.{prec}f}%".format(
            len(unique_list) / len(word_statistics['pred_list']) * 100,
            prec=2))
    print("Total unique tokens:", len(word_statistics['unique_words']))

    if opt['dump_predictions_path'] is not None:
        with PathManager.open(opt['dump_predictions_path'], 'w') as f:
            f.writelines([
                'CONTEXT: {}\nPREDICTION:{}\n\n'.format(c, p) for c, p in zip(
                    word_statistics['context_list'],
                    word_statistics['pure_pred_list'],
                )
            ])
        if opt['compute_unique'] is True:
            with PathManager.open(opt['dump_predictions_path'] + '_unique',
                                  'w') as f:
                f.writelines(['{}\n'.format(i) for i in unique_list])

    stat_str = 'total_words: {}, '.format(word_statistics['word_cnt'])
    stat_str += ', '.join([
        '<{}:{} ({:.{prec}f}%)'.format(
            b,
            word_statistics['freqs_cnt'].get(b, 0),
            (word_statistics['freqs_cnt'].get(b, 0) /
             word_statistics['word_cnt']) * 100,
            prec=2,
        ) for b in bins
    ])
    print("Word statistics: {}, avg_word_length: {:.{prec}f}, "
          "avg_char_length: {:.{prec}f}".format(
              stat_str,
              numpy.array(word_statistics['mean_wlength']).mean(),
              numpy.array(word_statistics['mean_clength']).mean(),
              prec=2,
          ))

    report = world.report()
    print(report)
    return report
Пример #19
0
def _eval_single_world(opt, agent, task):
    logging.info(
        f'Evaluating task {task} using datatype {opt.get("datatype")}.')
    # set up world logger
    task_opt = opt.copy()  # copy opt since we're editing the task
    task_opt['task'] = task
    # add task suffix in case of multi-tasking
    if opt['world_logs']:
        task_opt['world_logs'] = get_task_world_logs(
            task,
            task_opt['world_logs'],
            is_multitask=len(opt['task'].split(',')) > 1)

    world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None

    world = create_task(task_opt, agent)  # create worlds for tasks

    # set up logging
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0
    total_cnt = world.num_examples()

    if is_distributed():
        logging.warning('Progress bar is approximate in distributed mode.')

    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        if world_logger is not None:
            world_logger.log(world)
        if opt['display_examples']:
            # display examples
            print(world.display() + '\n~~')
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report.get('exs', 0),
                                        min(max_cnt, total_cnt), report)
            logging.info(text)

    if world_logger is not None:
        # dump world acts to file
        world_logger.reset()  # add final acts to logs
        if is_distributed():
            rank = get_rank()
            base_outfile, extension = os.path.splitext(task_opt['world_logs'])
            outfile = base_outfile + f'_{rank}' + extension
        else:
            outfile = task_opt['world_logs']
        world_logger.write(outfile, world, file_format=opt['save_format'])

    report = aggregate_unnamed_reports(all_gather_list(world.report()))

    if isinstance(world.agents, list) and len(world.agents) > 1:
        classifier_agent = world.agents[CLASSIFIER_AGENT]
        if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
            for class_indices, curr_auc in zip(
                    classifier_agent.auc_class_indices, classifier_agent.aucs):
                report[
                    f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
            classifier_agent.reset_auc()
            # for safety measures
            agent.reset_auc()
    world.reset()
    return report
Пример #20
0
def verify(opt):
    if opt['datatype'] == 'train':
        logging.warn('changing datatype from train to train:ordered')
        opt['datatype'] = 'train:ordered'

    # create repeat label agent and assign it to the specified task
    opt['fixed_response'] = None
    agent = FixedResponseAgent(opt)
    world = create_task(opt, agent)
    opt.log()

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    dictionary = DictionaryAgent(opt)
    ignore_tokens = opt.get('ignore_tokens').split(',')

    counts = {}
    for t in {'input', 'labels', 'both'}:
        counts[f'{t}/tokens'] = 0
        counts[f'{t}/utterances'] = 0
        counts[f'{t}/avg_utterance_length'] = None
        counts[f'{t}/unique_tokens'] = 0
        counts[f'{t}/unique_utterances'] = 0
        # for counting the stats..
        counts[f'{t}/token_dict'] = {}
        counts[f'{t}/utterance_dict'] = {}

    def tokenize(txt):
        return dictionary.tokenize(txt)

    def keep_token(t):
        for s in ignore_tokens:
            if s != '' and s in t:
                return False
        return True

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    # Show some example dialogs.
    while not world.epoch_done() and world.total_exs < max_cnt:
        world.parley()
        act = world.get_acts()[opt.get('agent')]
        for itype in {'input', 'labels'}:
            if itype == 'input':
                if opt.get('new_line_new_utt'):
                    txts = act.get('text').split('\n')
                else:
                    txts = [act.get('text')]
            else:
                txts = act.get('labels', act.get('eval_labels', ['']))

            for txt in txts:
                tokens = tokenize(txt)
                retxt = [t for t in tokens if keep_token(t)]
                counts[f'{itype}/tokens'] += len(retxt)
                counts['both/tokens'] += len(retxt)
                counts[f'{itype}/utterances'] += 1
                counts['both/utterances'] += 1
                counts[f'{itype}/avg_utterance_length'] += AverageMetric(
                    len(retxt), 1)
                counts[f'both/avg_utterance_length'] += AverageMetric(
                    len(retxt), 1)
                for t in retxt:
                    if t not in counts[f'{itype}/token_dict']:
                        counts[f'{itype}/unique_tokens'] += 1
                        counts[f'{itype}/token_dict'][t] = True
                    if t not in counts['both/token_dict']:
                        counts['both/unique_tokens'] += 1
                        counts['both/token_dict'][t] = True
                retxt = ' '.join(retxt)
                if retxt not in counts[f'{itype}/utterance_dict']:
                    counts[f'{itype}/unique_utterances'] += 1
                    counts[f'{itype}/utterance_dict'][retxt] = True
                if retxt not in counts['both/utterance_dict']:
                    counts['both/unique_utterances'] += 1
                    counts['both/utterance_dict'][retxt] = True

        if log_time.time() > log_every_n_secs:
            report = _report(world, counts)
            cnt = report.pop('exs')
            text, log = log_time.log(cnt, world.num_examples(), report)
            logging.info(text)

    try:
        # print dataset size if available
        logging.info(f'loaded {world.num_episodes()} episodes with a total '
                     f'of {world.num_examples()} examples')
    except AttributeError:
        pass

    retval = _report(world, counts)
    retval.pop('exs')
    return retval
Пример #21
0
def learn_arora(opt):
    """
    Go through ConvAI2 data and collect word counts, thus compute the unigram
    probability distribution. Use those probs to compute weighted sentence embeddings
    for all utterances, thus compute first principal component.

    Save all info to arora.pkl file.
    """
    arora_file = os.path.join(opt['datapath'], 'controllable_dialogue',
                              'arora.pkl')

    opt['task'] = 'fromfile:parlaiformat'
    opt['log_every_n_secs'] = 2

    print('Getting word counts from ConvAI2 train set...')
    opt['datatype'] = 'train:ordered'
    opt['fromfile_datapath'] = os.path.join(opt['datapath'],
                                            'controllable_dialogue',
                                            'ConvAI2_parlaiformat',
                                            'train.txt')
    # Do include inputs because ConvAI2 train set reverses every convo:
    word_counter_train, total_count_train, all_utts_train = get_word_counts(
        opt, count_inputs=False)

    print('Getting word counts from ConvAI2 val set...')
    opt['datatype'] = 'valid'
    opt['fromfile_datapath'] = os.path.join(opt['datapath'],
                                            'controllable_dialogue',
                                            'ConvAI2_parlaiformat',
                                            'valid.txt')
    # Don't include inputs because ConvAI2 val set doesn't reverses convos:
    word_counter_valid, total_count_valid, all_utts_valid = get_word_counts(
        opt, count_inputs=True)

    # Merge word counts
    word_counter = word_counter_train
    for word, count in word_counter_valid.items():
        word_counter[word] += count
    total_count = total_count_train + total_count_valid

    # Merge all_utts
    all_utts = all_utts_train + all_utts_valid

    # Compute unigram prob for every word
    print("Computing unigram probs for all words...")
    word2prob = {w: c / total_count for w, c in word_counter.items()}

    # Settings for sentence embedder
    arora_a = 0.0001
    glove_name = '840B'
    glove_dim = 300
    glove_cache = modelzoo_path(opt['datapath'], 'models:glove_vectors')

    # Embed every sentence, without removing first singular value
    print('Embedding all sentences...')
    sent_embedder = SentenceEmbedder(
        word2prob,
        arora_a,
        glove_name,
        glove_dim,
        first_sv=None,
        glove_cache=glove_cache,
    )
    utt_embs = []
    log_timer = TimeLogger()
    for n, utt in enumerate(all_utts):
        utt_emb = sent_embedder.embed_sent(utt.split(), rem_first_sv=False)
        utt_embs.append(utt_emb)
        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(n, len(all_utts))
            print(text)

    # Use SVD to calculate singular vector
    # https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.linalg.svd.html
    print('Calculating SVD...')
    utt_embs = np.stack(utt_embs, axis=0)  # shape (num_embs, glove_dim)
    U, s, V = np.linalg.svd(utt_embs, full_matrices=False)
    first_sv = V[0, :]  # first row of V. shape (glove_dim)

    # Remove singular vector from all embs to get complete Arora-style sent embs
    print('Removing singular vec from all sentence embeddings...')
    utt_embs_adj = [
        remove_first_sv(torch.Tensor(emb), torch.Tensor(first_sv)).numpy()
        for emb in utt_embs
    ]  # list of np arrays shape (glove_dim)

    # Make dict mapping ConvAI2 dataset utterances to Arora sent emb
    # We save this to file for convenience (e.g. if you want to inspect)
    utt2emb = {utt: emb for (utt, emb) in zip(all_utts, utt_embs_adj)}

    # Save unigram distribution, first singular value, hyperparameter value for a,
    # info about GloVe vectors used, and full dict of utt->emb to file
    print("Saving Arora embedding info to %s..." % arora_file)
    with open(arora_file, "wb") as f:
        pickle.dump(
            {
                'word2prob':
                word2prob,  # dict: string to float between 0 and 1
                'first_sv': first_sv,  # np array shape (glove_dim)
                'arora_a': arora_a,  # float, 0.0001
                'glove_name': glove_name,  # string, '840B'
                'glove_dim': glove_dim,  # int, 300
                'utt2emb':
                utt2emb,  # dict: string to np array shape (glove_dim)
            },
            f,
        )
Пример #22
0
def eval_wordstat(opt):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    """
    random.seed(42)

    # Setup control information
    initialize_control_information(opt)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)

    if opt.get('external_dict'):
        print('[ Using external dictionary from: {} ]'.format(opt['external_dict']))
        dict_opt = copy.deepcopy(opt)
        dict_opt['dict_file'] = opt['external_dict']
        dictionary = DictionaryAgent(dict_opt)
    else:
        print('[ Using model bundled dictionary ]')
        dictionary = agent.dict

    batch_size = opt['batchsize']

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    data = {}  # This will be written to the output json file
    data['opt'] = agent.opt  # Save the opt to json

    # Determine the output filename
    if opt['gold_response']:  # Special output file for gold response
        model_dir, _ = os.path.split(opt.get('model_file'))
        outfile = os.path.join(model_dir, 'goldresponse')
        if opt['use_reply'] != 'label':
            raise ValueError(
                'You should set --use-reply label (not --use-reply model) '
                'when measuring goldresponse stats'
            )
    else:
        outfile = "%s.%s.%s.%s" % (
            opt.get('model_file'),
            opt.get('datatype'),
            "use%sreply" % agent.opt['use_reply'],
            "beam%i" % agent.opt['beam_size'],
        )
        if agent.opt['beam_size'] > 1:
            outfile += ".beamminnbest%i" % agent.opt['beam_min_n_best']
        if len(agent.control_settings) > 0:
            outfile += ".setcontrols:" + "_".join(
                [
                    "%s%s" % (c, str(agent.control_settings[c]['set_value']))
                    for c in sorted(agent.control_settings.keys())
                ]
            )
        if agent.opt['beam_reorder'] not in ['none', False]:
            outfile += ".beamreorder_%s" % agent.opt['beam_reorder']
        if len(agent.wd_features) > 0:
            sorted_bfw = sorted(
                list(zip(agent.wd_features, agent.wd_wts)), key=lambda x: x[0]
            )
            outfile += ".WDfeatures:" + "_".join(
                ["%s%s" % (f, str(w)) for f, w in sorted_bfw]
            )
    if opt['num_examples'] != -1:
        outfile += ".numex%i" % opt['num_examples']
    outfile += ".wordstats.json"
    print("\nOutfile: %s\n" % outfile)

    cnt = 0
    word_statistics = {
        'mean_wlength': [],  # list of length (in words) of utterances
        'mean_clength': [],  # list of length (in chars) of utterances
        'freqs_cnt': Counter(),  # Counter for word frequencies, bucketed
        'word_cnt': 0,  # total number of words in all utterances
        'pred_list': [],  # list of generated utterances after applying normalize_answer
        'pure_pred_list': [],  # list of generated utterances
        'context_list': [],  # list of text inputs (persona and conversation history)
    }
    bins = [int(i) for i in opt['freq_bins'].split(',')]

    # This dictionary records all the sentence-level controllable attributes
    # For each attribute, we have a list of all the values
    sent_attrs = {attr: [] for attr in ATTR2SENTSCOREFN.keys()}  # str to list of floats

    # histories will be a list of ConvAI2History objects
    histories = []

    def process_prediction(prediction, word_statistics):
        word_statistics['pred_list'].append(normalize_answer(prediction))
        freqs, _cnt, wlength, clength = get_word_stats(
            prediction, dictionary, bins=bins
        )
        word_statistics['word_cnt'] += _cnt
        word_statistics['mean_wlength'].append(wlength)
        word_statistics['mean_clength'].append(clength)
        word_statistics['freqs_cnt'] += Counter(freqs)
        return word_statistics

    t0 = time.time()
    while not world.epoch_done():
        world.parley()
        # orig eval_wordstat.py handles bsz=1 but for simplicity we assume bsz>1
        assert batch_size != 1
        for w in world.worlds:
            try:
                try:
                    response_act = w.acts[-1]
                    prediction = response_act['text']
                except KeyError:
                    continue
                if opt['gold_response']:
                    # If we're measuring gold response, use eval_label as prediction
                    prediction = w.acts[0]['eval_labels'][0]
                    response_act = {'text': prediction}
                word_statistics['context_list'].append(w.acts[0]['text'])
                word_statistics['pure_pred_list'].append(prediction)
            except IndexError:
                continue
            cnt += 1
            word_statistics = process_prediction(prediction, word_statistics)

            # Compute and record sentence-level attributes
            history = ConvAI2History(w.acts[0]['text'])
            histories.append(history)
            sent_attrs = update_sent_attr_stats(sent_attrs, history, prediction)

        # Periodically log some info
        if log_time.time() > log_every_n_secs:
            report = world.report()
            text, report = log_time.log(report['exs'], world.num_examples(), report)
            print(text)

        if opt['num_examples'] > 0 and cnt >= opt['num_examples']:
            break
    if world.epoch_done():
        print("EPOCH DONE")
    print("Time to process %i examples: %f seconds" % (cnt, time.time() - t0))

    # Compute percent unique
    # Note this is w.r.t. normalized pred_list not original pure_pred_list
    unique_list = []
    cntr = Counter(word_statistics['pred_list'])
    for k, v in cntr.items():
        if v == 1:
            unique_list.append(k)
    unique_percent = len(unique_list) / len(word_statistics['pred_list']) * 100

    # Print a final report
    report = world.report()
    if opt['gold_response']:
        report['ppl'] = 0.0  # For gold responses, overwrite the perplexity
    print(report)

    # Put all information in data dict
    data['unique_percent'] = unique_percent  # percent of all responses that are unique
    data['word_statistics'] = word_statistics  # word stats, as in orig eval_wordstat
    data['report'] = report  # the final report
    data['histories'] = [
        (hist.persona_lines, hist.partner_utts, hist.own_utts) for hist in histories
    ]  # history for each example
    data['sent_attrs'] = sent_attrs  # all sentence attribute values for responses

    # Write data to outfile
    print("Writing to %s..." % outfile)
    with open(outfile, 'w') as f:
        json.dump(data, f)
Пример #23
0
def dump_data(opt):
    """
    Dump task data to ACUTE-Eval.
    """
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    task = opt.get('task')
    speaker_0_id = opt.get('speaker_0_id') or f'{task}_as_human'
    speaker_1_id = opt.get('speaker_1_id') or f'{task}_as_model'
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(prefix='{}_{}_'.format(
            opt['task'], opt['datatype']),
                                   suffix='.txt')[1]
    else:
        outfile = opt['outfile']

    num_episodes = (world.num_episodes() if opt['num_episodes'] == -1 else min(
        opt['num_episodes'], world.num_episodes()))
    log_timer = TimeLogger()

    print(f'[ starting to convert, saving output to {outfile} ]')
    dialogues = []
    for _ in range(num_episodes):
        episode = []
        episode_done = False
        while not episode_done:
            world.parley()
            acts = world.get_acts()
            text = acts[0].get('text')
            split_text = text.split('\n')
            label = random.choice(acts[0].get('labels',
                                              acts[0].pop('eval_labels',
                                                          None)))
            if not episode and opt.get('prepended_context'):
                # first turn
                context = split_text[:-1]
                text = split_text[-1]
                context_turn = [{
                    'text': context,
                    'episode_done': False,
                    'id': 'context'
                } for _ in range(2)]
                episode.append(context_turn)
            turn = [
                {
                    'text': text,
                    'episode_done': False,
                    'id': speaker_0_id
                },
                {
                    'text': label,
                    'episode_done': False,
                    'id': speaker_1_id
                },
            ]
            episode.append(turn)
            if acts[0].get('episode_done', False):
                episode[-1][-1]['episode_done'] = True
                episode_done = True
                dialogues.append(episode)

            if log_timer.time() > opt['log_every_n_secs']:
                text, _log = log_timer.log(world.total_parleys,
                                           world.num_examples())
                print(text)

        if world.epoch_done():
            break

    Conversations.save_conversations(dialogues, outfile, opt)
Пример #24
0
def detect(opt, printargs=None, print_parser=None):
    """
    Checks a task for offensive language.
    """
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
        offensive_string_matcher = OffensiveStringMatcher()
    if opt['safety'] == 'classifier' or opt['safety'] == 'all':
        offensive_classifier = OffensiveLanguageClassifier()

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    stats = {
        'bad_words': [],
        'bad_words_cnt': 0,
        'string_offensive': 0,
        'classifier_offensive': 0,
        'total_offensive': 0,
        'total': 0,
    }

    def report(world, stats):
        report = world.report()
        log = {
            'word_offenses':
            stats['bad_words_cnt'],
            'classifier_offenses%':
            100 * (stats['classifier_offensive'] / stats['total']),
            'string_offenses%':
            100 * (stats['string_offensive'] / stats['total']),
            'total_offenses%':
            100 * (stats['total_offensive'] / stats['total']),
        }
        text, log = log_time.log(report['exs'], world.num_examples(), log)
        print(text)

    def classify(text, stats):
        offensive = False
        stats['total'] += 1
        if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
            bad_words = offensive_string_matcher.contains_offensive_language(
                text)
            if bad_words:
                stats['string_offensive'] += 1
                offensive = True
                stats['bad_words'].append(bad_words)
        if opt['safety'] == 'classifier' or opt['safety'] == 'all':
            if text in offensive_classifier:
                stats['classifier_offensive'] += 1
                offensive = True
        if offensive:
            stats['total_offensive'] += 1

    while not world.epoch_done():
        world.parley()
        stats['bad_words'] = []
        for a in world.acts:
            text = a.get('text', '')
            classify(text, stats)
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                classify(l, stats)
        if len(stats['bad_words']) > 0 and opt['display_examples']:
            print(world.display())
            print("[Offensive words detected:]", ', '.join(stats['bad_words']))
            print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        stats['bad_words_cnt'] += len(stats['bad_words'])
        if log_time.time() > log_every_n_secs:
            report(world, stats)

    if world.epoch_done():
        print("EPOCH DONE")
    report(world, stats)
    return world.report()
Пример #25
0
def verify(opt, printargs=None, print_parser=None):
    if opt['datatype'] == 'train':
        logging.warn('changing datatype from train to train:ordered')
        opt['datatype'] = 'train:ordered'

    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    dictionary = DictionaryAgent(opt)
    ignore_tokens = opt.get('ignore_tokens').split(',')

    counts = {}
    for t in {'input', 'labels', 'both'}:
        counts['tokens_in_' + t] = 0
        counts['utterances_in_' + t] = 0
        counts['avg_utterance_length_in_' + t] = 0
        counts['unique_tokens_in_' + t] = 0
        counts['unique_utterances_in_' + t] = 0
        # for counting the stats..
        counts['token_dict_' + t] = {}
        counts['utterance_dict_' + t] = {}

    def tokenize(txt):
        return dictionary.tokenize(txt)

    def keep_token(t):
        for s in ignore_tokens:
            if s != '' and s in t:
                return False
        return True

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    # Show some example dialogs.
    while not world.epoch_done() and cnt < max_cnt:
        cnt += opt.get('batchsize', 1)
        world.parley()
        act = world.get_acts()[opt.get('agent')]
        for itype in {'input', 'labels'}:
            if itype == 'input':
                if opt.get('new_line_new_utt'):
                    txts = act.get('text').split('\n')
                else:
                    txts = [act.get('text')]
            else:
                txts = act.get('labels', act.get('eval_labels', ['']))

            for txt in txts:
                tokens = tokenize(txt)
                retxt = []
                for t in tokens:
                    if keep_token(t):
                        retxt.append(t)
                counts['tokens_in_' + itype] += len(retxt)
                counts['tokens_in_' + 'both'] += len(retxt)
                counts['utterances_in_' + itype] += 1
                counts['utterances_in_' + 'both'] += 1
                counts['avg_utterance_length_in_' + itype] = (
                    counts['tokens_in_' + itype] / counts['utterances_in_' + itype]
                )
                counts['avg_utterance_length_in_' + 'both'] = (
                    counts['tokens_in_' + 'both'] / counts['utterances_in_' + 'both']
                )
                for t in retxt:
                    if t not in counts['token_dict_' + itype]:
                        counts['unique_tokens_in_' + itype] += 1
                        counts['token_dict_' + itype][t] = True
                    if t not in counts['token_dict_' + 'both']:
                        counts['unique_tokens_in_' + 'both'] += 1
                        counts['token_dict_' + 'both'][t] = True
                retxt = ' '.join(retxt)
                if retxt not in counts['utterance_dict_' + itype]:
                    counts['unique_utterances_in_' + itype] += 1
                    counts['utterance_dict_' + itype][retxt] = True
                if retxt not in counts['utterance_dict_' + 'both']:
                    counts['unique_utterances_in_' + 'both'] += 1
                    counts['utterance_dict_' + 'both'][retxt] = True

        if log_time.time() > log_every_n_secs:
            text, log = report(world, counts, log_time)
            if print_parser:
                logging.info(text)

    try:
        # print dataset size if available
        logging.info(
            f'loaded {world.num_episodes()} episodes with a total '
            f'of {world.num_examples()} examples'
        )
    except Exception:
        pass
    return report(world, counts, log_time)
Пример #26
0
def detect(opt):
    """
    Checks a task for offensive language.
    """
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    agent.opt.log()
    if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
        offensive_string_matcher = OffensiveStringMatcher()
    if opt['safety'] == 'classifier' or opt['safety'] == 'all':
        offensive_classifier = OffensiveLanguageClassifier()

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    stats = {
        'bad_words': [],
        'bad_words_cnt': 0,
        'string_offensive': 0,
        'classifier_offensive': 0,
        'total_offensive': 0,
        'total': 0,
    }

    def report(world, stats):
        report = world.report()
        log = {
            'word_offenses':
            stats['bad_words_cnt'],
            'classifier_offenses%':
            100 * (stats['classifier_offensive'] / stats['total']),
            'string_offenses%':
            100 * (stats['string_offensive'] / stats['total']),
            'total_offenses%':
            100 * (stats['total_offensive'] / stats['total']),
        }
        text, log = log_time.log(report['exs'], world.num_examples(), log)
        logging.info(text)
        return log

    def classify(text, stats):
        offensive = False
        stats['total'] += 1
        if opt['safety'] == 'string_matcher' or opt['safety'] == 'all':
            bad_words = offensive_string_matcher.contains_offensive_language(
                text)
            if bad_words:
                stats['string_offensive'] += 1
                offensive = True
                stats['bad_words'].append(bad_words)
        if opt['safety'] == 'classifier' or opt['safety'] == 'all':
            if text in offensive_classifier:
                stats['classifier_offensive'] += 1
                offensive = True
        if offensive:
            stats['total_offensive'] += 1

    while not world.epoch_done():
        world.parley()
        stats['bad_words'] = []
        for a in world.acts:
            text = a.get('text', '')
            classify(text, stats)
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                classify(l, stats)
        if len(stats['bad_words']) > 0 and opt['display_examples']:
            logging.info(world.display())
            logging.info("Offensive words detected: {}".format(', '.join(
                stats['bad_words'])))
        stats['bad_words_cnt'] += len(stats['bad_words'])
        if log_time.time() > log_every_n_secs:
            report(world, stats)

    if world.epoch_done():
        logging.info("epoch done")
    return report(world, stats)
Пример #27
0
def bucket_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    assert opt['control'] != ''
    ctrl = opt['control']

    num_buckets = opt['num_buckets']

    ctrl_vals = []  # list of floats

    for _ in range(num_examples):
        world.parley()
        world.acts[0]['labels'] = world.acts[0].get(
            'labels', world.acts[0].pop('eval_labels', None))

        if ctrl not in world.acts[0].keys():
            raise Exception(
                'Error: control %s isn\'t in the data. available keys: %s' %
                (ctrl, ', '.join(world.acts[0].keys())))
        ctrl_val = world.acts[0][ctrl]
        if ctrl_val == "None":
            assert ctrl == 'lastuttsim'
            ctrl_val = None
        else:
            ctrl_val = float(ctrl_val)
        if ctrl == 'avg_nidf':
            assert ctrl_val >= 0
            assert ctrl_val <= 1
        elif ctrl == 'question':
            assert ctrl_val in [0, 1]
        elif ctrl == 'lastuttsim':
            if ctrl_val is not None:
                assert ctrl_val >= -1
                assert ctrl_val <= 1
        else:
            raise Exception('Unexpected ctrl name: %s' % ctrl)
        ctrl_vals.append(ctrl_val)

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)

        if world.epoch_done():
            print('EPOCH DONE')
            break

    if ctrl == 'lastuttsim':
        num_nones = len([v for v in ctrl_vals if v is None])
        ctrl_vals = [v for v in ctrl_vals if v is not None]
        print("Have %i Nones for lastuttsim; these have been removed "
              "for bucket calculation" % num_nones)

    print('Collected %i control vals between %.6f and %.6f' %
          (len(ctrl_vals), min(ctrl_vals), max(ctrl_vals)))

    # Calculate bucket lower bounds
    print('Calculating lowerbounds for %i buckets...' % num_buckets)
    ctrl_vals = sorted(ctrl_vals)
    lb_indices = [
        int(len(ctrl_vals) * i / num_buckets) for i in range(num_buckets)
    ]
    lbs = [ctrl_vals[idx] for idx in lb_indices]
    print('\nBucket lowerbounds for control %s: ' % ctrl)
    print(lbs)

    # Calculate the actual bucket sizes
    bucket_sizes = Counter()
    bucket_ids = [sort_into_bucket(ctrl_val, lbs) for ctrl_val in ctrl_vals]
    bucket_sizes.update(bucket_ids)
    print('\nBucket sizes: ')
    for bucket_id in sorted(bucket_sizes.keys()):
        print("%i: %i" % (bucket_id, bucket_sizes[bucket_id]))