def test_sched_config_parse_from_cls(self): """Test that we can parse a scheduler from a class""" model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) basic_sched_config = { "_target_": "mridc.core.conf.schedulers.CosineAnnealingParams", "params": { "min_lr": 0.1 }, "max_steps": self.MAX_STEPS, } scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler( opt, basic_sched_config) if not isinstance(scheduler_setup["scheduler"], optim.lr_scheduler.CosineAnnealing): raise AssertionError dict_config = omegaconf.OmegaConf.create(basic_sched_config) scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler( opt, dict_config) if not isinstance(scheduler_setup["scheduler"], optim.lr_scheduler.CosineAnnealing): raise AssertionError
def test_SquareAnnealing(self): """Test SquareAnnealing""" model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) # No warmup case policy = optim.lr_scheduler.SquareAnnealing(opt, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr != self.INITIAL_LR: raise AssertionError for _ in range(self.MAX_STEPS): if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr != self.MIN_LR: raise AssertionError # Warmup steps available policy = optim.lr_scheduler.SquareAnnealing(opt, warmup_steps=5, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr >= self.INITIAL_LR: raise AssertionError for i in range(self.MAX_STEPS): if i <= 5: if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError elif policy.get_last_lr()[0] >= self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr != self.MIN_LR: raise AssertionError
def test_register_optimizer(self): """Test that we can register a new optimizer""" class TempOpt(torch.optim.SGD): """A dummy optimizer""" class TempOptParams(optimizers.SGDParams): """A dummy optimizer params""" register_optimizer("TempOpt", TempOpt, TempOptParams) model = TempModel() opt_cls = get_optimizer("TempOpt") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) if not isinstance(opt, TempOpt): raise AssertionError
def test_get_optimizer(self): """Test that the optimizer is correctly created""" model = TempModel() for opt_name in AVAILABLE_OPTIMIZERS: if opt_name == "fused_adam" and not torch.cuda.is_available(): continue opt_cls = get_optimizer(opt_name) if opt_name == "adafactor": # Adafactor's default mode uses relative_step without any lr. opt = opt_cls(model.parameters()) else: opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) if not isinstance(opt, AVAILABLE_OPTIMIZERS[opt_name]): raise AssertionError
def test_sched_config_parse_simple(self): """Test that scheduler config is parsed correctly""" model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) basic_sched_config = {"name": "CosineAnnealing", "max_steps": 10} scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler( opt, basic_sched_config) if not isinstance(scheduler_setup["scheduler"], optim.lr_scheduler.CosineAnnealing): raise AssertionError dict_config = omegaconf.OmegaConf.create(basic_sched_config) scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler( opt, dict_config) if not isinstance(scheduler_setup["scheduler"], optim.lr_scheduler.CosineAnnealing): raise AssertionError
def test_register_scheduler(self): """Test registering a new scheduler""" class TempSched(optim.lr_scheduler.CosineAnnealing): """Temporary scheduler class.""" class TempSchedParams(CosineAnnealingParams): """Temporary scheduler class.""" optim.lr_scheduler.register_scheduler("TempSched", TempSched, TempSchedParams) model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) sched_cls = optim.lr_scheduler.get_scheduler("TempSched") sched = sched_cls(opt, max_steps=self.MAX_STEPS) if not isinstance(sched, TempSched): raise AssertionError
def test_CosineAnnealing_with_noop_steps(self): """Test CosineAnnealing with noop steps.""" model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) # No warmup case policy = optim.lr_scheduler.CosineAnnealing(opt, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr != self.INITIAL_LR: raise AssertionError update_steps = 0 for i in range(self.MAX_STEPS): if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError opt.step() policy.step() # Perform a No-Op for scheduler every 2 steps if i % 2 == 0: policy.last_epoch -= 1 else: update_steps += 1 policy.step() update_steps += 1 if update_steps >= self.MAX_STEPS: raise AssertionError final_lr = policy.get_last_lr()[0] if final_lr <= self.MIN_LR: raise AssertionError # update step = true number of updates performed after some number of skipped steps true_end_lr = policy._get_lr(step=update_steps)[0] if final_lr != true_end_lr: raise AssertionError
def test_PolynomialHoldDecayAnnealing(self): """Test PolynomialHoldDecayAnnealing""" model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) # No warmup case policy = optim.lr_scheduler.PolynomialHoldDecayAnnealing( opt, power=2, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr != self.INITIAL_LR: raise AssertionError for _ in range(self.MAX_STEPS): if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr <= self.MIN_LR: raise AssertionError # Warmup steps available policy = optim.lr_scheduler.PolynomialHoldDecayAnnealing( opt, power=2, warmup_steps=5, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr >= self.INITIAL_LR: raise AssertionError for _ in range(self.MAX_STEPS): if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr < self.MIN_LR: raise AssertionError # Warmup + Hold steps available policy = optim.lr_scheduler.PolynomialHoldDecayAnnealing( opt, warmup_steps=5, hold_steps=3, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR, power=2) initial_lr = policy.get_last_lr()[0] if initial_lr >= self.INITIAL_LR: raise AssertionError for i in range(self.MAX_STEPS): if i <= 4: if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError elif i <= 8: if policy.get_last_lr()[0] < self.INITIAL_LR: raise AssertionError elif policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr < self.MIN_LR: raise AssertionError
def test_WarmupAnnealing(self): """Test that the warmup annealing policy works as expected.""" model = TempModel() opt_cls = get_optimizer("novograd") opt = opt_cls(model.parameters(), lr=self.INITIAL_LR) # No warmup case policy = optim.lr_scheduler.WarmupAnnealing(opt, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr != self.INITIAL_LR: raise AssertionError for _ in range(self.MAX_STEPS): if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr < self.MIN_LR: raise AssertionError # Warmup steps available policy = optim.lr_scheduler.WarmupAnnealing(opt, warmup_steps=5, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr >= self.INITIAL_LR: raise AssertionError for i in range(self.MAX_STEPS): if i <= 5: if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError elif policy.get_last_lr()[0] >= self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr != self.MIN_LR: raise AssertionError # Warmup + Hold steps available policy = optim.lr_scheduler.WarmupHoldPolicy(opt, warmup_steps=5, hold_steps=3, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR) initial_lr = policy.get_last_lr()[0] if initial_lr >= self.INITIAL_LR: raise AssertionError for i in range(self.MAX_STEPS): if i <= 4: if policy.get_last_lr()[0] > self.INITIAL_LR: raise AssertionError elif policy.get_last_lr()[0] != self.INITIAL_LR: raise AssertionError opt.step() policy.step() policy.step() final_lr = policy.get_last_lr()[0] if final_lr < self.MIN_LR: raise AssertionError