def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not trainer.loggers: raise MisconfigurationException( "Cannot use XLAStatsMonitor callback with Trainer that has no logger." ) device = trainer.strategy.root_device memory_info = xm.get_memory_info(device) epoch_time = time.time() - self._start_time free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory free_memory = trainer.strategy.reduce(free_memory) * 0.001 peak_memory = trainer.strategy.reduce(peak_memory) * 0.001 epoch_time = trainer.strategy.reduce(epoch_time) for logger in trainer.loggers: logger.log_metrics( { "avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory) }, step=trainer.current_epoch, ) if self._verbose: rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB") rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
def on_train_start(self, trainer, pl_module) -> None: if not trainer.logger: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") if trainer._device_type != _AcceleratorType.TPU: raise MisconfigurationException( "You are using XLAStatsMonitor but are not running on TPU" f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}." ) memory_info = xm.get_memory_info(pl_module.device) total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001 rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not trainer.loggers: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") if not isinstance(trainer.accelerator, TPUAccelerator): raise MisconfigurationException( "You are using XLAStatsMonitor but are not running on TPU." f" The accelerator is set to {trainer.accelerator.__class__.__name__}." ) device = trainer.strategy.root_device memory_info = xm.get_memory_info(device) total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001 rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """Gets stats for the given TPU device. Args: device: TPU device for which to get stats Returns: A dictionary mapping the metrics (free memory and peak memory) to their values. """ memory_info = xm.get_memory_info(device) free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory device_stats = { "avg. free memory (MB)": free_memory, "avg. peak memory (MB)": peak_memory, } return device_stats
def on_train_epoch_end(self, trainer, pl_module) -> None: logs = {} memory_info = xm.get_memory_info(pl_module.device) epoch_time = time.time() - self._start_time free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001 peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001 epoch_time = trainer.training_type_plugin.reduce(epoch_time) logs["avg. free memory (MB)"] = free_memory logs["avg. peak memory (MB)"] = peak_memory trainer.logger.log_metrics(logs, step=trainer.current_epoch) if self._verbose: rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB") rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: if self._dummy_batch == "DUMMY": self._dummy_batch = sample is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() if self.cuda: torch.cuda.empty_cache() if self.cfg.distributed_training.distributed_world_size == 1: return None else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() logging_outputs, ( sample_size, ooms, total_train_time, ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) self._cumulative_training_time = (total_train_time / self.data_parallel_world_size) overflow = False try: with torch.autograd.profiler.record_function("reduce-grads"): self.optimizer.all_reduce_grads(self.model) if utils.has_parameters(self.criterion): self.optimizer.all_reduce_grads(self.criterion) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (data_parallel_size / sample_size) since # DDP already normalizes by the number of data parallel workers. # Thus we get (sum_of_gradients / sample_size) at the end. if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats( ) else 1 self.optimizer.multiply_grads(num / sample_size) with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm( self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers # on tpu check tensor is slow if not self.tpu: if (not self.cfg.optimization.use_bmuf and self.cfg.distributed_training.distributed_wrapper != "SlowMo"): self._check_grad_norms(grad_norm) if not torch.isfinite(grad_norm).all(): # check local gradnorm single GPU case, trigger NanDetector raise FloatingPointError("gradients are Nan/Inf") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails with NanDetector(self.get_model()): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False, ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, "perform_additional_optimizer_actions"): if hasattr(self.optimizer, "fp32_params"): self.model.perform_additional_optimizer_actions( self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions( self.optimizer.optimizer) logging_output = None if (not overflow or self.cfg.distributed_training.distributed_wrapper == "SlowMo"): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.cfg.common.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) gb_free = mem_info["kb_free"] / 1024 / 1024 gb_total = mem_info["kb_total"] / 1024 / 1024 metrics.log_scalar( "gb_free", gb_free, priority=1500, round=1, weight=0, ) metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0, ) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if (self.cuda and self.cfg.common.empty_cache_freq > 0 and ((self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) % self.cfg.common.empty_cache_freq) == 0): torch.cuda.empty_cache() if self.cfg.common.fp16: metrics.log_scalar( "loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=4, weight=0, ) metrics.log_stop_time("train_wall") return logging_output