def maybe_update_lr(self, epoch=None):
        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch

        if ep < 1000:
            self.optimizer.param_groups[0]['lr'] = poly_lr(
                ep, 1000, self.initial_lr, 0.9)
            self.print_to_log_file("lr:",
                                   poly_lr(ep, 1000, self.initial_lr, 0.9))
        else:
            new_lr = cycle_lr(
                ep, 200, min_lr=1e-6,
                max_lr=1e-2)  # we don't go all the way back up to initial lr
            self.optimizer.param_groups[0]['lr'] = new_lr
            self.print_to_log_file("lr:", new_lr)
示例#2
0
    def maybe_update_lr(self, epoch=None):
        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch

        for idx, (opt, _) in enumerate(self.opt_loss):
            opt.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lrs[idx], 0.9)
            self.print_to_log_file("lr:", np.round(opt.param_groups[0]['lr'], decimals=6))
示例#3
0
    def maybe_update_lr(self, epoch=None):
        """
        if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1

        (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
        herefore we need to do +1 here)

        :param epoch:
        :return:
        """
        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch
        self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
        self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))
示例#4
0
    def maybe_update_lr(self, epoch=None):
        """
        here we go one step, then use polyLR
        :param epoch:
        :return:
        """
        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch

        if 0 <= ep < 500:
            new_lr = self.initial_lr
        elif 500 <= ep < 675:
            new_lr = self.initial_lr * 0.1
        elif ep >= 675:
            new_lr = poly_lr(ep - 675, self.max_num_epochs - 675,
                             self.initial_lr * 0.1, 0.9)
        else:
            raise RuntimeError("Really unexpected things happened, ep=%d" % ep)

        self.optimizer.param_groups[0]['lr'] = new_lr
        self.print_to_log_file("lr:", self.optimizer.param_groups[0]['lr'])