コード例 #1
0
    def slurm_sigusr1_handler_fn(self, signum: _SIGNUM,
                                 frame: FrameType) -> None:
        rank_zero_info("handling SIGUSR1")

        # save logger to make sure we get all the metrics
        for logger in self.trainer.loggers:
            logger.finalize("finished")
        # TODO: in v1.8 change this to use self.trainer.default_root_dir
        hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(
            self.trainer._weights_save_path_internal)
        self.trainer.save_checkpoint(hpc_save_path)

        if self.trainer.is_global_zero:
            # find job id
            job_id = os.environ["SLURM_JOB_ID"]
            cmd = ["scontrol", "requeue", job_id]

            # requeue job
            log.info(f"requeing job {job_id}...")
            try:
                result = call(cmd)
            except FileNotFoundError:
                # This can occur if a subprocess call to `scontrol` is run outside a shell context
                # Re-attempt call (now with shell context). If any error is raised, propagate to user.
                # When running a shell command, it should be passed as a single string.
                joint_cmd = [str(x) for x in cmd]
                result = call(" ".join(joint_cmd), shell=True)

            # print result text
            if result == 0:
                log.info(f"requeued exp {job_id}")
            else:
                log.warning("requeue failed...")
コード例 #2
0
    def resume_end(self) -> None:
        """Signal the connector that all states have resumed and memory for the checkpoint object can be
        released."""
        assert self.trainer.state.fn is not None
        if self.resume_checkpoint_path:
            if self.trainer.state.fn == TrainerFn.FITTING:
                rank_zero_info(
                    f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}"
                )
            elif self.trainer.state.fn in (TrainerFn.VALIDATING,
                                           TrainerFn.TESTING,
                                           TrainerFn.PREDICTING):
                rank_zero_info(
                    f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}"
                )
        # TODO: remove resume_from_checkpoint_fit_path in v2.0
        if (self.trainer.state.fn == TrainerFn.FITTING
                and self.resume_checkpoint_path
                == self.resume_from_checkpoint_fit_path):
            self.resume_from_checkpoint_fit_path = None
        self.resume_checkpoint_path = None
        self._loaded_checkpoint = {}

        # clear cache after restore
        torch.cuda.empty_cache()

        # wait for all to catch up
        self.trainer.strategy.barrier("CheckpointConnector.resume_end")
コード例 #3
0
 def _initialize_deepspeed_inference(self, model):
     # todo: Currently DeepSpeed requires optimizers at inference to partition weights correctly
     optimizer, scheduler = None, None
     if "optimizer" not in self.config:
         rank_zero_info(
             "You have not specified an optimizer or scheduler within the DeepSpeed config."
             " Using `configure_optimizers` to define optimizer and scheduler."
         )
         optimizer, lr_scheduler, _ = self._init_optimizers()
         if lr_scheduler is not None:
             scheduler = lr_scheduler.scheduler
     # todo: this is required for DeepSpeed throughput timers
     inference_config = {"train_micro_batch_size_per_gpu": 1}
     if "fp16" in self.config:
         inference_config.update({"fp16": self.config["fp16"]})
     if self.zero_stage_3:
         inference_config.update(
             {
                 "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"],
                 "zero_optimization": self.config["zero_optimization"],
             }
         )
     # Remove all module hooks before initializing new model
     remove_module_hooks(model)
     model, _, _, _ = deepspeed.initialize(
         args=argparse.Namespace(device_rank=self.root_device.index),
         config=inference_config,
         model=model,
         optimizer=optimizer,
         lr_scheduler=scheduler,
         model_parameters=[],
         dist_init_required=False,
     )
     self.model = model
コード例 #4
0
    def _initialize_deepspeed_train(self, model):
        optimizer, scheduler = None, None
        if "optimizer" in self.config:
            rank_zero_info(
                "You have specified an optimizer and/or scheduler within the DeepSpeed config."
                " It is recommended to define it in `LightningModule.configure_optimizers`."
            )
            lr_scheduler = None
        else:
            optimizer, lr_scheduler, _ = self._init_optimizers()
            if lr_scheduler is not None:
                scheduler = lr_scheduler.scheduler

        model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
        self._set_deepspeed_activation_checkpointing()

        # although we set these here, deepspeed manages the specific optimizer logic
        self.optimizers = [deepspeed_optimizer]

        deepspeed_scheduler = model.lr_scheduler
        if deepspeed_scheduler is not None:
            # disable deepspeed lr scheduling as lightning manages scheduling
            model.lr_scheduler = None
            if lr_scheduler is None:
                lr_scheduler = LRSchedulerConfig(deepspeed_scheduler, interval="step", opt_idx=0)
            else:
                lr_scheduler.scheduler = deepspeed_scheduler
            self.lr_scheduler_configs = [lr_scheduler]
        self.model = model
コード例 #5
0
    def __validate_init_configuration(self) -> None:
        if self.save_top_k < -1:
            raise MisconfigurationException(
                f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1"
            )
        if self._every_n_train_steps < 0:
            raise MisconfigurationException(
                f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0"
            )
        if self._every_n_epochs < 0:
            raise MisconfigurationException(
                f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0"
            )

        every_n_train_steps_triggered = self._every_n_train_steps >= 1
        every_n_epochs_triggered = self._every_n_epochs >= 1
        train_time_interval_triggered = self._train_time_interval is not None
        if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
            raise MisconfigurationException(
                f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
                f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
                "should be mutually exclusive.")

        if self.monitor is None:
            # -1: save all epochs, 0: nothing is saved, 1: save last epoch
            if self.save_top_k not in (-1, 0, 1):
                raise MisconfigurationException(
                    f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
                    " configuration. No quantity for top_k to track.")

            if self.save_top_k == -1 and self.save_last:
                rank_zero_info(
                    "ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
                    " will duplicate the last checkpoint saved.")
コード例 #6
0
    def teardown(self) -> None:
        log.detail(f"{self.__class__.__name__}: tearing down strategy")
        super().teardown()

        if isinstance(self.model, DistributedDataParallel):
            if (_TORCH_GREATER_EQUAL_1_11 and not self.model.static_graph
                    and self.model._get_ddp_logging_data().get(
                        "can_set_static_graph")):
                rank_zero_info(
                    "Your model can run with static graph optimizations. For future training runs, we suggest you"
                    f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them."
                )
            # unwrap model
            self.model = self.lightning_module

        if (self.lightning_module.trainer is not None
                and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
                and self._layer_sync):
            # `self.lightning_module.trainer` can be None if teardown gets called on an exception before
            # the trainer gets set on the LightningModule
            self.model = self._layer_sync.revert(self.model)

        if self.root_device.type == "cuda":
            # GPU teardown
            log.detail(f"{self.__class__.__name__}: moving model to CPU")
            self.lightning_module.cpu()
            # clean up memory
            torch.cuda.empty_cache()
コード例 #7
0
    def _attach_model_callbacks(self) -> None:
        """Attaches the callbacks defined in the model.

        If a callback returned by the model's configure_callback method has the same type as one or several
        callbacks already present in the trainer callbacks list, it will replace them.
        In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
        will be pushed to the end of the list, ensuring they run last.
        """
        model_callbacks = self.trainer._call_lightning_module_hook(
            "configure_callbacks")
        if not model_callbacks:
            return

        model_callbacks = [
            model_callbacks
        ] if not isinstance(model_callbacks, Sequence) else model_callbacks
        model_callback_types = {type(c) for c in model_callbacks}
        trainer_callback_types = {type(c) for c in self.trainer.callbacks}
        override_types = model_callback_types.intersection(
            trainer_callback_types)
        if override_types:
            rank_zero_info(
                "The following callbacks returned in `LightningModule.configure_callbacks` will override"
                " existing callbacks passed to Trainer:"
                f" {', '.join(sorted(t.__name__ for t in override_types))}")
        # remove all callbacks with a type that occurs in model callbacks
        all_callbacks = [
            c for c in self.trainer.callbacks if type(c) not in override_types
        ]
        all_callbacks.extend(model_callbacks)
        all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks)
        # TODO: connectors refactor: move callbacks list to connector and do not write Trainer state
        self.trainer.callbacks = all_callbacks
コード例 #8
0
ファイル: timer.py プロジェクト: neptune-ai/pytorch-lightning
 def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
     assert self._duration is not None
     should_stop = self.time_elapsed() >= self._duration
     should_stop = trainer.strategy.broadcast(should_stop)
     trainer.should_stop = trainer.should_stop or should_stop
     if should_stop and self._verbose:
         elapsed = timedelta(
             seconds=int(self.time_elapsed(RunningStage.TRAINING)))
         rank_zero_info(
             f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop."
         )
コード例 #9
0
 def _load_config(self, config):
     if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
         rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
         config = os.environ[self.DEEPSPEED_ENV_VAR]
     if isinstance(config, (str, Path)):
         if not os.path.isfile(config):
             raise MisconfigurationException(
                 f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
             )
         with open(config) as f:
             config = json.load(f)
     return config
コード例 #10
0
 def _save_monitor_checkpoint(
         self, trainer: "pl.Trainer",
         monitor_candidates: Dict[str, _METRIC]) -> None:
     current = monitor_candidates.get(self.monitor)
     if self.check_monitor_top_k(trainer, current):
         self._update_best_and_save(current, trainer, monitor_candidates)
     elif self.verbose:
         epoch = monitor_candidates["epoch"]
         step = monitor_candidates["step"]
         rank_zero_info(
             f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}"
         )
コード例 #11
0
 def configure_optimizers(self):
     parameters = list(self.parameters())
     trainable_parameters = list(
         filter(lambda p: p.requires_grad, parameters))
     rank_zero_info(
         f"The model will start training with only {len(trainable_parameters)} "
         f"trainable parameters out of {len(parameters)}.")
     optimizer = optim.Adam(trainable_parameters, lr=self.lr)
     scheduler = MultiStepLR(optimizer,
                             milestones=self.milestones,
                             gamma=self.lr_scheduler_gamma)
     return [optimizer], [scheduler]
コード例 #12
0
 def _configure_timer_callback(
     self,
     max_time: Optional[Union[str, timedelta, Dict[str,
                                                   int]]] = None) -> None:
     if max_time is None:
         return
     if any(isinstance(cb, Timer) for cb in self.trainer.callbacks):
         rank_zero_info(
             "Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer."
         )
         return
     timer = Timer(duration=max_time, interval="step")
     self.trainer.callbacks.append(timer)
コード例 #13
0
    def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not trainer.loggers:
            raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

        if not isinstance(trainer.accelerator, TPUAccelerator):
            raise MisconfigurationException(
                "You are using XLAStatsMonitor but are not running on TPU."
                f" The accelerator is set to {trainer.accelerator.__class__.__name__}."
            )

        device = trainer.strategy.root_device
        memory_info = xm.get_memory_info(device)
        total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
        rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
コード例 #14
0
    def _configure_model_summary_callback(
            self,
            enable_model_summary: bool,
            weights_summary: Optional[str] = None) -> None:
        if weights_summary is None:
            rank_zero_deprecation(
                "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
                " in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
            )
            return
        if not enable_model_summary:
            return

        model_summary_cbs = [
            type(cb) for cb in self.trainer.callbacks
            if isinstance(cb, ModelSummary)
        ]
        if model_summary_cbs:
            rank_zero_info(
                f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
                " Skipping setting a default `ModelSummary` callback.")
            return

        if weights_summary == "top":
            # special case the default value for weights_summary to preserve backward compatibility
            max_depth = 1
        else:
            rank_zero_deprecation(
                f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
                " in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
                " `max_depth` directly to the Trainer's `callbacks` argument instead."
            )
            if weights_summary not in ModelSummaryMode.supported_types():
                raise MisconfigurationException(
                    f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
                    f" but got {weights_summary}",
                )
            max_depth = ModelSummaryMode.get_max_depth(weights_summary)

        progress_bar_callback = self.trainer.progress_bar_callback
        is_progress_bar_rich = isinstance(progress_bar_callback,
                                          RichProgressBar)

        if progress_bar_callback is not None and is_progress_bar_rich:
            model_summary = RichModelSummary(max_depth=max_depth)
        else:
            model_summary = ModelSummary(max_depth=max_depth)
        self.trainer.callbacks.append(model_summary)
        self.trainer._weights_summary = weights_summary
コード例 #15
0
    def _update_best_and_save(self, current: Tensor, trainer: "pl.Trainer",
                              monitor_candidates: Dict[str, Tensor]) -> None:
        k = len(self.best_k_models
                ) + 1 if self.save_top_k == -1 else self.save_top_k

        del_filepath = None
        if len(self.best_k_models) == k and k > 0:
            del_filepath = self.kth_best_model_path
            self.best_k_models.pop(del_filepath)

        # do not save nan, replace with +/- inf
        if isinstance(current, Tensor) and torch.isnan(current):
            current = torch.tensor(
                float("inf" if self.mode == "min" else "-inf"),
                device=current.device)

        filepath = self._get_metric_interpolated_filepath_name(
            monitor_candidates, trainer, del_filepath)

        # save the current score
        self.current_score = current
        self.best_k_models[filepath] = current

        if len(self.best_k_models) == k:
            # monitor dict has reached k elements
            _op = max if self.mode == "min" else min
            self.kth_best_model_path = _op(
                self.best_k_models,
                key=self.best_k_models.get)  # type: ignore[arg-type]
            self.kth_value = self.best_k_models[self.kth_best_model_path]

        _op = min if self.mode == "min" else max
        self.best_model_path = _op(
            self.best_k_models,
            key=self.best_k_models.get)  # type: ignore[arg-type]
        self.best_model_score = self.best_k_models[self.best_model_path]

        if self.verbose:
            epoch = monitor_candidates["epoch"]
            step = monitor_candidates["step"]
            rank_zero_info(
                f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
                f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
            )
        self._save_checkpoint(trainer, filepath)

        if del_filepath is not None and filepath != del_filepath:
            trainer.strategy.remove_checkpoint(del_filepath)
コード例 #16
0
 def _format_precision_config(self) -> None:
     if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
         if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
             # FP16 is a DeepSpeed standalone AMP implementation
             rank_zero_info("Enabling DeepSpeed FP16.")
             self.config["fp16"] = {
                 "enabled": True,
                 "loss_scale": self.loss_scale,
                 "initial_scale_power": self.initial_scale_power,
                 "loss_scale_window": self.loss_scale_window,
                 "hysteresis": self.hysteresis,
                 "min_loss_scale": self.min_loss_scale,
             }
         elif "amp" not in self.config and self.precision_plugin.amp_type == AMPType.APEX:
             rank_zero_info("Enabling DeepSpeed APEX Implementation.")
             self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level}
コード例 #17
0
    def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
        """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

        1. from HPC weights if found
        2. from fault-tolerant auto-saved checkpoint if found
        3. from `checkpoint_path` file if provided
        4. don't restore
        """
        self.resume_checkpoint_path = self._hpc_resume_path or checkpoint_path
        checkpoint_path = self.resume_checkpoint_path
        if not checkpoint_path:
            log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
            return

        rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
        self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
コード例 #18
0
    def teardown(self) -> None:
        rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...")

        pl_module = self.lightning_module
        if (
            pl_module is not None
            # `self.lightning_module._trainer` can be None if teardown gets called on an exception before
            # the trainer gets set on the LightningModule
            and pl_module._trainer is not None
            and pl_module._trainer.state.fn == TrainerFn.FITTING
            and self._layer_sync
        ):
            assert self.model is not None
            self.model = self._layer_sync.revert(self.model)

        assert self.cluster_environment is not None
        self.cluster_environment.teardown()
        self.precision_plugin.teardown()
        self.accelerator.teardown()
コード例 #19
0
    def on_train_epoch_end(self, trainer: "pl.Trainer",
                           pl_module: "pl.LightningModule") -> None:
        if not trainer.loggers:
            raise MisconfigurationException(
                "Cannot use XLAStatsMonitor callback with Trainer that has no logger."
            )

        device = trainer.strategy.root_device
        memory_info = xm.get_memory_info(device)
        epoch_time = time.time() - self._start_time

        free_memory = memory_info["kb_free"]
        peak_memory = memory_info["kb_total"] - free_memory

        free_memory = trainer.strategy.reduce(free_memory) * 0.001
        peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
        epoch_time = trainer.strategy.reduce(epoch_time)

        for logger in trainer.loggers:
            logger.log_metrics(
                {
                    "avg. free memory (MB)": float(free_memory),
                    "avg. peak memory (MB)": float(peak_memory)
                },
                step=trainer.current_epoch,
            )

        if self._verbose:
            rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
            rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB")
            rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
コード例 #20
0
    def teardown(self) -> None:
        log.detail(f"{self.__class__.__name__}: tearing down strategy")

        pl_module = self.lightning_module
        if isinstance(self.model, DistributedDataParallel):
            if (_TORCH_GREATER_EQUAL_1_11 and not self.model.static_graph
                    and self.model._get_ddp_logging_data().get(
                        "can_set_static_graph")):
                rank_zero_info(
                    "Your model can run with static graph optimizations. For future training runs, we suggest you"
                    f" pass `Trainer(..., strategy={self.__class__.__name__}(static_graph=True))` to enable them."
                )
            # unwrap model
            self.model = pl_module

        if (pl_module is not None
                # `self.lightning_module._trainer` can be None if teardown gets called on an exception before
                # the trainer gets set on the LightningModule
                and pl_module._trainer is not None and
                pl_module._trainer.state.fn == TrainerFn.FITTING and
                self._layer_sync):
            self.model = self._layer_sync.revert(self.model)
        super().teardown()
コード例 #21
0
    def _configure_model_summary_callback(self,
                                          enable_model_summary: bool) -> None:
        if not enable_model_summary:
            return

        model_summary_cbs = [
            type(cb) for cb in self.trainer.callbacks
            if isinstance(cb, ModelSummary)
        ]
        if model_summary_cbs:
            rank_zero_info(
                f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
                " Skipping setting a default `ModelSummary` callback.")
            return

        progress_bar_callback = self.trainer.progress_bar_callback
        is_progress_bar_rich = isinstance(progress_bar_callback,
                                          RichProgressBar)

        if progress_bar_callback is not None and is_progress_bar_rich:
            model_summary = RichModelSummary()
        else:
            model_summary = ModelSummary()
        self.trainer.callbacks.append(model_summary)
コード例 #22
0
    def done(self) -> bool:
        """Evaluates when to leave the loop."""
        if self.trainer.num_training_batches == 0:
            rank_zero_info("`Trainer.fit` stopped: No training batches.")
            return True

        # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
        stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
        if stop_steps:
            rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
            return True

        # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
        # we use it here because the checkpoint data won't have `completed` increased yet
        stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
        if stop_epochs:
            # in case they are not equal, override so `trainer.current_epoch` has the expected value
            self.epoch_progress.current.completed = self.epoch_progress.current.processed
            rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
            return True

        if self.trainer.should_stop:
            # early stopping
            met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
            met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
            if met_min_epochs and met_min_steps:
                self.trainer.should_stop = True
                rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
                return True
            else:
                rank_zero_info(
                    f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
                    f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
                )
        self.trainer.should_stop = False
        return False
コード例 #23
0
    def on_train_epoch_start(self, trainer: "pl.Trainer",
                             pl_module: "pl.LightningModule"):
        if trainer.current_epoch == self.swa_start:
            # move average model to request device.
            self._average_model = self._average_model.to(self._device
                                                         or pl_module.device)

            optimizer = trainer.optimizers[0]
            if isinstance(self._swa_lrs, float):
                self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

            for lr, group in zip(self._swa_lrs, optimizer.param_groups):
                group["initial_lr"] = lr

            self._swa_scheduler = SWALR(
                optimizer,
                swa_lr=self._swa_lrs,
                anneal_epochs=self._annealing_epochs,
                anneal_strategy=self._annealing_strategy,
                last_epoch=trainer.max_epochs
                if self._annealing_strategy == "cos" else -1,
            )
            # We assert that there is only one optimizer on fit start, so know opt_idx is always 0
            default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler,
                                                      opt_idx=0)
            assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1

            if trainer.lr_scheduler_configs:
                scheduler_cfg = trainer.lr_scheduler_configs[0]
                if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
                    rank_zero_warn(
                        f"SWA is currently only supported every epoch. Found {scheduler_cfg}"
                    )
                rank_zero_info(
                    f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
                    f" for `{self._swa_scheduler.__class__.__name__}`")
                trainer.lr_scheduler_configs[0] = default_scheduler_cfg
            else:
                trainer.lr_scheduler_configs.append(default_scheduler_cfg)

            self.n_averaged = torch.tensor(0,
                                           dtype=torch.long,
                                           device=pl_module.device)

        if self.swa_start <= trainer.current_epoch <= self.swa_end:
            self.update_parameters(self._average_model, pl_module,
                                   self.n_averaged, self._avg_fn)

        # Note: No > here in case the callback is saved with the model and training continues
        if trainer.current_epoch == self.swa_end + 1:

            # Transfer weights from average model to pl_module
            self.transfer_weights(self._average_model, pl_module)

            # Reset BatchNorm for update
            self.reset_batch_norm_and_save_state(pl_module)

            # There is no need to perform either backward or optimizer.step as we are
            # performing only one pass over the train data-loader to compute activation statistics
            # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
            trainer.num_training_batches += 1
            trainer.fit_loop._skip_backward = True
            self._accumulate_grad_batches = trainer.accumulate_grad_batches

            trainer.accumulate_grad_batches = trainer.num_training_batches