def on_train_epoch_end(self, trainer: "pl.Trainer",
                           pl_module: "pl.LightningModule") -> None:
        if not trainer.loggers:
            raise MisconfigurationException(
                "Cannot use XLAStatsMonitor callback with Trainer that has no logger."
            )

        device = trainer.strategy.root_device
        memory_info = xm.get_memory_info(device)
        epoch_time = time.time() - self._start_time

        free_memory = memory_info["kb_free"]
        peak_memory = memory_info["kb_total"] - free_memory

        free_memory = trainer.strategy.reduce(free_memory) * 0.001
        peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
        epoch_time = trainer.strategy.reduce(epoch_time)

        for logger in trainer.loggers:
            logger.log_metrics(
                {
                    "avg. free memory (MB)": float(free_memory),
                    "avg. peak memory (MB)": float(peak_memory)
                },
                step=trainer.current_epoch,
            )

        if self._verbose:
            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB")
            rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
    def on_train_start(self, trainer, pl_module) -> None:
        if not trainer.logger:
            raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

        if trainer._device_type != _AcceleratorType.TPU:
            raise MisconfigurationException(
                "You are using XLAStatsMonitor but are not running on TPU"
                f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}."
            )

        memory_info = xm.get_memory_info(pl_module.device)
        total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
        rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
    def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not trainer.loggers:
            raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

        if not isinstance(trainer.accelerator, TPUAccelerator):
            raise MisconfigurationException(
                "You are using XLAStatsMonitor but are not running on TPU."
                f" The accelerator is set to {trainer.accelerator.__class__.__name__}."
            )

        device = trainer.strategy.root_device
        memory_info = xm.get_memory_info(device)
        total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
        rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
Ejemplo n.º 4
0
    def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
        """Gets stats for the given TPU device.

        Args:
            device: TPU device for which to get stats

        Returns:
            A dictionary mapping the metrics (free memory and peak memory) to their values.
        """
        memory_info = xm.get_memory_info(device)
        free_memory = memory_info["kb_free"]
        peak_memory = memory_info["kb_total"] - free_memory
        device_stats = {
            "avg. free memory (MB)": free_memory,
            "avg. peak memory (MB)": peak_memory,
        }
        return device_stats
Ejemplo n.º 5
0
    def on_train_epoch_end(self, trainer, pl_module) -> None:
        logs = {}
        memory_info = xm.get_memory_info(pl_module.device)
        epoch_time = time.time() - self._start_time

        free_memory = memory_info["kb_free"]
        peak_memory = memory_info["kb_total"] - free_memory

        free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001
        peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001
        epoch_time = trainer.training_type_plugin.reduce(epoch_time)

        logs["avg. free memory (MB)"] = free_memory
        logs["avg. peak memory (MB)"] = peak_memory
        trainer.logger.log_metrics(logs, step=trainer.current_epoch)

        if self._verbose:
            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB")
            rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
Ejemplo n.º 6
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                if self._dummy_batch == "DUMMY":
                    self._dummy_batch = sample
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.data_parallel_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.cfg.distributed_training.distributed_world_size == 1:
                        return None
                else:
                    raise e

            if self.tpu and i < len(samples) - 1:
                # tpu-comment: every XLA operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass without a backward pass
                import torch_xla.core.xla_model as xm

                xm.mark_step()

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            train_time = self._local_cumulative_training_time()
            logging_outputs, (
                sample_size,
                ooms,
                total_train_time,
            ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ooms,
                train_time,
                ignore=is_dummy_batch,
            )
            self._cumulative_training_time = (total_train_time /
                                              self.data_parallel_world_size)

        overflow = False
        try:
            with torch.autograd.profiler.record_function("reduce-grads"):
                self.optimizer.all_reduce_grads(self.model)
                if utils.has_parameters(self.criterion):
                    self.optimizer.all_reduce_grads(self.criterion)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (data_parallel_size / sample_size) since
                # DDP already normalizes by the number of data parallel workers.
                # Thus we get (sum_of_gradients / sample_size) at the end.
                if not self.cfg.optimization.use_bmuf:
                    self.optimizer.multiply_grads(
                        self.data_parallel_world_size / sample_size)
                elif sample_size > 0:  # BMUF needs to check sample size
                    num = self.data_parallel_world_size if self._sync_stats(
                    ) else 1
                    self.optimizer.multiply_grads(num / sample_size)

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(
                    self.cfg.optimization.clip_norm)

            # check that grad norms are consistent across workers
            # on tpu check tensor is slow
            if not self.tpu:
                if (not self.cfg.optimization.use_bmuf
                        and self.cfg.distributed_training.distributed_wrapper
                        != "SlowMo"):
                    self._check_grad_norms(grad_norm)
                if not torch.isfinite(grad_norm).all():
                    # check local gradnorm single GPU case, trigger NanDetector
                    raise FloatingPointError("gradients are Nan/Inf")

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            with NanDetector(self.get_model()):
                self.task.train_step(
                    sample,
                    self.model,
                    self.criterion,
                    self.optimizer,
                    self.get_num_updates(),
                    ignore_grad=False,
                )
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.0).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, "perform_additional_optimizer_actions"):
            if hasattr(self.optimizer, "fp32_params"):
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(
                    self.optimizer.optimizer)

        logging_output = None
        if (not overflow or self.cfg.distributed_training.distributed_wrapper
                == "SlowMo"):
            self.set_num_updates(self.get_num_updates() + 1)

            if self.tpu:
                # mark step on TPUs
                import torch_xla.core.xla_model as xm

                xm.mark_step()

                # only log stats every log_interval steps
                # this causes wps to be misreported when log_interval > 1
                logging_output = {}
                if self.get_num_updates() % self.cfg.common.log_interval == 0:
                    # log memory usage
                    mem_info = xm.get_memory_info(self.device)
                    gb_free = mem_info["kb_free"] / 1024 / 1024
                    gb_total = mem_info["kb_total"] / 1024 / 1024
                    metrics.log_scalar(
                        "gb_free",
                        gb_free,
                        priority=1500,
                        round=1,
                        weight=0,
                    )
                    metrics.log_scalar(
                        "gb_total",
                        gb_total,
                        priority=1600,
                        round=1,
                        weight=0,
                    )

                    logging_output = self._reduce_and_log_stats(
                        logging_outputs,
                        sample_size,
                        grad_norm,
                    )

                # log whenever there's an XLA compilation, since these
                # slow down training and may indicate opportunities for
                # optimization
                self._check_xla_compilation()
            else:
                # log stats
                logging_output = self._reduce_and_log_stats(
                    logging_outputs,
                    sample_size,
                    grad_norm,
                )

                # clear CUDA cache to reduce memory fragmentation
                if (self.cuda and self.cfg.common.empty_cache_freq > 0
                        and ((self.get_num_updates() +
                              self.cfg.common.empty_cache_freq - 1) %
                             self.cfg.common.empty_cache_freq) == 0):
                    torch.cuda.empty_cache()

        if self.cfg.common.fp16:
            metrics.log_scalar(
                "loss_scale",
                self.optimizer.scaler.loss_scale,
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")
        return logging_output