def _auto_step_lr_scheduler_per_batch(self, batch_idx: int, lr_scheduler: LRScheduler) -> None: """ This function aims at automatically step a LR scheduler. It should be called per batch. """ if lr_scheduler._step_mode == LRScheduler.StepMode.STEP_EVERY_BATCH: lr_scheduler.step() elif lr_scheduler._step_mode == LRScheduler.StepMode.STEP_EVERY_EPOCH: mod = (batch_idx + 1) % len(self.training_loader) if mod == 0 or mod < self.hvd_config.aggregation_frequency: lr_scheduler.step()
def _auto_step_lr_scheduler_per_batch( self, batch_idx: int, lr_scheduler: pytorch.LRScheduler) -> None: """ This function aims at automatically step a LR scheduler. It should be called per batch. """ # Never step lr when we do not step optimizer. if not self.context._should_communicate_and_update(): return if lr_scheduler._step_mode == pytorch.LRScheduler.StepMode.STEP_EVERY_BATCH: start_idx = batch_idx - self.context._aggregation_frequency + 1 for i in range(start_idx, batch_idx + 1): if (i + 1) % lr_scheduler._frequency == 0: lr_scheduler.step() elif lr_scheduler._step_mode == pytorch.LRScheduler.StepMode.STEP_EVERY_OPTIMIZER_STEP: if (batch_idx + 1) % lr_scheduler._frequency == 0: lr_scheduler.step() elif lr_scheduler._step_mode == pytorch.LRScheduler.StepMode.STEP_EVERY_EPOCH: # We will step if the next optimizer step will land in the next epoch. epoch_idx = self.get_epoch_idx(batch_idx) next_steppable_batch = batch_idx + self.context._aggregation_frequency next_batch_epoch_idx = self.get_epoch_idx(next_steppable_batch) for e in range(epoch_idx, next_batch_epoch_idx): if (e + 1) % lr_scheduler._frequency == 0: lr_scheduler.step()