Ejemplo n.º 1
0
    def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
        if trainer.current_epoch == self.swa_start:
            # move average model to request device.
            self._average_model = self._average_model.to(self._device or pl_module.device)

            optimizers = trainer.optimizers

            for param_group in optimizers[0].param_groups:
                if self._swa_lrs is None:
                    initial_lr = param_group["lr"]

                elif isinstance(self._swa_lrs, float):
                    initial_lr = self._swa_lrs

                else:
                    initial_lr = self._swa_lrs[0]

                param_group["initial_lr"] = initial_lr

            self._swa_lrs = initial_lr

            self._swa_scheduler = SWALR(
                optimizers[0],
                swa_lr=initial_lr,
                anneal_epochs=self._annealing_epochs,
                anneal_strategy=self._annealing_strategy,
                last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
            )
            _scheduler_config = _get_default_scheduler_config()
            assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
            _scheduler_config["scheduler"] = self._swa_scheduler

            if trainer.lr_schedulers:
                lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
                rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
                trainer.lr_schedulers[0] = _scheduler_config
            else:
                trainer.lr_schedulers.append(_scheduler_config)

            self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

        if self.swa_start <= trainer.current_epoch <= self.swa_end:
            self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

        # Note: No > here in case the callback is saved with the model and training continues
        if trainer.current_epoch == self.swa_end + 1:

            # Transfer weights from average model to pl_module
            self.transfer_weights(self._average_model, pl_module)

            # Reset BatchNorm for update
            self.reset_batch_norm_and_save_state(pl_module)

            # There is no need to perform either backward or optimizer.step as we are
            # performing only one pass over the train data-loader to compute activation statistics
            # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
            trainer.num_training_batches += 1
            trainer.train_loop._skip_backward = True
            self._accumulate_grad_batches = trainer.accumulate_grad_batches
            trainer.accumulate_grad_batches = len(trainer.train_dataloader)
Ejemplo n.º 2
0
    def _initialize_deepspeed_train(self, model):
        if "optimizer" in self.config:
            optimizer, lr_scheduler = None, _get_default_scheduler_config()
        else:
            rank_zero_info(
                "You have not specified an optimizer or scheduler within the DeepSpeed config."
                "Using `configure_optimizers` to define optimizer and scheduler."
            )
            optimizer, lr_scheduler, _ = self._init_optimizers()

        scheduler = lr_scheduler["scheduler"]

        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        model, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
            config=self.config,
            model=model,
            model_parameters=model_parameters,
            optimizer=optimizer,
            lr_scheduler=scheduler,
            dist_init_required=False,
        )

        self._set_deepspeed_activation_checkpointing()

        # although we set these here, deepspeed manages the specific optimizer logic
        self.lightning_module.trainer.optimizers = [deepspeed_optimizer]
        if deepspeed_scheduler is not None:
            lr_scheduler["scheduler"] = deepspeed_scheduler
            self.lightning_module.trainer.lr_schedulers = [lr_scheduler]
        self.model = model
Ejemplo n.º 3
0
    def _initialize_deepspeed_train(self, model):
        if "optimizer" in self.config:
            optimizer, lr_scheduler = None, _get_default_scheduler_config()
        else:
            rank_zero_info(
                "You have not specified an optimizer or scheduler within the DeepSpeed config."
                " Using `configure_optimizers` to define optimizer and scheduler."
            )
            optimizer, lr_scheduler, _ = self._init_optimizers()

        scheduler = lr_scheduler["scheduler"]
        model, deepspeed_optimizer = self._setup_model_and_optimizer(
            model, optimizer, scheduler)
        self._set_deepspeed_activation_checkpointing()

        # although we set these here, deepspeed manages the specific optimizer logic
        self.lightning_module.trainer.optimizers = [deepspeed_optimizer]

        deepspeed_scheduler = model.lr_scheduler
        if deepspeed_scheduler is not None:
            # disable deepspeed lr scheduling as lightning manages scheduling
            model.lr_scheduler = None
            lr_scheduler["scheduler"] = deepspeed_scheduler
            self.lightning_module.trainer.lr_schedulers = [lr_scheduler]
        self.model = model
Ejemplo n.º 4
0
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if trainer.current_epoch == self.swa_start:
            # move average model to request device.
            self._average_model = self._average_model.to(self._device or pl_module.device)

            optimizer = trainer.optimizers[0]
            if self._swa_lrs is None:
                self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
            if isinstance(self._swa_lrs, float):
                self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

            for lr, group in zip(self._swa_lrs, optimizer.param_groups):
                group["initial_lr"] = lr

            self._swa_scheduler = SWALR(
                optimizer,
                swa_lr=self._swa_lrs,
                anneal_epochs=self._annealing_epochs,
                anneal_strategy=self._annealing_strategy,
                last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
            )
            default_scheduler_cfg = _get_default_scheduler_config()
            assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
            default_scheduler_cfg["scheduler"] = self._swa_scheduler

            if trainer.lr_schedulers:
                scheduler_cfg = trainer.lr_schedulers[0]
                if scheduler_cfg["interval"] != "epoch" or scheduler_cfg["frequency"] != 1:
                    rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
                rank_zero_info(
                    f"Swapping scheduler `{scheduler_cfg['scheduler'].__class__.__name__}`"
                    f" for `{self._swa_scheduler.__class__.__name__}`"
                )
                trainer.lr_schedulers[0] = default_scheduler_cfg
            else:
                trainer.lr_schedulers.append(default_scheduler_cfg)

            self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

        if self.swa_start <= trainer.current_epoch <= self.swa_end:
            self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

        # Note: No > here in case the callback is saved with the model and training continues
        if trainer.current_epoch == self.swa_end + 1:

            # Transfer weights from average model to pl_module
            self.transfer_weights(self._average_model, pl_module)

            # Reset BatchNorm for update
            self.reset_batch_norm_and_save_state(pl_module)

            # There is no need to perform either backward or optimizer.step as we are
            # performing only one pass over the train data-loader to compute activation statistics
            # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
            trainer.num_training_batches += 1
            trainer.fit_loop._skip_backward = True
            self._accumulate_grad_batches = trainer.accumulate_grad_batches

            trainer.accumulate_grad_batches = trainer.num_training_batches
Ejemplo n.º 5
0
 def _init_optimizers(
     self
 ) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]],
            Optional[int]]:
     optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
         self.lightning_module)
     if len(optimizers) > 1 or len(schedulers) > 1:
         raise MisconfigurationException(
             "DeepSpeed currently only supports single optimizer, single optional scheduler."
         )
     return (
         optimizers[0],
         schedulers[0] if schedulers else _get_default_scheduler_config(),
         optimizer_frequencies[0] if optimizer_frequencies else None,
     )
Ejemplo n.º 6
0
    def on_train_epoch_start(
        self,
        trainer: 'pl.Trainer',
        pl_module: 'pl.LightningModule',
    ):
        """
        Repalce current lr scheduler with SWA scheduler
        """
        if trainer.current_epoch == self.swa_start:
            optimizer = trainer.optimizers[0]

            # move average model to request device.
            self._average_model = self._average_model.to(self._device
                                                         or pl_module.device)

            _scheduler = self.get_swa_scheduler(optimizer)
            self._swa_scheduler = _get_default_scheduler_config()
            if not isinstance(_scheduler, dict):
                _scheduler = {"scheduler": _scheduler}
            self._swa_scheduler.update(_scheduler)

            if trainer.lr_schedulers:
                lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
                rank_zero_warn(
                    f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}"
                )
                trainer.lr_schedulers[0] = self._swa_scheduler
            else:
                trainer.lr_schedulers.append(self._swa_scheduler)

            self.n_averaged = torch.tensor(0,
                                           dtype=torch.long,
                                           device=pl_module.device)

        if self.swa_start <= trainer.current_epoch <= self.swa_end:
            self.update_parameters(self._average_model, pl_module,
                                   self.n_averaged, self.avg_fn)

        if trainer.current_epoch == self.swa_end + 1:
            raise NotImplementedError("This should never happen (yet)")
Ejemplo n.º 7
0
        def func(model):
            # Decide the structure of the output from init_optimizers
            optimizers, _, _ = init_optimizers(model)

            if len(optimizers) != 1:
                raise MisconfigurationException(
                    f"`model.configure_optimizers()` returned {len(optimizers)}, but"
                    " learning rate finder only works with single optimizer"
                )

            optimizer = optimizers[0]

            new_lrs = [self.lr_min] * len(optimizer.param_groups)
            for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
                param_group["lr"] = new_lr
                param_group["initial_lr"] = new_lr

            args = (optimizer, self.lr_max, self.num_training)
            scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
            sched_config = _get_default_scheduler_config()
            sched_config.update({"scheduler": scheduler, "interval": "step"})

            return [optimizer], [sched_config], []
Ejemplo n.º 8
0
 def configure_scheduler(self, lr_scheduler):
     scheduler = _get_default_scheduler_config()
     scheduler["scheduler"] = lr_scheduler
     return [scheduler]