def batches(self, stage: Stage, rank=0, world_size=1, data_source=None):
        all_batches = {
            name: task.batches(stage, rank, world_size)
            for name, task in self.data_dict.items()
        }
        if stage == Stage.TRAIN:
            sampler = create_component(self.sampler_config,
                                       iterators=all_batches)
        else:
            sampler = EvalBatchSampler(all_batches)

        for name, batch in sampler.batchify(all_batches):
            batch[BatchContext.TASK_NAME] = name
            yield batch
 def from_config(
     cls,
     config: Config,
     data_dict: Dict[str, Data],
     task_key: str = BatchContext.TASK_NAME,
     rank=0,
     world_size=1,
     init_tensorizers=True,
 ):
     samplers = {
         Stage.TRAIN: create_component(ComponentType.BATCH_SAMPLER, config.sampler),
         Stage.EVAL: EvalBatchSampler(),
         Stage.TEST: EvalBatchSampler(),
     }
     return cls(data_dict, samplers, config.test_key, task_key)
示例#3
0
    def test_batch_sampler(self):
        iteratorA = ["1", "2", "3", "4", "5"]
        iteratorB = ["a", "b", "c"]

        # no iter_to_set_epoch
        round_robin_iterator = RoundRobinBatchSampler().batchify(
            {"A": iteratorA, "B": iteratorB}
        )
        expected_items = ["1", "a", "2", "b", "3", "c", "4"]
        self._check_iterator(round_robin_iterator, expected_items)

        # iter_to_set_epoch = "A"
        round_robin_iterator = RoundRobinBatchSampler(iter_to_set_epoch="A").batchify(
            {"A": iteratorA, "B": iteratorB}
        )
        expected_items = ["1", "a", "2", "b", "3", "c", "4", "a", "5", "b"]
        self._check_iterator(round_robin_iterator, expected_items)

        eval_iterator = EvalBatchSampler().batchify({"A": iteratorA, "B": iteratorB})
        expected_items = ["1", "2", "3", "4", "5", "a", "b", "c"]
        self._check_iterator(eval_iterator, expected_items)
示例#4
0
 def test_eval_batch_sampler(self):
     eval_iterator = EvalBatchSampler().batchify(self.iter_dict)
     expected_items = ["1", "2", "3", "4", "5", "a", "b", "c"]
     self._check_iterator(eval_iterator, expected_items)
 class Config(Data.Config):
     sampler: BaseBatchSampler.Config = EvalBatchSampler.Config()
示例#6
0
 class Config(Component.Config):
     sampler: BaseBatchSampler.Config = EvalBatchSampler.Config()
     test_key: Optional[str] = None
 class Config(Component.Config):
     epoch_size: Optional[int] = None
     sampler: BaseBatchSampler.Config = EvalBatchSampler.Config()
     test_key: Optional[str] = None