Esempio n. 1
0
 def add_cmdline_args(cls,
                      parser: ParlaiParser,
                      partial_opt: Optional[Opt] = None) -> ParlaiParser:
     super().add_cmdline_args(parser, partial_opt)
     flattened = parser.add_argument_group(
         'ControllableTaskTeacher Flattening Args')
     flattened.add_argument(
         '--flatten-include-labels',
         type='bool',
         default=True,
         help='Include labels in the history when flattening an episode',
     )
     flattened.add_argument(
         '--flatten-delimiter',
         type=str,
         default='\n',
         help='How to join the dialogue history from previous turns.',
     )
     flattened.add_argument(
         '--flatten-max-context-length',
         type=int,
         default=-1,
         help='Maximum number of utterances to include per episode. '
         'Default -1 keeps all.',
     )
     agent = parser.add_argument_group('ControllableTaskTeacher Args')
     agent.add_argument(
         '--invalidate-cache',
         type='bool',
         default=False,
         help='Set this to True to rebuild the data (may want to do this if '
         'original data has changed or you want to rebuild with new options)',
     )
     agent.add_argument(
         '--max-examples',
         type=int,
         default=-1,
         help=
         'If greater than zero, will stop building after a certain num of exs',
     )
     agent.add_argument(
         '--fixed-control',
         type=str,
         default='',
         help=
         'Always append this fixed control string, good for deploy time.',
     )
     # Add the arguments for the task teacher
     opt = parser.parse_and_process_known_args()[0]
     tasks = get_original_task_module(opt, multi_possible=True)
     for task in tasks:
         if hasattr(task, 'add_cmdline_args'):
             task.add_cmdline_args(parser, partial_opt=partial_opt)
     return parser
Esempio n. 2
0
    def _setup_data(self, opt: Opt) -> List[List[Message]]:
        """
        Flatten and classify the normal task data.

        Save/load where applicable.

        :param opt:
            options dict.
        """
        # create save directory, if it does not already exist
        self.original_task_name = ':'.join(opt['task'].split(':')[2:])
        self.save_dir = self._get_save_path(
            opt['datapath'], str(datetime.datetime.today())
        )
        os.makedirs(self.save_dir, exist_ok=True)

        fname = f"{opt['datatype'].split(':')[0]}.json"
        self.save_path = os.path.join(self.save_dir, fname)

        data = self.load_data(opt, fname)
        if data is not None:
            # successfully load data
            return data

        # build the original teacher
        original_task_module = get_original_task_module(opt)
        teacher_opt = deepcopy(opt)
        teacher_opt['task'] = self.original_task_name
        teacher = original_task_module(teacher_opt)

        total_exs = teacher.num_examples()
        if self.opt['max_examples'] > 0:
            total_exs = min(self.opt['max_examples'], total_exs)

        progress_bar = tqdm(
            total=total_exs, unit='ex', unit_scale=True, desc='Building flattened data'
        )

        all_episodes = []
        num_exs = 0
        while num_exs < total_exs:
            current_episode = []
            episode_done = False

            while not episode_done:
                action = Message(teacher.act())
                current_episode.append(action)
                episode_done = action.get('episode_done', False)
                num_exs += 1

            # flatten the episode into 1-example episodes with context
            flattened_ep = flatten_and_classify(
                current_episode,
                opt['flatten_max_context_length'],
                include_labels=opt['flatten_include_labels'],
                delimiter=opt['flatten_delimiter'],
                word_lists=self.word_lists,
            )
            all_episodes += flattened_ep

            progress_bar.update(len(flattened_ep))

        # save data for future use
        self.save_data(all_episodes)

        return all_episodes