Esempio n. 1
0
 def test_torch_model_on_batch_begin(self):
     lr_scheduler = LRScheduler(model=self.torch_model,
                                lr_fn=lambda step: fe.schedule.cosine_decay(
                                    step, cycle_length=3750, init_lr=1e-3))
     lr_scheduler.system = sample_system_object()
     lr_scheduler.system.global_step = 3
     lr_scheduler.on_batch_begin(data=self.data)
     new_lr = list(self.torch_model.optimizer.param_groups)[0]['lr']
     self.assertTrue(math.isclose(new_lr, 0.0009999993, rel_tol=1e-5))
Esempio n. 2
0
 def test_torch_model_on_batch_end(self):
     model_name = self.torch_model.model_name + '_lr'
     lr_scheduler = LRScheduler(model=self.torch_model,
                                lr_fn=lambda step: fe.schedule.cosine_decay(
                                    step, cycle_length=3750, init_lr=1e-3))
     lr_scheduler.system = sample_system_object()
     lr_scheduler.system.global_step = 3
     lr_scheduler.system.log_steps = 1
     lr_scheduler.on_batch_end(data=self.data)
     self.assertTrue(
         math.isclose(self.data[model_name], 0.001, rel_tol=1e-3))
Esempio n. 3
0
 def test_tf_model_on_batch_begin(self):
     lr_scheduler = LRScheduler(model=self.tf_model,
                                lr_fn=lambda step: fe.schedule.cosine_decay(
                                    step, cycle_length=3750, init_lr=1e-3))
     lr_scheduler.system = sample_system_object()
     lr_scheduler.system.global_step = 3
     lr_scheduler.on_batch_begin(data=self.data)
     self.assertTrue(
         math.isclose(self.tf_model.optimizer.lr.numpy(),
                      0.0009999973,
                      rel_tol=1e-5))