Пример #1
0
 def _get_eval_dataloader_dict(self, phase, task_name_list, use_subset=False):
     val_dataloader_dict = {}
     for task_name in task_name_list:
         task = self.jiant_task_container.task_dict[task_name]
         eval_cache = self.jiant_task_container.task_cache_dict[task_name][phase]
         task_specific_config = self.jiant_task_container.task_specific_configs[task_name]
         val_dataloader_dict[task_name] = get_eval_dataloader_from_cache(
             eval_cache=eval_cache,
             task=task,
             eval_batch_size=task_specific_config.eval_batch_size,
             subset_num=task_specific_config.eval_subset_num if use_subset else None,
         )
     return val_dataloader_dict
Пример #2
0
    def get_train_dataloader_dict(self,
                                  for_eval=False,
                                  use_subset=False,
                                  do_weighted_sampling=False):
        # Not currently supported distributed parallel
        train_dataloader_dict = {}
        for task_name in self.jiant_task_container.task_run_config.train_task_list:
            task = self.jiant_task_container.task_dict[task_name]
            train_cache = self.jiant_task_container.task_cache_dict[task_name][
                "train"]
            task_specific_config = self.jiant_task_container.task_specific_configs[
                task_name]
            if for_eval:
                train_dataloader_dict[
                    task_name] = get_eval_dataloader_from_cache(
                        eval_cache=train_cache,
                        task=task,
                        eval_batch_size=task_specific_config.eval_batch_size,
                        subset_num=task_specific_config.eval_subset_num
                        if use_subset else None,
                    )
            else:
                if do_weighted_sampling:
                    sample_weights_path = task_specific_config.train_sample_weights
                    logger.info(
                        'building train loader with sample weights "%s"',
                        task_specific_config.train_sample_weights)
                else:
                    logger.info('building train loader without sample weights')

                    sample_weights_path = None
                train_dataloader_dict[task_name] = InfiniteYield(
                    get_train_dataloader_from_cache(
                        train_cache=train_cache,
                        task=task,
                        train_batch_size=task_specific_config.train_batch_size,
                        sample_weights_path=sample_weights_path,
                        fix_seed_for_weighted_sampler=self.jiant_task_container
                        .global_train_config.fix_seed_for_weighted_sampler,
                    ), )
        return train_dataloader_dict