def add_cmdline_args(cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) flattened = parser.add_argument_group( 'ControllableTaskTeacher Flattening Args') flattened.add_argument( '--flatten-include-labels', type='bool', default=True, help='Include labels in the history when flattening an episode', ) flattened.add_argument( '--flatten-delimiter', type=str, default='\n', help='How to join the dialogue history from previous turns.', ) flattened.add_argument( '--flatten-max-context-length', type=int, default=-1, help='Maximum number of utterances to include per episode. ' 'Default -1 keeps all.', ) agent = parser.add_argument_group('ControllableTaskTeacher Args') agent.add_argument( '--invalidate-cache', type='bool', default=False, help='Set this to True to rebuild the data (may want to do this if ' 'original data has changed or you want to rebuild with new options)', ) agent.add_argument( '--max-examples', type=int, default=-1, help= 'If greater than zero, will stop building after a certain num of exs', ) agent.add_argument( '--fixed-control', type=str, default='', help= 'Always append this fixed control string, good for deploy time.', ) # Add the arguments for the task teacher opt = parser.parse_and_process_known_args()[0] tasks = get_original_task_module(opt, multi_possible=True) for task in tasks: if hasattr(task, 'add_cmdline_args'): task.add_cmdline_args(parser, partial_opt=partial_opt) return parser
def _setup_data(self, opt: Opt) -> List[List[Message]]: """ Flatten and classify the normal task data. Save/load where applicable. :param opt: options dict. """ # create save directory, if it does not already exist self.original_task_name = ':'.join(opt['task'].split(':')[2:]) self.save_dir = self._get_save_path( opt['datapath'], str(datetime.datetime.today()) ) os.makedirs(self.save_dir, exist_ok=True) fname = f"{opt['datatype'].split(':')[0]}.json" self.save_path = os.path.join(self.save_dir, fname) data = self.load_data(opt, fname) if data is not None: # successfully load data return data # build the original teacher original_task_module = get_original_task_module(opt) teacher_opt = deepcopy(opt) teacher_opt['task'] = self.original_task_name teacher = original_task_module(teacher_opt) total_exs = teacher.num_examples() if self.opt['max_examples'] > 0: total_exs = min(self.opt['max_examples'], total_exs) progress_bar = tqdm( total=total_exs, unit='ex', unit_scale=True, desc='Building flattened data' ) all_episodes = [] num_exs = 0 while num_exs < total_exs: current_episode = [] episode_done = False while not episode_done: action = Message(teacher.act()) current_episode.append(action) episode_done = action.get('episode_done', False) num_exs += 1 # flatten the episode into 1-example episodes with context flattened_ep = flatten_and_classify( current_episode, opt['flatten_max_context_length'], include_labels=opt['flatten_include_labels'], delimiter=opt['flatten_delimiter'], word_lists=self.word_lists, ) all_episodes += flattened_ep progress_bar.update(len(flattened_ep)) # save data for future use self.save_data(all_episodes) return all_episodes