Esempio n. 1
0
    def setup_experiment(self, config):
        """
        :param config:
            - multi_cycle_lr_args: A list of (epoch, dict) pairs.
                                   The dicts don't need to include epoch
                                   counts, this is inferred from the config.
        """
        config = copy.deepcopy(config)

        ignored_class = config.pop("lr_scheduler_class", None)
        ignored_args = config.pop("lr_scheduler_args", None)

        config["lr_scheduler_step_every_batch"] = True

        super().setup_experiment(config)

        if ignored_class is not None and ignored_class != OneCycleLR:
            self.logger.warning("Ignoring lr_scheduler_class, using OneCycleLR")
        if ignored_args is not None and len(ignored_args) > 0:
            self.logger.warning("Ignoring lr_scheduler_args, using "
                                "multi_cycle_lr_args")

        # Insert epoch counts and div_factors
        improved_args = {}
        multi_cycle_lr_args = sorted(config["multi_cycle_lr_args"],
                                     key=lambda x: x[0])
        for i, (start_epoch, cycle_config) in enumerate(multi_cycle_lr_args):
            if i + 1 < len(multi_cycle_lr_args):
                end_epoch = multi_cycle_lr_args[i + 1][0]
            else:
                end_epoch = config["epochs"]

            cycle_config = copy.deepcopy(cycle_config)
            cycle_config["epochs"] = end_epoch - start_epoch

            # Default behavior: no sudden change in learning rate between
            # cycles.
            if "div_factor" not in cycle_config and i > 0:
                prev_cycle_config = multi_cycle_lr_args[i - 1][1]
                if "final_div_factor" in prev_cycle_config:
                    cycle_config["div_factor"] = \
                        prev_cycle_config["final_div_factor"]

            improved_args[start_epoch] = cycle_config

        self.multi_cycle_args_by_epoch = improved_args

        self.logger.info("MultiCycleLR regime: "
                         f"{self.multi_cycle_args_by_epoch}")

        # Set it immediately, rather than waiting for the pre_epoch, in case a
        # restore is occurring.
        args = self.multi_cycle_args_by_epoch[0]
        self.lr_scheduler = create_lr_scheduler(
            optimizer=self.optimizer,
            lr_scheduler_class=OneCycleLR,
            lr_scheduler_args=args,
            steps_per_epoch=self.total_batches)
Esempio n. 2
0
    def pre_epoch(self):
        super().pre_epoch()

        if self.current_epoch != 0 and \
           self.current_epoch in self.multi_cycle_args_by_epoch:

            args = self.multi_cycle_args_by_epoch[self.current_epoch]
            self.lr_scheduler = create_lr_scheduler(
                optimizer=self.optimizer,
                lr_scheduler_class=OneCycleLR,
                lr_scheduler_args=args,
                steps_per_epoch=self.total_batches)
Esempio n. 3
0
 def create_lr_scheduler(cls, config, optimizer, total_batches):
     """
     Create lr scheduler from the experiment config
     :param config:
         - lr_scheduler_class: (optional) Class of lr-scheduler
         - lr_scheduler_args: (optional) dict of args to pass to lr-class
     :param optimizer: torch optimizer
     :param total_batches: number of batches/steps in an epoch
     """
     lr_scheduler_class = config.get("lr_scheduler_class", None)
     if lr_scheduler_class is not None:
         lr_scheduler_args = config.get("lr_scheduler_args", {})
         return create_lr_scheduler(optimizer=optimizer,
                                    lr_scheduler_class=lr_scheduler_class,
                                    lr_scheduler_args=lr_scheduler_args,
                                    steps_per_epoch=total_batches)