def setup_experiment(self, config): """ :param config: - multi_cycle_lr_args: A list of (epoch, dict) pairs. The dicts don't need to include epoch counts, this is inferred from the config. """ config = copy.deepcopy(config) ignored_class = config.pop("lr_scheduler_class", None) ignored_args = config.pop("lr_scheduler_args", None) config["lr_scheduler_step_every_batch"] = True super().setup_experiment(config) if ignored_class is not None and ignored_class != OneCycleLR: self.logger.warning("Ignoring lr_scheduler_class, using OneCycleLR") if ignored_args is not None and len(ignored_args) > 0: self.logger.warning("Ignoring lr_scheduler_args, using " "multi_cycle_lr_args") # Insert epoch counts and div_factors improved_args = {} multi_cycle_lr_args = sorted(config["multi_cycle_lr_args"], key=lambda x: x[0]) for i, (start_epoch, cycle_config) in enumerate(multi_cycle_lr_args): if i + 1 < len(multi_cycle_lr_args): end_epoch = multi_cycle_lr_args[i + 1][0] else: end_epoch = config["epochs"] cycle_config = copy.deepcopy(cycle_config) cycle_config["epochs"] = end_epoch - start_epoch # Default behavior: no sudden change in learning rate between # cycles. if "div_factor" not in cycle_config and i > 0: prev_cycle_config = multi_cycle_lr_args[i - 1][1] if "final_div_factor" in prev_cycle_config: cycle_config["div_factor"] = \ prev_cycle_config["final_div_factor"] improved_args[start_epoch] = cycle_config self.multi_cycle_args_by_epoch = improved_args self.logger.info("MultiCycleLR regime: " f"{self.multi_cycle_args_by_epoch}") # Set it immediately, rather than waiting for the pre_epoch, in case a # restore is occurring. args = self.multi_cycle_args_by_epoch[0] self.lr_scheduler = create_lr_scheduler( optimizer=self.optimizer, lr_scheduler_class=OneCycleLR, lr_scheduler_args=args, steps_per_epoch=self.total_batches)
def pre_epoch(self): super().pre_epoch() if self.current_epoch != 0 and \ self.current_epoch in self.multi_cycle_args_by_epoch: args = self.multi_cycle_args_by_epoch[self.current_epoch] self.lr_scheduler = create_lr_scheduler( optimizer=self.optimizer, lr_scheduler_class=OneCycleLR, lr_scheduler_args=args, steps_per_epoch=self.total_batches)
def create_lr_scheduler(cls, config, optimizer, total_batches): """ Create lr scheduler from the experiment config :param config: - lr_scheduler_class: (optional) Class of lr-scheduler - lr_scheduler_args: (optional) dict of args to pass to lr-class :param optimizer: torch optimizer :param total_batches: number of batches/steps in an epoch """ lr_scheduler_class = config.get("lr_scheduler_class", None) if lr_scheduler_class is not None: lr_scheduler_args = config.get("lr_scheduler_args", {}) return create_lr_scheduler(optimizer=optimizer, lr_scheduler_class=lr_scheduler_class, lr_scheduler_args=lr_scheduler_args, steps_per_epoch=total_batches)