def test_torch_model_on_epoch_begin(self): lr_scheduler = LRScheduler( model=self.torch_model, lr_fn=lambda epoch: fe.schedule.cosine_decay( epoch, cycle_length=3750, init_lr=1e-3)) lr_scheduler.system = sample_system_object() lr_scheduler.system.epoch_idx = 3 lr_scheduler.on_epoch_begin(data=self.data) new_lr = list(self.torch_model.optimizer.param_groups)[0]['lr'] self.assertTrue(math.isclose(new_lr, 0.0009999993, rel_tol=1e-5))
def test_tf_model_on_epoch_begin(self): lr_scheduler = LRScheduler( model=self.tf_model, lr_fn=lambda epoch: fe.schedule.cosine_decay( epoch, cycle_length=3750, init_lr=1e-3)) lr_scheduler.system = sample_system_object() lr_scheduler.system.epoch_idx = 3 lr_scheduler.on_epoch_begin(data=self.data) self.assertTrue( math.isclose(self.tf_model.optimizer.lr.numpy(), 0.0009999973, rel_tol=1e-5))