def test_batch_size_per_device(self): # Need to patch the mmf.utils.general's world size not mmf.utils.distributed # as the first one is what will be used with patch("mmf.utils.general.get_world_size", return_value=2): trainer = TrainerTrainingLoopMock(100, 2, None, batch_size=4) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() # Train loader has batch size per device, for global batch size 4 # with world size 2, batch size per device should 4 // 2 = 2 self.assertEqual(trainer.train_loader.current_loader.batch_size, 2) # This is per device, so should stay same trainer = TrainerTrainingLoopMock(100, 2, None, batch_size_per_device=4) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() self.assertEqual(trainer.train_loader.current_loader.batch_size, 4) max_updates = trainer._calculate_max_updates() self.assertEqual(max_updates, 2) self.check_values(trainer, 0, 0, 0) trainer.training_loop() self.check_values(trainer, 2, 1, 2)
def _add_extra_args_for_dataloader( dataset_instance: mmf_typings.DatasetType, other_args: mmf_typings.DataLoaderArgsType = None, ) -> mmf_typings.DataLoaderArgsType: from mmf.utils.general import get_batch_size, get_class_weight if other_args is None: other_args = {} dataset_type = dataset_instance.dataset_type if dataset_type != "test": other_args["shuffle"] = True else: other_args["shuffle"] = False other_args["sampler"] = WeightedRandomSampler( torch.from_numpy(np.array(get_class_weight())), get_batch_size()) other_args.pop("shuffle") # In distributed mode, we use DistributedSampler from PyTorch if is_dist_initialized(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, shuffle=other_args["shuffle"]) # Shuffle is mutually exclusive with sampler, let DistributedSampler # take care of shuffle and pop from main args other_args.pop("shuffle") other_args["batch_size"] = get_batch_size() return other_args
def _add_extra_args_for_dataloader( dataset_instance: torch.utils.data.Dataset, other_args: Dict[str, Any] = None ) -> Dict[str, Any]: from mmf.utils.general import get_batch_size dataset_type = dataset_instance.dataset_type if other_args["shuffle"] is None: other_args["shuffle"] = False if dataset_type != "test": other_args["shuffle"] = True # In distributed mode, we use DistributedSampler from PyTorch if is_dist_initialized(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, shuffle=other_args["shuffle"] ) # Shuffle is mutually exclusive with sampler, let DistributedSampler # take care of shuffle and pop from main args other_args.pop("shuffle") if is_xla(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=other_args["shuffle"], drop_last=True, ) other_args.pop("shuffle") if other_args["batch_size"] is None: other_args["batch_size"] = get_batch_size() return other_args
def _add_extra_args_for_dataloader( dataset_instance: torch.utils.data.Dataset, other_args: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: from mmf.utils.general import get_batch_size if other_args is None: other_args = {} dataset_type = dataset_instance.dataset_type other_args["shuffle"] = False if dataset_type != "test": other_args["shuffle"] = True # In distributed mode, we use DistributedSampler from PyTorch if is_dist_initialized(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, shuffle=other_args["shuffle"]) # Shuffle is mutually exclusive with sampler, let DistributedSampler # take care of shuffle and pop from main args other_args.pop("shuffle") other_args["batch_size"] = get_batch_size() return other_args
def __init__(self, config): super().__init__() self.config = config self.batch_size = get_batch_size() self.train_loader = MultiDatasetLoader("train") self.val_loader = MultiDatasetLoader("val") self.test_loader = MultiDatasetLoader("test") self.train_loader.load(self.config) self.val_loader.load(self.config) self.test_loader.load(self.config)
def __init__(self, config: DictConfig): super().__init__() self.config = config self.batch_size = get_batch_size() self.dataset_list: List[str] = dataset_list_from_config(self.config) self.datamodules: List[pl.LightningDataModule] = build_multiple_datamodules( self.dataset_list, self.config.dataset_config ) self.train_loader: Optional[MultiDataLoader] = None self.val_loader: Optional[MultiDataLoader] = None self.test_loader: Optional[MultiDataLoader] = None
def _add_extra_args_for_dataloader(self, other_args=None): if other_args is None: other_args = {} if is_dist_initialized(): other_args["sampler"] = DistributedSampler(self.current_dataset, shuffle=False) else: other_args["shuffle"] = False other_args["batch_size"] = get_batch_size() return other_args
def __init__(self, config, dataset_type, imdb_file_index, *args, **kwargs): super().__init__("airstore", config, dataset_type) self.pathmanager = create_path_manager() self.config = config self.batch_size = get_batch_size() self.airstore_uri = config.annotations.get( dataset_type)[imdb_file_index] self.split = dataset_type self.epoch = 0 self.start_iter = 0 self.global_rank = torch.distributed.get_rank() self.global_world_size = torch.distributed.get_world_size() self._iterator = None
def __len__(self): # Since, this is iterator, we need to return total length == number of batches batch_size = get_batch_size() # Changed the length to accomadate drop_last == True # drop_last is required if the batch is split into multiple cores # some of the cores may not have enough examples. if is_xla(): logging.info( "drop_last is set to True to avoid uneven dimension shapes " "across cores.") return (self._total_length) // batch_size else: # This assumes drop_last=False for all loaders. See also # build_dataloader_and_sampler(). return (self._total_length + batch_size - 1) // batch_size
def test_exit_on_nan_losses(self, a): config = self._get_config(max_updates=2, max_epochs=None, batch_size=4) trainer = TrainerTrainingLoopMock(config=config) add_model(trainer, SimpleNaNLossModel({"in_dim": 1})) add_optimizer(trainer, config) registry.register("config", trainer.config) batch_size = get_batch_size() trainer.config.training.batch_size = batch_size trainer.load_datasets() exception_raised = False try: trainer.training_loop() except RuntimeError: exception_raised = True self.assertTrue(exception_raised)
def _add_extra_args_for_dataloader(self, dataset, opts, other_args=None): if other_args is None: other_args = {} dataset_type = self._dataset_type other_args["shuffle"] = False if dataset_type != "test": other_args["shuffle"] = True # In distributed mode, we use DistributedSampler from PyTorch if torch.distributed.is_initialized(): other_args["sampler"] = DistributedSampler( dataset, shuffle=other_args["shuffle"]) # Shuffle is mutually exclusive with sampler, let DistributedSampler # take care of shuffle and pop from main args other_args.pop("shuffle") other_args["batch_size"] = get_batch_size() return other_args
def __init__( self, num_train_data, max_updates, max_epochs, config=None, optimizer=None, update_frequency=1, batch_size=1, batch_size_per_device=None, fp16=False, on_update_end_fn=None, scheduler_config=None, grad_clipping_config=None, ): if config is None: self.config = OmegaConf.create( { "training": { "detect_anomaly": False, "evaluation_interval": 10000, "update_frequency": update_frequency, "fp16": fp16, "batch_size": batch_size, "batch_size_per_device": batch_size_per_device, } } ) self.training_config = self.config.training else: self.training_config = config.training self.config = config # Load batch size with custom config and cleanup original_config = registry.get("config") registry.register("config", self.config) batch_size = get_batch_size() registry.register("config", original_config) if max_updates is not None: self.training_config["max_updates"] = max_updates if max_epochs is not None: self.training_config["max_epochs"] = max_epochs self.model = SimpleModel({"in_dim": 1}) self.model.build() if torch.cuda.is_available(): self.model = self.model.cuda() self.device = "cuda" else: self.device = "cpu" self.distributed = False self.dataset_loader = MagicMock() self.dataset_loader.seed_sampler = MagicMock(return_value=None) self.dataset_loader.prepare_batch = lambda x: SampleList(x) if optimizer is None: self.optimizer = MagicMock() self.optimizer.step = MagicMock(return_value=None) self.optimizer.zero_grad = MagicMock(return_value=None) else: self.optimizer = optimizer if scheduler_config: config.training.lr_scheduler = True config.scheduler = scheduler_config self.lr_scheduler_callback = LRSchedulerCallback(config, self) self.callbacks.append(self.lr_scheduler_callback) on_update_end_fn = ( on_update_end_fn if on_update_end_fn else self.lr_scheduler_callback.on_update_end ) if grad_clipping_config: self.training_config.clip_gradients = True self.training_config.max_grad_l2_norm = grad_clipping_config[ "max_grad_l2_norm" ] self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"] dataset = NumbersDataset(num_train_data) self.train_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False, ) self.train_loader.current_dataset = dataset self.on_batch_start = MagicMock(return_value=None) self.on_update_start = MagicMock(return_value=None) self.logistics_callback = MagicMock(return_value=None) self.logistics_callback.log_interval = MagicMock(return_value=None) self.on_batch_end = MagicMock(return_value=None) self.on_update_end = ( on_update_end_fn if on_update_end_fn else MagicMock(return_value=None) ) self.meter = Meter() self.after_training_loop = MagicMock(return_value=None) self.on_validation_start = MagicMock(return_value=None) self.evaluation_loop = MagicMock(return_value=(None, None)) self.scaler = torch.cuda.amp.GradScaler(enabled=False) self.val_loader = MagicMock(return_value=None) self.early_stop_callback = MagicMock(return_value=None) self.on_validation_end = MagicMock(return_value=None) self.metrics = MagicMock(return_value=None)
def __len__(self): # Since, this is iterator, we need to return total length == number of batches return self._total_length // get_batch_size()
def __len__(self): # Since, this is iterator, we need to return total length == number of batches batch_size = get_batch_size() # This assumes drop_last=False for all loaders. See also # build_dataloader_and_sampler(). return (self._total_length + batch_size - 1) // batch_size