def _get_eval_dataloader_dict(self, phase, task_name_list, use_subset=False): val_dataloader_dict = {} for task_name in task_name_list: task = self.jiant_task_container.task_dict[task_name] eval_cache = self.jiant_task_container.task_cache_dict[task_name][phase] task_specific_config = self.jiant_task_container.task_specific_configs[task_name] val_dataloader_dict[task_name] = get_eval_dataloader_from_cache( eval_cache=eval_cache, task=task, eval_batch_size=task_specific_config.eval_batch_size, subset_num=task_specific_config.eval_subset_num if use_subset else None, ) return val_dataloader_dict
def get_train_dataloader_dict(self, for_eval=False, use_subset=False, do_weighted_sampling=False): # Not currently supported distributed parallel train_dataloader_dict = {} for task_name in self.jiant_task_container.task_run_config.train_task_list: task = self.jiant_task_container.task_dict[task_name] train_cache = self.jiant_task_container.task_cache_dict[task_name][ "train"] task_specific_config = self.jiant_task_container.task_specific_configs[ task_name] if for_eval: train_dataloader_dict[ task_name] = get_eval_dataloader_from_cache( eval_cache=train_cache, task=task, eval_batch_size=task_specific_config.eval_batch_size, subset_num=task_specific_config.eval_subset_num if use_subset else None, ) else: if do_weighted_sampling: sample_weights_path = task_specific_config.train_sample_weights logger.info( 'building train loader with sample weights "%s"', task_specific_config.train_sample_weights) else: logger.info('building train loader without sample weights') sample_weights_path = None train_dataloader_dict[task_name] = InfiniteYield( get_train_dataloader_from_cache( train_cache=train_cache, task=task, train_batch_size=task_specific_config.train_batch_size, sample_weights_path=sample_weights_path, fix_seed_for_weighted_sampler=self.jiant_task_container .global_train_config.fix_seed_for_weighted_sampler, ), ) return train_dataloader_dict