Esempio n. 1
0
 def _auto_step_lr_scheduler_per_batch(self, batch_idx: int, lr_scheduler: LRScheduler) -> None:
     """
     This function aims at automatically step a LR scheduler. It should be called per batch.
     """
     if lr_scheduler._step_mode == LRScheduler.StepMode.STEP_EVERY_BATCH:
         lr_scheduler.step()
     elif lr_scheduler._step_mode == LRScheduler.StepMode.STEP_EVERY_EPOCH:
         mod = (batch_idx + 1) % len(self.training_loader)
         if mod == 0 or mod < self.hvd_config.aggregation_frequency:
             lr_scheduler.step()
    def _auto_step_lr_scheduler_per_batch(
            self, batch_idx: int, lr_scheduler: pytorch.LRScheduler) -> None:
        """
        This function aims at automatically step a LR scheduler. It should be called per batch.
        """

        # Never step lr when we do not step optimizer.
        if not self.context._should_communicate_and_update():
            return

        if lr_scheduler._step_mode == pytorch.LRScheduler.StepMode.STEP_EVERY_BATCH:
            start_idx = batch_idx - self.context._aggregation_frequency + 1
            for i in range(start_idx, batch_idx + 1):
                if (i + 1) % lr_scheduler._frequency == 0:
                    lr_scheduler.step()
        elif lr_scheduler._step_mode == pytorch.LRScheduler.StepMode.STEP_EVERY_OPTIMIZER_STEP:
            if (batch_idx + 1) % lr_scheduler._frequency == 0:
                lr_scheduler.step()
        elif lr_scheduler._step_mode == pytorch.LRScheduler.StepMode.STEP_EVERY_EPOCH:
            # We will step if the next optimizer step will land in the next epoch.
            epoch_idx = self.get_epoch_idx(batch_idx)
            next_steppable_batch = batch_idx + self.context._aggregation_frequency
            next_batch_epoch_idx = self.get_epoch_idx(next_steppable_batch)
            for e in range(epoch_idx, next_batch_epoch_idx):
                if (e + 1) % lr_scheduler._frequency == 0:
                    lr_scheduler.step()