Ejemplo n.º 1
0
def verify(opt, printargs=None, print_parser=None):
    if opt['datatype'] == 'train':
        logging.warn("changing datatype from train to train:ordered")
        opt['datatype'] = 'train:ordered'
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    counts = {}
    counts['missing_text'] = 0
    counts['missing_labels'] = 0
    counts['missing_label_candidates'] = 0
    counts['empty_string_label_candidates'] = 0
    counts['label_candidates_with_missing_label'] = 0
    counts['did_not_return_message'] = 0

    # Show some example dialogs.
    while not world.epoch_done():
        world.parley()

        act = world.acts[0]

        if not isinstance(act, Message):
            counts['did_not_return_message'] += 1

        if 'text' not in act and 'image' not in act:
            warn("warning: missing text field:\n", act, opt)
            counts['missing_text'] += 1

        if 'labels' not in act and 'eval_labels' not in act:
            warn("warning: missing labels/eval_labels field:\n", act, opt)
            counts['missing_labels'] += 1
        else:
            if 'label_candidates' not in act:
                counts['missing_label_candidates'] += 1
            else:
                labels = act.get('labels', act.get('eval_labels'))
                is_label_cand = {}
                for l in labels:
                    is_label_cand[l] = False
                for c in act['label_candidates']:
                    if c == '':
                        warn("warning: empty string label_candidate:\n", act,
                             opt)
                        counts['empty_string_label_candidates'] += 1
                    if c in is_label_cand:
                        if is_label_cand[c] is True:
                            warn(
                                "warning: label mentioned twice in candidate_labels:\n",
                                act,
                                opt,
                            )
                        is_label_cand[c] = True
                for _, has in is_label_cand.items():
                    if has is False:
                        warn("warning: label missing in candidate_labels:\n",
                             act, opt)
                        counts['label_candidates_with_missing_label'] += 1

        if log_time.time() > log_every_n_secs:
            text, log = report(world, counts, log_time)
            if print_parser:
                print(text)

    try:
        # print dataset size if available
        logging.info(f'Loaded {world.num_episodes()} episodes with a '
                     f'total of {world.num_examples()} examples')
    except Exception:
        pass

    return report(world, counts, log_time)
Ejemplo n.º 2
0
    def test_check_examples(self):

        # Define all pairs of task strings and examples
        tasks_and_messages = [
            (
                "blended_skill_talk:ConvAI2PersonaTopicifier",
                {
                    'text':
                    "your persona: i like to remodel homes.\nyour persona: i like to go hunting.\nyour persona: i like to shoot a bow.\nyour persona: my favorite holiday is halloween.\nNicholas Sparks\nhi , how are you doing ? i'm getting ready to do some cheetah chasing to stay in shape .",
                    'labels':
                    ('you must be very fast . hunting is one of my favorite hobbies .',
                     ),
                    'reward':
                    0,
                    'label_candidates': (
                        'my mom was single with 3 boys , so we never left the projects .',
                        'i try to wear all black every day . it makes me feel comfortable .',
                        'well nursing stresses you out so i wish luck with sister',
                        'yeah just want to pick up nba nfl getting old',
                        'i really like celine dion . what about you ?',
                        'no . i live near farms .',
                        "i wish i had a daughter , i'm a boy mom . they're beautiful boys though still lucky",
                        'yeah when i get bored i play gone with the wind my favorite movie .',
                        "hi how are you ? i'm eatingdinner with my hubby and 2 kids .",
                        'were you married to your high school sweetheart ? i was .',
                        'that is great to hear ! are you a competitive rider ?',
                        "hi , i'm doing ok . i'm abanker . how about you ?",
                        "i'm 5 years old",
                        'hi there . how are you today ?',
                        'i totally understand how stressful that can be .',
                        'yeah sometimes you do not know what you are actually watching',
                        'mother taught me to cook ! we are looking for an exterminator .',
                        'i enjoy romantic movie . what is your favorite season ? mine is summer .',
                        'editing photos takesa lot of work .',
                        'you must be very fast . hunting is one of my favorite hobbies .',
                    ),
                    'episode_done':
                    False,
                },
            ),
            (
                "blended_skill_talk:EDPersonaTopicifier",
                {
                    'situation':
                    'I remember going to the fireworks with my best friend. There was a lot of people, but it only felt like us in the world.',
                    'emotion':
                    'sentimental',
                    'prepend_ctx':
                    None,
                    'prepend_cand':
                    None,
                    'deepmoji_ctx':
                    None,
                    'deepmoji_cand':
                    None,
                    'text':
                    'your persona: people hate that i obsess about the poor.\nyour persona: i like to make cellphone apps that would help heal our world.\nyour persona: i like to watch people pray together.\nyour persona: people don t like me too much but i like them anyways.\nAndroid (operating system)#Applications\nI remember going to see the fireworks with my best friend. It was the first time we ever spent time alone together. Although there was a lot of people, we felt like the only people in the world.',
                    'labels': [
                        'Was this a friend you were in love with, or just a best friend?'
                    ],
                    'episode_done':
                    False,
                },
            ),
            (
                "blended_skill_talk:WoWPersonaTopicifier",
                {
                    'id':
                    'WizardDialogKnowledgeTeacher',
                    'text':
                    "your persona: not a day goes by that i don't drink four mountain dews.\nyour persona: i enjoy movies about aliens invading the earth.\nyour persona: my favorite hobby is chess.\nyour persona: i just dyed my hair hot pink with purple highlights.\nScience fiction\n",
                    'labels': [
                        "I think science fiction is an amazing genre for anything. Future science, technology, time travel, FTL travel, they're all such interesting concepts."
                    ],
                    'chosen_topic':
                    'Science fiction',
                    'episode_done':
                    False,
                    'label_candidates': [],
                    'knowledge':
                    'Science fiction Science fiction (often shortened to SF or sci-fi) is a genre of speculative fiction, typically dealing with imaginative concepts such as futuristic science and technology, space travel, time travel, faster than light travel, parallel universes, and extraterrestrial life.\nScience fiction Science fiction often explores the potential consequences of scientific and other innovations, and has been called a "literature of ideas".\nScience fiction It usually avoids the supernatural, unlike the related genre of fantasy.\nScience fiction Historically, science-fiction stories have had a grounding in actual science, but now this is only expected of hard science fiction.\nScience fiction Science fiction is difficult to define, as it includes a wide range of subgenres and themes.\nScience fiction Hugo Gernsback, who suggested the term "scientifiction" for his "Amazing Stories" magazine, wrote: "By \'scientifiction\' I mean the Jules Verne, H. G. Wells and Edgar Allan Poe type of story—a charming romance intermingled with scientific fact and prophetic vision... Not only do these amazing tales make tremendously interesting reading—they are always instructive.\nScience fiction They supply knowledge... in a very palatable form... New adventures pictured for us in the scientifiction of today are not at all impossible of realization tomorrow...\n',
                    'title':
                    'Science fiction',
                    'checked_sentence':
                    'Science fiction (often shortened to SF or sci-fi) is a genre of speculative fiction, typically dealing with imaginative concepts such as futuristic science and technology, space travel, time travel, faster than light travel, parallel universes, and extraterrestrial life.',
                },
            ),
        ]
        for task_string, desired_message in tasks_and_messages:

            # Get message
            kwargs = {'task': task_string, 'datatype': 'train:ordered'}
            parser = setup_args()
            parser.set_defaults(**kwargs)
            opt = parser.parse_args([])
            agent = RepeatLabelAgent(opt)
            teacher = create_task(opt, agent).get_task_agent()
            actual_message = teacher.get(episode_idx=0, entry_idx=0)

            print(f'\nChecking {task_string}:')
            for key in desired_message.keys():
                if key in ['label_candidates']:
                    # These are often created randomly and thus will vary
                    continue
                print(key)
                self.assertEqual(desired_message[key], actual_message[key])
            print('')
Ejemplo n.º 3
0
    def test_check_examples(self):

        with testing_utils.tempdir() as tmpdir:
            data_path = tmpdir

            # Check the first entry (entry_idx==0) of the second episode for the train
            # set, in order to check the context for an episode that has a WoW topic
            # string
            train_opt_and_example = (
                {
                    'datatype': 'train'
                },
                {
                    'text':
                    "your persona: i just bought a new house with my partner.\nyour persona: i like to make my own coffee.\nLasagne\nOh, I love lasagne. I make my own noodles as well as the sauce. \nWow.  That's amazing.  I read where lasagne originated in Italy during the Middle Ages.  \nOh really!? That is interesting. I am actually italian myself.",
                    'labels': [
                        "Awesome. Me and my partner just bought a house. I can't wait to cook in my kitchen."
                    ],
                    'context_dataset':
                    'wizard_of_wikipedia',
                    'free_message':
                    'Oh really!? That is interesting. I am actually italian myself.',
                    'convai2':
                    'yum . i like to make lasagna and it s so good',
                    'empathetic_dialogues':
                    'Cool. I love italian. Real italian.',
                    'wizard_of_wikipedia':
                    "Wow.  That's amazing.  I read where lasagne originated in Italy during the Middle Ages.",
                    'guided_chosen_suggestion':
                    ' ',
                    'episode_done':
                    False,
                },
            )
            all_kwargs = {
                **train_opt_and_example[0],
                'task': 'blended_skill_talk',
                'datapath': data_path,
            }
            parser = setup_args()
            parser.set_defaults(**all_kwargs)
            opt = parser.parse_args([])
            agent = RepeatLabelAgent(opt)
            teacher = create_task(opt, agent).get_task_agent()
            self.assertEqual(teacher.get(episode_idx=1, entry_idx=0),
                             train_opt_and_example[1])

            # Check the second entry (entry_idx==1) of the second episode for each dataset
            opts_and_examples = [
                (
                    {
                        'datatype': 'train'
                    },
                    {
                        'text':
                        'Moving in a new place can be a lot of fun. Are you a good cook?',
                        'labels': [
                            'I like to think so. I love to make coffee for an after dinner treat too.'
                        ],
                        'context_dataset':
                        'wizard_of_wikipedia',
                        'free_message':
                        'Moving in a new place can be a lot of fun. Are you a good cook?',
                        'convai2':
                        'yes ! trying to master lasagna .',
                        'empathetic_dialogues':
                        "See. I'm not a great cook.",
                        'wizard_of_wikipedia':
                        'With the training and skills I have, I can cook pretty much anything.',
                        'guided_chosen_suggestion':
                        ' ',
                        'episode_done':
                        False,
                    },
                ),
                (
                    {
                        'datatype': 'valid'
                    },
                    {
                        'text':
                        'I like to go mountain biking with my friends.',
                        'labels': [
                            "I have never done that.  Not really the physical activity type, but I'd be willing to give it a try, I guess"
                        ],
                        'context_dataset':
                        'empathetic_dialogues',
                        'free_message':
                        'I like to go mountain biking with my friends.',
                        'convai2':
                        "that's so cool , i love biking",
                        'empathetic_dialogues':
                        "Ive never been on any but I'll try it out",
                        'wizard_of_wikipedia':
                        "That's interesting!  Most mountain biking is in the categories of Trail and Cross Country riding styles",
                        'guided_chosen_suggestion':
                        '',
                        'label_candidates': {
                            'num_cands': 100,
                            'first':
                            'i work as a vet so no days off over here!',
                            'last': 'And what else? ',
                        },
                        'episode_done':
                        False,
                    },
                ),
                (
                    {
                        'datatype': 'test'
                    },
                    {
                        'text':
                        "He eats insects, leaves and sun flower seeds. It's easy. They don't need walking and cleanup is simple. Do you have any pets?",
                        'labels': [
                            'No, not at the moment.  I have 3 girls and they are enough trouble! LOL'
                        ],
                        'context_dataset':
                        'empathetic_dialogues',
                        'free_message':
                        "He eats insects, leaves and sun flower seeds. It's easy. They don't need walking and cleanup is simple. Do you have any pets?",
                        'convai2':
                        "no , i don't have any pets either .",
                        'empathetic_dialogues':
                        'I do not just a cat',
                        'wizard_of_wikipedia':
                        "I actually do.  He is ten years old and loves to be outside.  He's fat and furry.",
                        'guided_chosen_suggestion':
                        '',
                        'label_candidates': {
                            'num_cands':
                            100,
                            'first':
                            "Wow, engineering, sounds impressive.  I'm sure the income will be awesome.",
                            'last':
                            'but the worst part is you have to clean every day and keep the flat tidy all the time.  ',
                        },
                        'episode_done':
                        False,
                    },
                ),
            ]
            for kwargs, example in opts_and_examples:
                all_kwargs = {
                    **kwargs,
                    'task': 'blended_skill_talk',
                    'datapath': data_path,
                }
                parser = setup_args()
                parser.set_defaults(**all_kwargs)
                opt = parser.parse_args([])
                agent = RepeatLabelAgent(opt)
                teacher = create_task(opt, agent).get_task_agent()
                actual_message = teacher.get(episode_idx=1, entry_idx=1)

                # Check for field equality
                self.assertEqual(set(actual_message.keys()),
                                 set(example.keys()))

                # Check label candidates
                if 'label_candidates' in example:
                    params = example['label_candidates']
                    self.assertEqual(len(actual_message['label_candidates']),
                                     params['num_cands'])
                    self.assertEqual(actual_message['label_candidates'][0],
                                     params['first'])
                    self.assertEqual(actual_message['label_candidates'][-1],
                                     params['last'])

                # Check other fields
                for key in [
                        k for k in example.keys() if k != 'label_candidates'
                ]:
                    self.assertEqual(example[key], actual_message[key])
Ejemplo n.º 4
0
def main():
    '''Main script for running an eval task against the LIGHT dataset.

    special CLI arguments are
      --light-eval-task-type [speech, emote, action]
      --light-eval-unseen [False, True]

    This launches a task that, on a workers first attempt pairs with an entry
    from the training set. Then based on if the worker performs above a
    specified benchmark, they will either be soft blocked from evaluating or
    allowed to try against the test set.
    '''
    # Get relevant arguments
    argparser = ParlaiParser(False, False)
    argparser.add_parlai_data_path()
    argparser.add_mturk_args()
    argparser.set_defaults(datatype='test:stream')
    argparser.add_argument('--light-eval-task-type',
                           default='speech',
                           help='Type of task to be evaluating')
    argparser.add_argument(
        '--light-eval-unseen',
        default=False,
        type='bool',
        help='Evaluate against the unseen test rather than the seen test')
    opt = argparser.parse_args()

    task_opt = opt.copy()
    task_opt['task'] = 'light_dialog'
    assert opt['light_eval_task_type'] in [
        'speech', 'emote', 'action'
    ], ('--light-eval-task-type must be one of speech, emote, or action')
    LABEL_TYPE = opt['light_eval_task_type']  # speech, emote, action
    TRAIN_TURNS = 7
    TRAININGS = 1
    MAX_WRONG = 1
    if LABEL_TYPE != 'speech':
        TRAIN_TURNS = 3
        TRAININGS = 2
        MAX_WRONG = 3 if LABEL_TYPE == 'emote' else 2
    task_opt['light_label_type'] = LABEL_TYPE
    task_opt['light_use_action'] = 'all'
    task_opt['light_use_cands'] = '20'
    task_opt['light_use_emote'] = 'all'
    task_opt['light_use_objects'] = True
    task_opt['light_use_person_names'] = True
    task_opt['light_use_persona'] = 'self'
    task_opt['light_use_repeat'] = 'none'
    task_opt['light_use_setting'] = True
    task_opt['light_use_speech'] = 'all'
    task_opt['light_use_current_self_output'] = 'all'
    task_opt['light_use_clip_cands'] = 10000
    task_opt['light_unseen_test'] = task_opt['light_eval_unseen']

    random.seed(10)
    agent = RepeatLabelAgent(task_opt)
    world = create_task(task_opt, agent)

    # Populate dialogues from the LIGHT dataset
    samples = []
    curr_sample = []
    while True:
        world.parley()
        curr_sample.append(world.acts[0].copy())
        if world.acts[0]['episode_done']:
            if len(curr_sample) >= TRAIN_TURNS:
                samples.append(curr_sample)
            curr_sample = []
        if world.epoch_done():
            break

    train_samples = []
    task_opt['datatype'] = 'train:stream'
    task_opt['light_unseen_test'] = False
    agent = RepeatLabelAgent(task_opt)
    world = create_task(task_opt, agent)
    curr_sample = []
    while True:
        world.parley()
        curr_sample.append(world.acts[0].copy())
        if world.acts[0]['episode_done']:
            if len(curr_sample) >= TRAIN_TURNS:
                train_samples.append(curr_sample)
            curr_sample = []
        if world.epoch_done() or len(train_samples) > 2000:
            break

    # Set up temporary pools to pull tasks from
    use_train_samples = train_samples.copy()
    use_samples = train_samples.copy()

    # Set the task name to be the folder name
    opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))

    # append the contents of task_config.py to the configuration
    opt.update(task_config)

    # Select an agent_id that worker agents will be assigned in their world
    mturk_agent_roles = [LABEL_TYPE]

    opt['assignment_duration_in_seconds'] = 20 * 60

    # Instantiate an MTurkManager with the given options and a maximum number
    # of agents per world of 1 (based on the length of mturk_agent_ids)
    mturk_manager = MTurkManager(
        opt=opt,
        mturk_agent_ids=mturk_agent_roles,
        use_db=True,
    )
    mturk_manager.setup_server(
        task_directory_path=os.path.dirname(os.path.abspath(__file__)))

    # Create an onboard_function, which will be run for workers who have
    # accepted your task and must be completed before they are put in the
    # queue for a task world.
    completed_agents = []

    completed_train = {}

    def run_onboard(worker):
        nonlocal completed_agents
        if worker.worker_id in completed_agents:
            return
        else:
            world = LightEvalTestWorld(opt=opt, mturk_agent=worker)
            while not world.episode_done():
                world.parley()
            if world.did_complete:
                completed_agents.append(worker.worker_id)
            else:
                print(worker.worker_id, 'Failed the onboarding')
            world.shutdown()
            return world.prep_save_data([worker])

    mturk_manager.set_onboard_function(onboard_function=run_onboard)

    try:
        # Initialize run information
        mturk_manager.start_new_run()

        # Set up the sockets and threads to recieve workers
        mturk_manager.ready_to_accept_workers()

        # Create the hits as specified by command line arguments
        mturk_manager.create_hits(qualifications=[])

        # Check workers eligiblity acts as a filter, and should return
        # the list of all workers currently eligible to work on the task
        # Can be used to pair workers that meet certain criterea
        def check_workers_eligibility(workers):
            return workers

        eligibility_function = {
            'func': check_workers_eligibility,
            'multiple': True,
        }

        # Assign worker roles is used to determine what the role each worker
        # in the given worker list will play. Setting `id` to None will return
        # the worker to the pool rather than putting them in a given task,
        # which is useful for having tasks with different possible worker
        # counts.
        def assign_worker_roles(workers):
            workers[0].id = LABEL_TYPE

        # Define the task function, which will be run with workers that are
        # as the main task.
        global run_conversation

        def run_conversation(mturk_manager, opt, workers):
            nonlocal completed_train
            nonlocal use_samples
            nonlocal use_train_samples
            worker_id = workers[0].worker_id
            use_train = True
            if worker_id not in completed_train:
                completed_train[worker_id] = 0
            if completed_train[worker_id] >= TRAININGS:
                use_train = False

            # Create the real task world
            if not use_train:
                if len(use_samples) == 0:
                    # reset the pool if none are left
                    use_samples = samples.copy()
                sample = use_samples.pop()
            else:
                if len(use_train_samples) == 0:
                    # reset the pool if none are left
                    use_train_samples = train_samples.copy()
                sample = train_samples.pop()

            world = LightEvalTaskWorld(
                opt=opt,
                mturk_agents=workers,
                sample=sample,
                use_train=use_train,
                max_wrong=MAX_WRONG,
            )
            # run the world to completion
            while not world.episode_done():
                world.parley()

            # shutdown and review the work
            world.shutdown()
            world.review_work()

            if not world.completed and not use_train:
                samples.append(sample)
            if use_train and world.completed:
                completed_train[worker_id] += 1
                print('Worker passed train: ', worker_id)

            # Return the contents for saving
            return world.prep_save_data(workers)

        # Begin the task, allowing mturk_manager to start running the task
        # world on any workers who connect
        mturk_manager.start_task(eligibility_function=eligibility_function,
                                 assign_role_function=assign_worker_roles,
                                 task_function=run_conversation)

    except BaseException:
        raise
    finally:
        print('Accepted agents:', repr(completed_agents))
        # Shutdown the manager and free all related resources
        mturk_manager.shutdown()
Ejemplo n.º 5
0
def main(opt):
    # Get command line arguments
    opt = copy.deepcopy(opt)
    dt = opt['datatype'].split(':')[0] + ':ordered'
    opt['datatype'] = dt
    bsz = opt.get('batchsize', 1)
    opt['no_cuda'] = False
    opt['gpu'] = 0
    opt['num_epochs'] = 1
    opt['no_hdf5'] = True
    logger = ProgressLogger(should_humanize=False, throttle=0.1)
    print("[ Loading Images ]")
    # create repeat label agent and assign it to the specified task
    if opt.get('dataset') is None:
        agent = RepeatLabelAgent(opt)
        world = create_task(opt, agent)

        exs_seen = 0
        total_exs = world.num_examples()
        while not world.epoch_done():
            world.parley()
            exs_seen += bsz
            logger.log(exs_seen, total_exs)
    else:
        '''One can specify a Pytorch Dataset for custom image loading'''
        nw = opt.get('numworkers', 1)
        im = opt.get('image_mode', 'raw')
        opt['batchsize'] = 1
        opt['extract_image'] = True
        bsz = 1
        try:
            import torch
            from torch.utils.data import DataLoader
        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                'Need to install Pytorch: go to pytorch.org')

        dataset = get_dataset_class(opt)(opt)
        pre_image_path, _ = os.path.split(dataset.image_path)
        image_path = os.path.join(pre_image_path, opt.get('image_mode'))
        images_built_file = image_path + '.built'

        if not os.path.exists(image_path) or not os.path.isfile(
                images_built_file):
            '''Image features have not been computed yet'''
            opt['num_load_threads'] = 20
            agent = RepeatLabelAgent(opt)
            if opt['task'] == 'pytorch_teacher':
                opt['task'] = opt['pytorch_buildteacher']
            world = create_task(opt, agent)
            exs_seen = 0
            total_exs = world.num_examples()
            print('[ Computing and Saving Image Features ]')
            while exs_seen < total_exs:
                world.parley()
                exs_seen += bsz
                logger.log(exs_seen, total_exs)
            print('[ Feature Computation Done ]')
            with open(images_built_file, 'w') as write:
                write.write(str(datetime.datetime.today()))

        dataloader = DataLoader(dataset,
                                batch_size=bsz,
                                shuffle=False,
                                num_workers=nw,
                                collate_fn=lambda batch: batch[0])

        dataset_shape = None
        image_id_to_index = {}
        num_images = dataset.num_images()
        attention = opt.get('attention', False)
        if attention:
            hdf5_path = '{}mode_{}.hdf5'.format(dataset.image_path, im)
        else:
            hdf5_path = '{}mode_{}_noatt.hdf5'.format(dataset.image_path, im)
        image_id_to_idx_path = '{}mode_{}_id_to_idx.txt'.format(
            dataset.image_path, im)
        hdf5_built_file = hdf5_path + '.built'
        if os.path.isfile(hdf5_path) and os.path.isfile(hdf5_built_file):
            print('[ Images already extracted at: {} ]'.format(hdf5_path))
            return

        print("[ Beginning image extraction for {} images ]".format(
            dt.split(':')[0]))
        hdf5_file = h5py.File(hdf5_path, 'w')
        idx = 0
        for ex in iter(dataloader):
            if ex['image_id'] in image_id_to_index:
                continue
            else:
                image_id_to_index[ex['image_id']] = idx

            img = ex['image']
            if isinstance(img, torch.autograd.Variable):
                img = img.cpu().data

            if not attention:
                nb_regions = img.size(2) * img.size(3)
                img = img.sum(3).sum(2).div(nb_regions).view(-1, 2048)

            if dataset_shape is None:
                if attention:
                    dataset_shape = (num_images, img.size(1), img.size(2),
                                     img.size(3))
                else:
                    dataset_shape = (num_images, img.size(1))
                hdf5_dataset = hdf5_file.create_dataset('images',
                                                        dataset_shape,
                                                        dtype='f')

            hdf5_dataset[idx] = img
            logger.log(idx, num_images)
            idx += 1

        hdf5_file.close()
        if not os.path.exists(image_id_to_idx_path):
            with open(image_id_to_idx_path, 'w') as f:
                json.dump(image_id_to_index, f)
        with open(hdf5_built_file, 'w') as write:
            write.write(str(datetime.datetime.today()))

    print("[ Finished extracting images ]")
Ejemplo n.º 6
0
def verify(opt, printargs=None, print_parser=None):
    if opt['datatype'] == 'train':
        print("[ note: changing datatype from train to train:ordered ]")
        opt['datatype'] = 'train:ordered'

    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    dictionary = DictionaryAgent(opt)
    ignore_tokens = opt.get('ignore_tokens').split(',')

    counts = {}
    for t in {'input', 'labels', 'both'}:
        counts['tokens_in_' + t] = 0
        counts['utterances_in_' + t] = 0
        counts['avg_utterance_length_in_' + t] = 0
        counts['unique_tokens_in_' + t] = 0
        counts['unique_utterances_in_' + t] = 0
        # for counting the stats..
        counts['token_dict_' + t] = {}
        counts['utterance_dict_' + t] = {}

    def tokenize(txt):
        return dictionary.tokenize(txt)

    def keep_token(t):
        for s in ignore_tokens:
            if s != '' and s in t:
                return False
        return True

    # Show some example dialogs.
    while not world.epoch_done():
        world.parley()
        act = world.get_acts()[opt.get('agent')]
        for itype in {'input', 'labels'}:
            if itype == 'input':
                if opt.get('new_line_new_utt'):
                    txts = act.get('text').split('\n')
                else:
                    txts = [act.get('text')]
            else:
                txts = act.get('labels', act.get('eval_labels', ['']))

            for txt in txts:
                tokens = tokenize(txt)
                retxt = []
                for t in tokens:
                    if keep_token(t):
                        retxt.append(t)
                counts['tokens_in_' + itype] += len(retxt)
                counts['tokens_in_' + 'both'] += len(retxt)
                counts['utterances_in_' + itype] += 1
                counts['utterances_in_' + 'both'] += 1
                counts['avg_utterance_length_in_' +
                       itype] = (counts['tokens_in_' + itype] /
                                 counts['utterances_in_' + itype])
                counts['avg_utterance_length_in_' +
                       'both'] = (counts['tokens_in_' + 'both'] /
                                  counts['utterances_in_' + 'both'])
                for t in retxt:
                    if t not in counts['token_dict_' + itype]:
                        counts['unique_tokens_in_' + itype] += 1
                        counts['token_dict_' + itype][t] = True
                    if t not in counts['token_dict_' + 'both']:
                        counts['unique_tokens_in_' + 'both'] += 1
                        counts['token_dict_' + 'both'][t] = True
                retxt = ' '.join(retxt)
                if retxt not in counts['utterance_dict_' + itype]:
                    counts['unique_utterances_in_' + itype] += 1
                    counts['utterance_dict_' + itype][retxt] = True
                if retxt not in counts['utterance_dict_' + 'both']:
                    counts['unique_utterances_in_' + 'both'] += 1
                    counts['utterance_dict_' + 'both'][retxt] = True

        if log_time.time() > log_every_n_secs:
            text, log = report(world, counts, log_time)
            if print_parser:
                print(text)

    try:
        # print dataset size if available
        print('[ loaded {} episodes with a total of {} examples ]'.format(
            world.num_episodes(), world.num_examples()))
    except Exception:
        pass
    return report(world, counts, log_time)
Ejemplo n.º 7
0
def extract_feats(opt):
    if isinstance(opt, ParlaiParser):
        logging.error('extract_feats should be passed opt not parser')
        opt = opt.parse_args()
    # Get command line arguments
    opt = copy.deepcopy(opt)
    dt = opt['datatype'].split(':')[0] + ':ordered'
    opt['datatype'] = dt
    bsz = opt.get('batchsize', 1)
    opt['no_cuda'] = False
    opt['gpu'] = 0
    opt['num_epochs'] = 1
    opt['use_hdf5'] = False
    opt['num_load_threads'] = 20
    logging.info("Loading Images")
    # create repeat label agent and assign it to the specified task
    if opt.get('pytorch_teacher_dataset') is None:
        agent = RepeatLabelAgent(opt)
        world = create_task(opt, agent)

        total_exs = world.num_examples()
        pbar = tqdm.tqdm(unit='ex', total=total_exs)
        while not world.epoch_done():
            world.parley()
            pbar.update()
        pbar.close()
    elif opt.get('use_hdf5_extraction', False):
        # TODO Deprecate
        """
        One can specify a Pytorch Dataset for custom image loading.
        """
        nw = opt.get('numworkers', 1)
        im = opt.get('image_mode', 'raw')
        opt['batchsize'] = 1
        opt['extract_image'] = True
        bsz = 1
        try:
            import torch
            from torch.utils.data import DataLoader
        except ImportError:
            raise ImportError('Need to install Pytorch: go to pytorch.org')

        dataset = get_dataset_class(opt)(opt)
        pre_image_path, _ = os.path.split(dataset.image_path)
        image_path = os.path.join(pre_image_path, opt.get('image_mode'))
        images_built_file = image_path + '.built'

        if not os.path.exists(image_path) or not os.path.isfile(images_built_file):
            """
            Image features have not been computed yet.
            """
            opt['num_load_threads'] = 20
            agent = RepeatLabelAgent(opt)
            if opt['task'] == 'pytorch_teacher':
                if opt.get('pytorch_teacher_task'):
                    opt['task'] = opt['pytorch_teacher_task']
                else:
                    opt['task'] = opt['pytorch_teacher_dataset']
            world = create_task(opt, agent)
            exs_seen = 0
            total_exs = world.num_examples()
            pbar = tqdm.tqdm(unit='ex', total=total_exs)
            logging.info('Computing and Saving Image Features')
            while exs_seen < total_exs:
                world.parley()
                exs_seen += bsz
                pbar.update(bsz)
            pbar.close()
            logging.info('Feature Computation Done')
            with open(images_built_file, 'w') as write:
                write.write(str(datetime.datetime.today()))

        dataloader = DataLoader(
            dataset,
            batch_size=bsz,
            shuffle=False,
            num_workers=nw,
            collate_fn=lambda batch: batch[0],
        )

        dataset_shape = None
        image_id_to_index = {}
        num_images = dataset.num_images()
        attention = opt.get('attention', False)
        if attention:
            hdf5_path = '{}mode_{}.hdf5'.format(dataset.image_path, im)
        else:
            hdf5_path = '{}mode_{}_noatt.hdf5'.format(dataset.image_path, im)
        image_id_to_idx_path = '{}mode_{}_id_to_idx.txt'.format(dataset.image_path, im)
        hdf5_built_file = hdf5_path + '.built'
        if os.path.isfile(hdf5_path) and os.path.isfile(hdf5_built_file):
            logging.info(f'Images already extracted at: {hdf5_path}')
            return

        logging.info(
            "Beginning image extraction for {} images".format(dt.split(':')[0])
        )
        hdf5_file = h5py.File(hdf5_path, 'w')
        idx = 0
        iterator = tqdm.tqdm(
            dataloader, unit='batch', unit_scale=True, total=total_exs // bsz
        )
        for ex in iterator:
            if ex['image_id'] in image_id_to_index:
                continue
            else:
                image_id_to_index[ex['image_id']] = idx

            img = ex['image']
            if isinstance(img, torch.autograd.Variable):
                img = img.cpu().data

            if not attention:
                nb_regions = img.size(2) * img.size(3)
                img = img.sum(3).sum(2).div(nb_regions).view(-1, 2048)

            if dataset_shape is None:
                if attention:
                    dataset_shape = (num_images, img.size(1), img.size(2), img.size(3))
                else:
                    dataset_shape = (num_images, img.size(1))
                hdf5_dataset = hdf5_file.create_dataset(
                    'images', dataset_shape, dtype='f'
                )

            hdf5_dataset[idx] = img
            idx += 1

        hdf5_file.close()
        if not os.path.exists(image_id_to_idx_path):
            with open(image_id_to_idx_path, 'w') as f:
                json.dump(image_id_to_index, f)
        with open(hdf5_built_file, 'w') as write:
            write.write(str(datetime.datetime.today()))

    logging.info("Finished extracting images")
def dump_data(opt):
    """
    Dump task data to ACUTE-Eval.
    """
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)
    task = opt.get('task')
    speaker_0_id = opt.get('speaker_0_id') or f'{task}_as_human'
    speaker_1_id = opt.get('speaker_1_id') or f'{task}_as_model'
    if opt['outfile'] is None:
        outfile = tempfile.mkstemp(prefix='{}_{}_'.format(
            opt['task'], opt['datatype']),
                                   suffix='.txt')[1]
    else:
        outfile = opt['outfile']

    num_episodes = (world.num_episodes() if opt['num_episodes'] == -1 else min(
        opt['num_episodes'], world.num_episodes()))
    log_timer = TimeLogger()

    print(f'[ starting to convert, saving output to {outfile} ]')
    dialogues = []
    for _ in range(num_episodes):
        episode = []
        episode_done = False
        while not episode_done:
            world.parley()
            acts = world.get_acts()
            text = acts[0].get('text')
            split_text = text.split('\n')
            label = random.choice(acts[0].get('labels',
                                              acts[0].pop('eval_labels',
                                                          None)))
            if not episode and opt.get('prepended_context'):
                # first turn
                context = split_text[:-1]
                text = split_text[-1]
                context_turn = [{
                    'text': context,
                    'episode_done': False,
                    'id': 'context'
                } for _ in range(2)]
                episode.append(context_turn)
            turn = [
                {
                    'text': text,
                    'episode_done': False,
                    'id': speaker_0_id
                },
                {
                    'text': label,
                    'episode_done': False,
                    'id': speaker_1_id
                },
            ]
            episode.append(turn)
            if acts[0].get('episode_done', False):
                episode[-1][-1]['episode_done'] = True
                episode_done = True
                dialogues.append(episode)

            if log_timer.time() > opt['log_every_n_secs']:
                text, _log = log_timer.log(world.total_parleys,
                                           world.num_examples())
                print(text)

        if world.epoch_done():
            break

    Conversations.save_conversations(dialogues, outfile, opt)
Ejemplo n.º 9
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Create ConvAI2 data so we can assign personas.
    convai2_opt = opt.copy()
    convai2_opt['task'] = 'convai2:both'
    convai2_agent = RepeatLabelAgent(convai2_opt)
    convai2_world = create_task(convai2_opt, convai2_agent)

    def get_new_personas():
        # Find a new episode
        while True:
            convai2_world.parley()
            msg = convai2_world.get_acts()[0]
            if msg['episode_done']:
                convai2_world.parley()
                msg = convai2_world.get_acts()[0]
                break
        txt = msg.get('text', '').split('\n')
        bot_persona = ""
        for t in txt:
            if t.startswith("partner's persona:"):
                print(t.replace("partner's ", 'your '))
            if t.startswith('your persona:'):
                bot_persona += t + '\n'
        print("Enter [DONE] if you want a new partner at any time.")
        return bot_persona

    # Now run interactive mode, chatting with personas!
    cnt = 0
    while True:
        if cnt == 0:
            bot_persona = get_new_personas()
        # Run the parts of world.parley() in turn,
        # but insert persona into user message.
        acts = world.acts
        agents = world.agents
        acts[0] = agents[0].act()
        # add the persona on to the first message
        if cnt == 0:
            acts[0].force_set('text', bot_persona + acts[0].get('text', 'hi'))
        agents[1].observe(acts[0])
        acts[1] = agents[1].act()
        agents[0].observe(acts[1])
        world.update_counters()
        cnt = cnt + 1

        if opt.get('display_examples'):
            print("---")
            print(world.display())
        if world.episode_done():
            print("CHAT DONE ")
            print("In case you were curious you were talking to this bot:")
            print(bot_persona.split('\n'))
            print("\n... preparing new chat... \n")
            cnt = 0
def bucket_data(opt):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    if opt['num_examples'] == -1:
        num_examples = world.num_examples()
    else:
        num_examples = opt['num_examples']
    log_timer = TimeLogger()

    assert opt['control'] != ''
    ctrl = opt['control']

    num_buckets = opt['num_buckets']

    ctrl_vals = []  # list of floats

    for _ in range(num_examples):
        world.parley()
        world.acts[0]['labels'] = world.acts[0].get(
            'labels', world.acts[0].pop('eval_labels', None))

        if ctrl not in world.acts[0].keys():
            raise Exception(
                'Error: control %s isn\'t in the data. available keys: %s' %
                (ctrl, ', '.join(world.acts[0].keys())))
        ctrl_val = world.acts[0][ctrl]
        if ctrl_val == "None":
            assert ctrl == 'lastuttsim'
            ctrl_val = None
        else:
            ctrl_val = float(ctrl_val)
        if ctrl == 'avg_nidf':
            assert ctrl_val >= 0
            assert ctrl_val <= 1
        elif ctrl == 'question':
            assert ctrl_val in [0, 1]
        elif ctrl == 'lastuttsim':
            if ctrl_val is not None:
                assert ctrl_val >= -1
                assert ctrl_val <= 1
        else:
            raise Exception('Unexpected ctrl name: %s' % ctrl)
        ctrl_vals.append(ctrl_val)

        if log_timer.time() > opt['log_every_n_secs']:
            text, _log = log_timer.log(world.total_parleys,
                                       world.num_examples())
            print(text)

        if world.epoch_done():
            print('EPOCH DONE')
            break

    if ctrl == 'lastuttsim':
        num_nones = len([v for v in ctrl_vals if v is None])
        ctrl_vals = [v for v in ctrl_vals if v is not None]
        print("Have %i Nones for lastuttsim; these have been removed "
              "for bucket calculation" % num_nones)

    print('Collected %i control vals between %.6f and %.6f' %
          (len(ctrl_vals), min(ctrl_vals), max(ctrl_vals)))

    # Calculate bucket lower bounds
    print('Calculating lowerbounds for %i buckets...' % num_buckets)
    ctrl_vals = sorted(ctrl_vals)
    lb_indices = [
        int(len(ctrl_vals) * i / num_buckets) for i in range(num_buckets)
    ]
    lbs = [ctrl_vals[idx] for idx in lb_indices]
    print('\nBucket lowerbounds for control %s: ' % ctrl)
    print(lbs)

    # Calculate the actual bucket sizes
    bucket_sizes = Counter()
    bucket_ids = [sort_into_bucket(ctrl_val, lbs) for ctrl_val in ctrl_vals]
    bucket_sizes.update(bucket_ids)
    print('\nBucket sizes: ')
    for bucket_id in sorted(bucket_sizes.keys()):
        print("%i: %i" % (bucket_id, bucket_sizes[bucket_id]))
Ejemplo n.º 11
0
def interactive(opt, print_parser=None):
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    if isinstance(opt, ParlaiParser):
        print(
            '[ Deprecated Warning: interactive should be passed opt not Parser ]'
        )
        opt = opt.parse_args()
    opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    # Create ConvAI2 data so we can assign personas.
    convai2_opt = opt.copy()
    convai2_opt['task'] = 'convai2:both'
    convai2_agent = RepeatLabelAgent(convai2_opt)
    convai2_world = create_task(convai2_opt, convai2_agent)

    def get_new_personas():
        # Find a new episode
        while True:
            convai2_world.parley()
            msg = convai2_world.get_acts()[0]
            if msg['episode_done']:
                convai2_world.parley()
                msg = convai2_world.get_acts()[0]
                break
        txt = msg.get('text', '').split('\n')
        bot_persona = ""
        for t in txt:
            if t.startswith("partner's persona:"):
                print(t.replace("partner's ", 'your '))
            if t.startswith('your persona:'):
                bot_persona += t + '\n'
        print("Enter [DONE] if you want a new partner at any time.")
        return bot_persona

    # Now run interactive mode, chatting with personas!
    cnt = 0
    if not opt.get('chat_script'):
        while True:
            if cnt == 0:
                bot_persona = get_new_personas()
            # Run the parts of world.parley() in turn,
            # but insert persona into user message.
            acts = world.acts
            agents = world.agents
            acts[0] = agents[0].act()
            # add the persona on to the first message
            if cnt == 0:
                acts[0].force_set('text',
                                  bot_persona + acts[0].get('text', 'hi'))
            agents[1].observe(acts[0])
            acts[1] = agents[1].act()
            agents[0].observe(acts[1])
            world.update_counters()
            cnt = cnt + 1

            if opt.get('display_examples'):
                print("---")
                print(world.display())
            if world.episode_done():
                print("CHAT DONE ")
                print("In case you were curious you were talking to this bot:")
                print(bot_persona.split('\n'))
                print("\n... preparing new chat... \n")
                cnt = 0
    else:
        while True:
            input_path = opt.get('script_input_path')
            output_path = opt.get('script_output_path')
            model_name_ = opt.get('model_file')
            multi_check = opt.get('chateval_multi')
            turn_n = opt.get('chateval_multi_num')
            model_name = str(model_name_)

            script_input_path = str(input_path)
            script_file = open(script_input_path, 'r', encoding='utf-8')

            script_out_path = str(output_path)
            timestr = time.strftime("%Y%m%d-%H%M%S")
            file_name = script_input_path.split('/')[-1].split('.')[0]

            if model_name.find(":") != -1:
                model_name = model_name.split(':')[-1]
            else:
                model_name = model_name.split('/')[-1]

            if model_name.find('blender') != -1:
                script_response = open(
                    script_out_path + '/' + file_name + '_' +
                    model_name.split('/')[-2] + '_' + timestr + '.txt', 'w')
            else:
                script_response = open(
                    script_out_path + '/' + file_name + '_' + model_name +
                    '_' + timestr + '.txt', 'w')
            if cnt == 0:
                bot_persona = get_new_personas()
            # Run the parts of world.parley() in turn,
            # but insert persona into user message.
            acts = world.acts
            agents = world.agents
            # for raw_text in script_file:
            #     raw_text = raw_text.replace('\n', '')
            #     # acts[0] = agents[0].act()
            #     acts[0] = {'id': 'localHuman', 'episode_done': False, 'label_candidates': None, 'text': str(raw_text)}
            #     # add the persona on to the first message
            #     # if cnt == 0:
            #     #     acts[0].force_set('text', bot_persona + acts[0].get('text', 'hi'))
            #     agents[1].observe(acts[0])
            #     acts[1] = agents[1].act()
            #     agents[0].observe(acts[1])
            #
            #     result = acts[1]['text']
            #     script_response.write("%s\n" % (result))
            #
            #     world.update_counters()
            #     cnt = cnt + 1
            #
            #     if opt.get('display_examples'):
            #         print("---")
            #         print(world.display())
            #     if world.episode_done():
            #         print("CHAT DONE ")
            #         print("In case you were curious you were talking to this bot:")
            #         print(bot_persona.split('\n'))
            #         print("\n... preparing new chat... \n")
            #         cnt = 0
            # script_response.close()
            # print("script response complete!")
            # # acts[0] = {'id': 'localHuman', 'episode_done': False, 'label_candidates': None, 'text': '[DONE]'}
            # # agents[1].observe(validate(acts[0]))
            # import sys
            # sys.exit()
            count = 0
            for raw_text in script_file:
                count += 1
                # acts[0] = {'id': 'localHuman', 'episode_done': False, 'label_candidates': None, 'text': 'hi'}
                # if count > 850:
                raw_text = raw_text.replace('\n', '')
                if multi_check == True:
                    if turn_n == 2:
                        turn1 = raw_text.split('</s>')[0]
                        turn2 = raw_text.split('</s>')[1]
                        turn_temp = [turn1, turn2]

                        for index, turn_each in enumerate(turn_temp):
                            if index == 1:
                                # second turn
                                acts[0] = {
                                    'id': 'localHuman',
                                    'episode_done': False,
                                    'label_candidates': None,
                                    'text': str(turn_each)
                                }
                                agents[1].observe(acts[0])
                                acts[1] = agents[1].act()
                                agents[0].observe(acts[1])

                                result = acts[1]['text']
                                script_response.write("%s\n" % (result))
                                world.update_counters()
                                cnt = cnt + 1

                            # first turn
                            acts[0] = {
                                'id': 'localHuman',
                                'episode_done': True,
                                'label_candidates': None,
                                'text': str(turn_each)
                            }
                            agents[1].observe(acts[0])
                            acts[1] = agents[1].act()
                            agents[0].observe(acts[1])

                            # result = acts[1]['text']
                            # script_response.write("%s\n" % (result))
                            world.update_counters()
                            cnt = cnt + 1

                        turn_temp = []

                    elif turn_n == 3:
                        turn1 = raw_text.split('</s>')[0].replace('`', '')
                        turn2 = raw_text.split('</s>')[1].replace('`', '')
                        # turn3 = raw_text.split('</s>')[2].replace('`', '')
                        turn3 = raw_text.split('<\s>')[1].replace('`', '')
                        # if turn2.find('</s>') != -1:
                        #     turn3 = raw_text.split('</s>')[2].replace('`','')
                        # elif raw_text.find('<\s>') != -1:
                        #     turn3 = raw_text.split('<\s>')[1].replace('`','')
                        # else:
                        #     turn3 = ''
                        #     print("Check the turn utterances!!")
                        turn_temp = [turn1, turn2, turn3]

                        for index, turn_each in enumerate(turn_temp):
                            if index == 1:
                                # second turn
                                acts[0] = {
                                    'id': 'localHuman',
                                    'episode_done': False,
                                    'label_candidates': None,
                                    'text': str(turn_each)
                                }
                                agents[1].observe(acts[0])
                                acts[1] = agents[1].act()
                                agents[0].observe(acts[1])

                                result = acts[1]['text']
                                # script_response.write("%s\n" % (result))
                                world.update_counters()
                                cnt = cnt + 1

                            if index == 2:
                                # third turn
                                acts[0] = {
                                    'id': 'localHuman',
                                    'episode_done': False,
                                    'label_candidates': None,
                                    'text': str(turn_each)
                                }
                                agents[1].observe(acts[0])
                                acts[1] = agents[1].act()
                                agents[0].observe(acts[1])

                                result = acts[1]['text']
                                script_response.write("%s\n" % (result))
                                world.update_counters()
                                cnt = cnt + 1

                            # first turn
                            acts[0] = {
                                'id': 'localHuman',
                                'episode_done': True,
                                'label_candidates': None,
                                'text': str(turn_each)
                            }
                            agents[1].observe(acts[0])
                            acts[1] = agents[1].act()
                            agents[0].observe(acts[1])

                            # result = acts[1]['text']
                            # script_response.write("%s\n" % (result))
                            world.update_counters()
                            cnt = cnt + 1

                        turn_temp = []

                else:
                    acts[0] = {
                        'id': 'localHuman',
                        'episode_done': False,
                        'label_candidates': None,
                        'text': str(raw_text)
                    }
                    agents[1].observe(acts[0])
                    acts[1] = agents[1].act()
                    agents[0].observe(acts[1])

                    result = acts[1]['text']
                    script_response.write("%s\n" % (result))
                    world.update_counters()
                    cnt = cnt + 1

            script_response.close()
            print("script response complete!")
            acts[0] = {
                'id': 'localHuman',
                'episode_done': False,
                'label_candidates': None,
                'text': '[DONE]'
            }
            agents[1].observe(acts[0])
            import sys
            sys.exit()
Ejemplo n.º 12
0
def display_data(opt, lang=None):
    # create repeat label agent and assign it to the specified task
    agent = RepeatLabelAgent(opt)
    world = create_task(opt, agent)

    # Show some example dialogs.
    dialog = []
    full_dialogs = []
    dialogs = []
    situation = ''
    situations = []
    emotion = ''
    emotions = []
    sys_emotions = []
    sys_sentiments = []
    sys_sentiments_binary = []
    targets = []
    usr_dialogs = []  # Listener
    usr_targets = []  # Speaker
    sys_dialogs = []  # Speaker
    sys_targets = []  # Listener
    questions = []
    answers = []

    split = re.sub(r':ordered', '', opt['datatype'])
    split = re.sub(r':stream', '', split)
    split = re.sub(r'valid', 'dev', split)
    sure_sentiments = load_pkl(
        'data/prep/empathetic-dialogue/sure_sentiments.{}'.format(split))
    unsure_sentiments = load_pkl(
        'data/prep/empathetic-dialogue/unsure_sentiments.{}'.format(split))
    unsure_sentiments_binary = load_pkl(
        'data/prep/empathetic-dialogue/unsure_sentiments_binary.{}'.format(
            split))

    for i in range(opt['num_examples']):
        world.parley()

        # NOTE: If you want to look at the data from here rather than calling
        # world.display() you could access world.acts[0] directly
        message = world.display().split('\n')
        if situation != message[0] and i > 0:
            targets.append(dialog[-1] + ' __eou__')
            full_dialogs.append(flatten(dialog))
            dialogs.append(flatten(dialog[:-1]))
            # sys_dialogs.append(dialog[0])
            sys_dialogs.append('SPEAKER: ' + dialog[0])
            sys_targets.append(dialog[1] + ' __eou__')
            questions.append(dialog[0])
            answers.append(dialog[1] + ' __eou__')
            sys_emotions.append(clean_msg(emotion, 'emotion'))
            sys_sentiments.append(
                emo2sentiment(situation,
                              sys_emotions[-1],
                              unsure_sentiments,
                              sure_sentiments=sure_sentiments,
                              binary=False))
            sys_sentiments_binary.append(
                emo2sentiment(situation,
                              sys_emotions[-1],
                              unsure_sentiments_binary,
                              binary=True))
            if len(dialog) > 2:
                usr_dialog = [dialog[0], dialog[1]]
                # usr_dialog = ['SPEAKER: ' + dialog[0], 'LISTENER: ' + dialog[1]]
                usr_dialogs.append(flatten(usr_dialog))
                # usr_dialogs.append(flatten(dialog[:2]))
                usr_targets.append(dialog[2] + ' __eou__')

            for t in range(2, len(dialog)):
                # Speaker
                if t % 2 == 0:
                    # sys_dialog = [dialog[t-1], dialog[t]]
                    sys_dialog = [
                        'LISTENER: ' + dialog[t - 1], 'SPEAKER: ' + dialog[t]
                    ]
                    sys_dialogs.append(flatten(sys_dialog))
                # Listener
                else:
                    sys_targets.append(dialog[t] + ' __eou__')
                    sys_emotions.append(clean_msg(emotion, 'emotion'))
                    sys_sentiments.append(
                        emo2sentiment(situation,
                                      sys_emotions[-1],
                                      unsure_sentiments,
                                      sure_sentiments=sure_sentiments,
                                      binary=False))
                    sys_sentiments_binary.append(
                        emo2sentiment(situation,
                                      sys_emotions[-1],
                                      unsure_sentiments_binary,
                                      binary=True))
                    # Listener is not last turn
                    if len(dialog) > 2 and t < len(dialog) - 1:
                        usr_dialog = [dialog[t - 1], dialog[t]]
                        # usr_dialog = ['SPEAKER: ' + dialog[t-1], 'LISTENER: ' + dialog[t]]
                        usr_dialogs.append(flatten(dialog[t - 1:t + 1]))
                        usr_targets.append(dialog[t + 1] + ' __eou__')
            dialog = []
            situations.append(clean_msg(situation, 'situation'))
            # if lang:
            #     add_sentence(lang, situations[-1])
            emotions.append(clean_msg(emotion, 'emotion'))
            # if lang:
            #     add_sentence(lang, emotions[-1])

        situation = message[0]
        emotion = message[1]
        # as_emotions.append(clean_msg(emotion, 'emotion'))
        dialog.append(clean_msg(message[6], 'empathetic_dialogues'))
        if lang:
            add_sentence(lang, dialog[-1])
            # add_sentence(lang, 'SPEAKER ' + dialog[-1])
        if 'train' not in opt['datatype']:
            dialog.append(clean_msg(message[7], 'eval_labels'))
        else:
            dialog.append(clean_msg(message[7], 'labels'))

        if lang:
            add_sentence(lang, dialog[-1])
            # add_sentence(lang, 'LISTENER ' + dialog[-1])
        if world.epoch_done():
            print('EPOCH DONE')
            break

    try:
        # print dataset size if available
        print('[ loaded {} episodes with a total of {} examples ]'.format(
            world.num_episodes(), world.num_examples()))
    except Exception:
        pass

    return situations, emotions, sys_emotions, sys_sentiments, sys_sentiments_binary, full_dialogs, dialogs, targets, usr_dialogs, usr_targets, sys_dialogs, sys_targets, questions, answers