def _before_task(self, train_loader, val_loader):
        self._n_classes += self._task_size
        self._network.add_classes(self._task_size)
        logger.info("Now {} examplars per class.".format(
            self._memory_per_class))

        self._optimizer = factory.get_optimizer(self._network.parameters(),
                                                self._opt_name, self._lr,
                                                self._weight_decay)

        base_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self._optimizer, self._scheduling, gamma=self._lr_decay)

        if self._warmup_config:
            if self._warmup_config.get("only_first_step",
                                       True) and self._task != 0:
                pass
            else:
                logger.info("Using WarmUp")
                self._scheduler = schedulers.GradualWarmupScheduler(
                    optimizer=self._optimizer,
                    after_scheduler=base_scheduler,
                    **self._warmup_config)
        else:
            self._scheduler = base_scheduler
Example #2
0
def get_lr_scheduler(scheduling_config,
                     optimizer,
                     nb_epochs,
                     lr_decay=0.1,
                     warmup_config=None,
                     task=0):
    if scheduling_config is None:
        return None
    elif isinstance(scheduling_config, str):
        warnings.warn("Use a dict not a string for scheduling config!",
                      DeprecationWarning)
        scheduling_config = {"type": scheduling_config}
    elif isinstance(scheduling_config, list):
        warnings.warn("Use a dict not a list for scheduling config!",
                      DeprecationWarning)
        scheduling_config = {"type": "step", "epochs": scheduling_config}

    if scheduling_config["type"] == "step":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            scheduling_config["epochs"],
            gamma=scheduling_config.get("gamma") or lr_decay)
    elif scheduling_config["type"] == "exponential":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, scheduling_config["gamma"])
    elif scheduling_config["type"] == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=scheduling_config["gamma"])
    elif scheduling_config["type"] == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, nb_epochs)
    elif scheduling_config["type"] == "cosine_with_restart":
        scheduler = schedulers.CosineWithRestarts(
            optimizer,
            t_max=scheduling_config.get("cycle_len", nb_epochs),
            factor=scheduling_config.get("factor", 1.))
    else:
        raise ValueError("Unknown LR scheduling type {}.".format(
            scheduling_config["type"]))

    if warmup_config:
        if warmup_config.get("only_first_step", True) and task != 0:
            pass
        else:
            print("Using WarmUp")
            scheduler = schedulers.GradualWarmupScheduler(
                optimizer=optimizer,
                after_scheduler=scheduler,
                **warmup_config)

    return scheduler