def test_flatten_and_classify(self):
     word_lists = ControllableTaskTeacher.build_wordlists(
         ParlaiParser().parse_args([]))
     utterances = [
         "hello there",
         "hi there dad, what's up",
         "not much, do you know where your sister is?",
         "I have not seen her, I thought she was with grandpa",
         "well, if you see her, let me know",
         "will do!",
         "ok, have a good day",
         "bye bye! tell mom I say hello",
     ]
     tokens = ['f0m1', 'f1m1', 'f0m0', 'f1m0']
     episode = [
         Message({
             'text': utterances[i],
             'labels': [utterances[i + 1]],
             'episode_done': False,
         }) for i in range(0,
                           len(utterances) - 1, 2)
     ]
     episode[-1].force_set('episode_done', True)
     new_episode = flatten_and_classify(episode, -1, word_lists)
     assert len(new_episode) == 4
     assert all(ex['text'].endswith(tok) for ex, tok in zip(
         new_episode, tokens)), f"new episode: {new_episode}"
Beispiel #2
0
    def _setup_data(self, opt: Opt) -> List[List[Message]]:
        """
        Flatten and classify the normal task data.

        Save/load where applicable.

        :param opt:
            options dict.
        """
        # create save directory, if it does not already exist
        self.original_task_name = ':'.join(opt['task'].split(':')[2:])
        self.save_dir = self._get_save_path(
            opt['datapath'], str(datetime.datetime.today())
        )
        os.makedirs(self.save_dir, exist_ok=True)

        fname = f"{opt['datatype'].split(':')[0]}.json"
        self.save_path = os.path.join(self.save_dir, fname)

        data = self.load_data(opt, fname)
        if data is not None:
            # successfully load data
            return data

        # build the original teacher
        original_task_module = get_original_task_module(opt)
        teacher_opt = deepcopy(opt)
        teacher_opt['task'] = self.original_task_name
        teacher = original_task_module(teacher_opt)

        total_exs = teacher.num_examples()
        if self.opt['max_examples'] > 0:
            total_exs = min(self.opt['max_examples'], total_exs)

        progress_bar = tqdm(
            total=total_exs, unit='ex', unit_scale=True, desc='Building flattened data'
        )

        all_episodes = []
        num_exs = 0
        while num_exs < total_exs:
            current_episode = []
            episode_done = False

            while not episode_done:
                action = Message(teacher.act())
                current_episode.append(action)
                episode_done = action.get('episode_done', False)
                num_exs += 1

            # flatten the episode into 1-example episodes with context
            flattened_ep = flatten_and_classify(
                current_episode,
                opt['flatten_max_context_length'],
                include_labels=opt['flatten_include_labels'],
                delimiter=opt['flatten_delimiter'],
                word_lists=self.word_lists,
            )
            all_episodes += flattened_ep

            progress_bar.update(len(flattened_ep))

        # save data for future use
        self.save_data(all_episodes)

        return all_episodes