def _add_extra_args_for_dataloader(self, task, other_args={}): training_parameters = self.config.training_parameters other_args["shuffle"] = False if task.dataset_type != "test": other_args["shuffle"] = True if (training_parameters.local_rank is not None and training_parameters.distributed): other_args["sampler"] = DistributedSampler( task, shuffle=other_args["shuffle"]) # Shuffle is mutually exclusive with sampler, let DistributedSampler take care of # shuffle and pop from main args other_args.pop("shuffle") setattr(self, "{}_sampler".format(task.dataset_type), other_args["sampler"]) batch_size = training_parameters.batch_size world_size = get_world_size() if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format( batch_size, world_size)) other_args["batch_size"] = batch_size // world_size return other_args
def get_batch_size(): from pythia.common.registry import registry batch_size = registry.get("config").training_parameters.batch_size world_size = get_world_size() if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format(batch_size, world_size)) return batch_size // world_size
def _add_extra_args_for_dataloader(self, other_args={}): training_parameters = self.config.training_parameters if (training_parameters.local_rank is not None and training_parameters.distributed): other_args["sampler"] = DistributedSampler(self.current_dataset) else: other_args["shuffle"] = True batch_size = training_parameters.batch_size world_size = get_world_size() if batch_size % world_size != 0: raise RuntimeError("Batch size {} must be divisible by number " "of GPUs {} used.".format( batch_size, world_size)) other_args["batch_size"] = batch_size // world_size return other_args