示例#1
0
def _load_personas(opt):
    print('[ loading personas.. ]')
    # Create ConvAI2 data so we can assign personas.
    convai2_opt = opt.copy()
    convai2_opt['task'] = 'convai2:both'
    if convai2_opt['datatype'].startswith('train'):
        convai2_opt['datatype'] = 'train:evalmode'
    convai2_opt['interactive_task'] = False
    convai2_opt['selfchat_task'] = False
    convai2_agent = FixedResponseAgent({'fixed_response': None})
    convai2_world = create_task(convai2_opt, convai2_agent)
    personas = set()
    while not convai2_world.epoch_done():
        convai2_world.parley()
        msg = convai2_world.get_acts()[0]
        # Find a new episode
        if msg.get('episode_done', False) and not convai2_world.epoch_done():
            convai2_world.parley()
            msg = convai2_world.get_acts()[0]
            txt = msg.get('text', '').split('\n')
            a1_persona = []
            a2_persona = []
            for t in txt:
                if t.startswith("partner's persona:"):
                    a1_persona.append(t.replace("partner's persona:", 'your persona:'))
                if t.startswith('your persona:'):
                    a2_persona.append(t)
            personas.add('\n'.join(a1_persona))
            personas.add('\n'.join(a2_persona))
    print('[ loaded ' + str(len(personas)) + ' personas ]')
    return list(personas)
示例#2
0
def load_openers(opt) -> Optional[List[str]]:
    base_task = opt['task'].split(':')[0]
    if base_task == 'self_chat':
        # TODO(#2284): Load default openers from s3
        return None

    print('[ loading conversation openers... ]')
    # create dummy task so we can get openers from the data
    task_opt = copy.deepcopy(opt)
    task_opt['task'] = base_task

    # default train will loop forever, but evalmode will stop after one epoch
    datatype = task_opt['datatype']
    if 'train' in datatype and 'evalmode' not in datatype:
        task_opt['datatype'] = f'{datatype}:evalmode'
    task_opt['interactive_task'] = False
    task_opt['selfchat_task'] = False
    task_opt['fixed_response'] = None
    task_agent = FixedResponseAgent(task_opt)
    task_world = create_task(task_opt, task_agent)

    # run through task data, collecting all first messages
    openers = set()
    is_first_turn = True
    while not task_world.epoch_done():
        task_world.parley()
        msg = task_world.get_acts()[0]
        # add only the first message in the episode
        if is_first_turn and msg.get('text'):
            openers.add(msg['text'])
        is_first_turn = msg.get('episode_done', False)

    print(f'[ loaded {len(openers)} openers ]')
    return list(openers)
示例#3
0
def display_data(opt):
    # force ordered data to prevent repeats
    if 'ordered' not in opt['datatype'] and 'train' in opt['datatype']:
        opt['datatype'] = f"{opt['datatype']}:ordered"

    # create dummy agent and assign it to the specified task
    opt.log()
    opt['fixed_response'] = None
    agent = FixedResponseAgent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    turn = 0
    for _ in range(opt['num_examples']):
        world.parley()

        # NOTE: If you want to look at the data from here rather than calling
        # world.display() you could access world.acts[0] directly, see simple_display above.
        if opt.get('verbose', False) or opt.get('display_add_fields', ''):
            print(world.display() + '\n~~')
        else:
            simple_display(opt, world, turn)
            turn += 1
            if world.get_acts()[0]['episode_done']:
                turn = 0

        if world.epoch_done():
            logging.info('epoch done')
            break

    try:
        # print dataset size if available
        logging.info(f'loaded {world.num_episodes()} episodes with a '
                     f'total of {world.num_examples()} examples')
    except Exception:
        pass
示例#4
0
def verify(opt):
    if opt['datatype'] == 'train':
        logging.warn('changing datatype from train to train:ordered')
        opt['datatype'] = 'train:ordered'

    # create repeat label agent and assign it to the specified task
    opt['fixed_response'] = None
    agent = FixedResponseAgent(opt)
    world = create_task(opt, agent)
    opt.log()

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    dictionary = DictionaryAgent(opt)
    ignore_tokens = opt.get('ignore_tokens').split(',')

    counts = {}
    for t in {'input', 'labels', 'both'}:
        counts[f'{t}/tokens'] = 0
        counts[f'{t}/utterances'] = 0
        counts[f'{t}/avg_utterance_length'] = None
        counts[f'{t}/unique_tokens'] = 0
        counts[f'{t}/unique_utterances'] = 0
        # for counting the stats..
        counts[f'{t}/token_dict'] = {}
        counts[f'{t}/utterance_dict'] = {}

    def tokenize(txt):
        return dictionary.tokenize(txt)

    def keep_token(t):
        for s in ignore_tokens:
            if s != '' and s in t:
                return False
        return True

    # max number of examples to evaluate
    max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
    cnt = 0

    # Show some example dialogs.
    while not world.epoch_done() and world.total_exs < max_cnt:
        world.parley()
        act = world.get_acts()[opt.get('agent')]
        for itype in {'input', 'labels'}:
            if itype == 'input':
                if opt.get('new_line_new_utt'):
                    txts = act.get('text').split('\n')
                else:
                    txts = [act.get('text')]
            else:
                txts = act.get('labels', act.get('eval_labels', ['']))

            for txt in txts:
                tokens = tokenize(txt)
                retxt = [t for t in tokens if keep_token(t)]
                counts[f'{itype}/tokens'] += len(retxt)
                counts['both/tokens'] += len(retxt)
                counts[f'{itype}/utterances'] += 1
                counts['both/utterances'] += 1
                counts[f'{itype}/avg_utterance_length'] += AverageMetric(
                    len(retxt), 1)
                counts[f'both/avg_utterance_length'] += AverageMetric(
                    len(retxt), 1)
                for t in retxt:
                    if t not in counts[f'{itype}/token_dict']:
                        counts[f'{itype}/unique_tokens'] += 1
                        counts[f'{itype}/token_dict'][t] = True
                    if t not in counts['both/token_dict']:
                        counts['both/unique_tokens'] += 1
                        counts['both/token_dict'][t] = True
                retxt = ' '.join(retxt)
                if retxt not in counts[f'{itype}/utterance_dict']:
                    counts[f'{itype}/unique_utterances'] += 1
                    counts[f'{itype}/utterance_dict'][retxt] = True
                if retxt not in counts['both/utterance_dict']:
                    counts['both/unique_utterances'] += 1
                    counts['both/utterance_dict'][retxt] = True

        if log_time.time() > log_every_n_secs:
            report = _report(world, counts)
            cnt = report.pop('exs')
            text, log = log_time.log(cnt, world.num_examples(), report)
            logging.info(text)

    try:
        # print dataset size if available
        logging.info(f'loaded {world.num_episodes()} episodes with a total '
                     f'of {world.num_examples()} examples')
    except AttributeError:
        pass

    retval = _report(world, counts)
    retval.pop('exs')
    return retval