def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def summarize_report( current_iteration, num_updates, max_updates, meter, should_print=True, extra=None, tb_writer=None, ): if extra is None: extra = {} if not is_master() and not is_xla(): return if tb_writer: scalar_dict = meter.get_scalar_dict() tb_writer.add_scalars(scalar_dict, current_iteration) if not should_print: return log_dict = {} if num_updates is not None and max_updates is not None: log_dict.update({"progress": f"{num_updates}/{max_updates}"}) log_dict.update(meter.get_log_dict()) log_dict.update(extra) log_progress(log_dict)
def _add_extra_args_for_dataloader( dataset_instance: torch.utils.data.Dataset, other_args: Dict[str, Any] = None ) -> Dict[str, Any]: from mmf.utils.general import get_batch_size dataset_type = dataset_instance.dataset_type if other_args["shuffle"] is None: other_args["shuffle"] = False if dataset_type != "test": other_args["shuffle"] = True # In distributed mode, we use DistributedSampler from PyTorch if is_dist_initialized(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, shuffle=other_args["shuffle"] ) # Shuffle is mutually exclusive with sampler, let DistributedSampler # take care of shuffle and pop from main args other_args.pop("shuffle") if is_xla(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=other_args["shuffle"], drop_last=True, ) other_args.pop("shuffle") if other_args["batch_size"] is None: other_args["batch_size"] = get_batch_size() return other_args
def get_current_device(): if is_xla(): import torch_xla.core.xla_model as xm return xm.xla_device() if torch.cuda.is_available(): return f"cuda:{torch.cuda.current_device()}" else: return torch.device("cpu")
def __len__(self): # Since, this is iterator, we need to return total length == number of batches batch_size = get_batch_size() # Changed the length to accomadate drop_last == True # drop_last is required if the batch is split into multiple cores # some of the cores may not have enough examples. if is_xla(): logging.info( "drop_last is set to True to avoid uneven dimension shapes " "across cores.") return (self._total_length) // batch_size else: # This assumes drop_last=False for all loaders. See also # build_dataloader_and_sampler(). return (self._total_length + batch_size - 1) // batch_size
def __call__(self, update, iteration, meter): """ Method to be called everytime you need to check whether to early stop or not Arguments: update {number}: Current update number iteration {number}: Current iteration number Returns: bool -- Tells whether early stopping occurred or not """ # There are operations involving synchronization downstream # For XLA those calls must be executed from all cores # Therefore we do return here in case of XLA if not is_master() and not is_xla(): return False value = meter.meters.get(self.early_stop_criteria, None) if value is None: raise ValueError("Criteria used for early stopping ({}) is not " "present in meter.".format( self.early_stop_criteria)) value = value.global_avg if isinstance(value, torch.Tensor): value = value.item() if (self.minimize and value < self.best_monitored_value) or ( not self.minimize and value > self.best_monitored_value): self.best_monitored_value = value self.best_monitored_iteration = iteration self.best_monitored_update = update self.checkpoint.save(update, iteration, update_best=True) elif self.best_monitored_update + self.patience < update: self.activated = True if self.should_stop is True: self.checkpoint.restore() self.checkpoint.finalize() return True else: return False else: self.checkpoint.save(update, iteration, update_best=False) return False
def _check_nan_losses(self, report): # skip this check in XLA mode as calling .item() in forward pass # greatly slows down the training if not is_xla(): # check whether NaN has occurred in the losses, and exit the training # when NaN happens loss_dict = report.losses nan_loss_keys = [] for key, value in loss_dict.items(): if torch.any(torch.isnan(value)).item(): nan_loss_keys.append(key) if len(nan_loss_keys) > 0: keys_str = ", ".join(nan_loss_keys) error_msg = ( f"NaN occurred in the following loss(es): {keys_str}; " f"exiting the training") logger.info(error_msg) raise RuntimeError(error_msg)
def _finish_update(self): if self.training_config.clip_gradients: clip_gradients( self.model, self.num_updates, self.logistics_callback.tb_writer, self.config, scale=self.scaler.get_scale(), ) if is_xla(): import torch_xla.core.xla_model as xm # Assumes no model parallel xm.reduce_gradients(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() self.num_updates += 1 self.profile("Finished update")
def validate_batch_sizes(my_batch_size: int) -> bool: """ Validates all workers got the same batch size. """ # skip batch size validation on XLA (as there's too much overhead # and data loader automatically drops the last batch in XLA mode) if is_xla(): return True batch_size_tensor = torch.IntTensor([my_batch_size]) if torch.cuda.is_available(): batch_size_tensor = batch_size_tensor.cuda() all_batch_sizes = gather_tensor(batch_size_tensor) for j, oth_batch_size in enumerate(all_batch_sizes.data): if oth_batch_size != my_batch_size: logger.error(f"Node {j} batch {oth_batch_size} != {my_batch_size}") return False return True
def summarize_report( current_iteration, num_updates, max_updates, meter, should_print=True, extra=None, tb_writer=None, wandb_logger=None, ): if extra is None: extra = {} if not is_main() and not is_xla(): return # Log the learning rate if available if wandb_logger and "lr" in extra: wandb_logger.log_metrics({"train/learning_rate": float(extra["lr"])}, commit=False) if tb_writer: scalar_dict = meter.get_scalar_dict() tb_writer.add_scalars(scalar_dict, current_iteration) if wandb_logger: metrics = meter.get_scalar_dict() wandb_logger.log_metrics({ **metrics, "trainer/global_step": current_iteration }) if not should_print: return log_dict = {} if num_updates is not None and max_updates is not None: log_dict.update({"progress": f"{num_updates}/{max_updates}"}) log_dict.update(meter.get_log_dict()) log_dict.update(extra) log_progress(log_dict)
def _summarize_report(self, meter, should_print=True, extra=None): if extra is None: extra = {} if not is_master() and not is_xla(): return if self.training_config.tensorboard: scalar_dict = meter.get_scalar_dict() self.tb_writer.add_scalars(scalar_dict, self.trainer.current_iteration) if not should_print: return log_dict = {} if hasattr(self.trainer, "num_updates") and hasattr( self.trainer, "max_updates" ): log_dict.update( {"progress": f"{self.trainer.num_updates}/{self.trainer.max_updates}"} ) log_dict.update(meter.get_log_dict()) log_dict.update(extra) log_progress(log_dict)
def parallelize_model(self) -> None: registry.register("data_parallel", False) registry.register("distributed", False) if ("cuda" in str(self.device) and torch.cuda.device_count() > 1 and not self.distributed): registry.register("data_parallel", True) self.model = torch.nn.DataParallel(self.model) if "cuda" in str(self.device) and self.distributed: registry.register("distributed", True) set_torch_ddp = True try: from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim.oss import OSS if isinstance(self.optimizer, OSS): self.model = ShardedDataParallel(self.model, self.optimizer) set_torch_ddp = False logger.info("Using FairScale ShardedDataParallel") except ImportError: logger.info("Using PyTorch DistributedDataParallel") warnings.warn( "You can enable ZeRO and Sharded DDP, by installing fairscale " + "and setting optimizer.enable_state_sharding=True.") if set_torch_ddp: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=self.config.training. find_unused_parameters, ) if is_xla() and get_world_size() > 1: broadcast_xla_master_model_param(self.model)
def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): """Run starts a job based on the command passed from the command line. You can optionally run the mmf job programmatically by passing an optlist as opts. Args: opts (typing.Optional[typing.List[str]], optional): Optlist which can be used. to override opts programmatically. For e.g. if you pass opts = ["training.batch_size=64", "checkpoint.resume=True"], this will set the batch size to 64 and resume from the checkpoint if present. Defaults to None. predict (bool, optional): If predict is passed True, then the program runs in prediction mode. Defaults to False. """ setup_imports() if opts is None: parser = flags.get_parser() args = parser.parse_args() else: args = argparse.Namespace(config_override=None) args.opts = opts configuration = Configuration(args) # Do set runtime args which can be changed by MMF configuration.args = args config = configuration.get_config() config.start_rank = 0 if config.distributed.init_method is None: infer_init_method(config) if config.distributed.init_method is not None: if torch.cuda.device_count() > 1 and not config.distributed.no_spawn: config.start_rank = config.distributed.rank config.distributed.rank = None torch.multiprocessing.spawn( fn=distributed_main, args=(configuration, predict), nprocs=torch.cuda.device_count(), ) else: distributed_main(0, configuration, predict) elif config.distributed.world_size > 1: if is_xla(): import torch_xla.distributed.xla_multiprocessing as xmp torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( fn=distributed_main, args=(configuration, predict), nprocs=8, # use all 8 TPU cores start_method="fork", ) else: assert config.distributed.world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) config.distributed.init_method = f"tcp://localhost:{port}" config.distributed.rank = None torch.multiprocessing.spawn( fn=distributed_main, args=(configuration, predict), nprocs=config.distributed.world_size, ) else: config.device_id = 0 main(configuration, predict=predict)
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } if version.parse(torch.__version__) >= version.parse("1.8"): # only use persistent workers in PyTorch 1.8 or higher # (PyTorch 1.7 also has this option but doesn't support it correctly due to # https://github.com/pytorch/pytorch/issues/48370) other_args["persistent_workers"] = ( datamodule_config.get( "persistent_workers", training_config.get("persistent_workers", True) ), ) if other_args["persistent_workers"] and other_args["num_workers"] == 0: logger.warning( "persistent_workers cannot be used together with num_workers == 0; " "setting persistent_workers to False" ) other_args["persistent_workers"] = False # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") # Set drop_last=True when using XLA to have constant batch size. # In this case we also need to set drop_last=True in DistributedSampler. loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def save_func(self, *args): return save_xla_ckpt(*args) if is_xla() else torch.save(*args)
def finalize(self): if is_master() or is_xla(): with PathManager.open(self.pth_filepath, "wb") as f: self.save_func(self.trainer.model.state_dict(), f)
def save(self, update, iteration=None, update_best=False): # Only save in main process # For xla we use xm.save method # Which ensures that actual checkpoint saving happens # only for the master node. # The method also takes care of all the necessary synchronization if not is_master() and not is_xla(): return logger.info("Checkpoint save operation started!") if not iteration: iteration = update ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) best_ckpt_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + "best.ckpt" ) current_ckpt_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + "current.ckpt" ) best_iteration = ( self.trainer.early_stop_callback.early_stopping.best_monitored_iteration ) best_update = ( self.trainer.early_stop_callback.early_stopping.best_monitored_update ) best_metric = ( self.trainer.early_stop_callback.early_stopping.best_monitored_value ) model = self.trainer.model data_parallel = registry.get("data_parallel") or registry.get("distributed") fp16_scaler = getattr(self.trainer, "scaler", None) fp16_scaler_dict = None if fp16_scaler is not None: fp16_scaler_dict = fp16_scaler.state_dict() if data_parallel is True: model = model.module ckpt = { "model": model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "current_iteration": iteration, "current_epoch": self.trainer.current_epoch, "num_updates": update, "best_update": best_update, "best_metric_value": best_metric, "fp16_scaler": fp16_scaler_dict, # Convert to container to avoid any dependencies "config": OmegaConf.to_container(self.config, resolve=True), } lr_scheduler = self.trainer.lr_scheduler_callback._scheduler if lr_scheduler is not None: ckpt["lr_scheduler"] = lr_scheduler.state_dict() if self.git_repo: git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) with PathManager.open(ckpt_filepath, "wb") as f: self.save_func(ckpt, f) if update_best: logger.info("Saving best checkpoint") with PathManager.open(best_ckpt_filepath, "wb") as f: self.save_func(ckpt, f) # Save current always logger.info("Saving current checkpoint") with PathManager.open(current_ckpt_filepath, "wb") as f: self.save_func(ckpt, f) # Remove old checkpoints if max_to_keep is set if self.max_to_keep > 0: if len(self.saved_iterations) == self.max_to_keep: self.remove(self.saved_iterations.pop(0)) self.saved_iterations.append(update) logger.info("Checkpoint save operation finished!")
def save_func(self, *args): return xm.save(*args) if is_xla() else torch.save(*args)
def finalize(self): if is_main() or is_xla(): with open_if_main(self.pth_filepath, "wb") as f: self.save_func(self.trainer.model.state_dict(), f)
def build_dataloader_and_sampler( dataset_instance: mmf_typings.DatasetType, training_config: mmf_typings.DictConfig ) -> mmf_typings.DataLoaderAndSampler: """Builds and returns a dataloader along with its sample Args: dataset_instance (mmf_typings.DatasetType): Instance of dataset for which dataloader has to be created training_config (mmf_typings.DictConfig): Training configuration; required for infering params for dataloader Returns: mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator num_workers = training_config.num_workers pin_memory = training_config.pin_memory other_args = {} # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) if is_xla(): other_args["sampler"] = torch.utils.data.DistributedSampler( dataset_instance, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=other_args["shuffle"], ) other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, collate_fn=BatchCollator(dataset_instance.dataset_name, dataset_instance.dataset_type), num_workers=num_workers, drop_last=False, # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = pl.MpDeviceLoader(loader, device) if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)