def test_lr_scheduler(): def _test(torch_lr_scheduler_cls, **kwargs): tensor = torch.zeros([1], requires_grad=True) optimizer1 = torch.optim.SGD([tensor], lr=0.1) optimizer2 = torch.optim.SGD([tensor], lr=0.1) torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1, **kwargs) torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs) scheduler = LRScheduler(torch_lr_scheduler1) lrs = [] lrs_true = [] trainer = Engine(lambda engine, batch: None) @trainer.on(Events.ITERATION_STARTED) def torch_lr_scheduler_step(engine): torch_lr_scheduler2.step() @trainer.on(Events.ITERATION_COMPLETED) def save_lr(engine): lrs.append(optimizer1.param_groups[0]['lr']) @trainer.on(Events.ITERATION_COMPLETED) def save_true_lr(engine): lrs_true.append(optimizer2.param_groups[0]['lr']) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) data = [0] * 10 max_epochs = 2 trainer.run(data, max_epochs=max_epochs) assert lrs_true == pytest.approx(lrs) optimizer3 = torch.optim.SGD([tensor], lr=0.1) torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs) simulated_values = LRScheduler.simulate_values( num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler3) assert lrs == pytest.approx([v for i, v in simulated_values]) _test(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5) _test(torch.optim.lr_scheduler.ExponentialLR, gamma=0.78) # test _replicate_lr_scheduler tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0.1) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.78) init_lr_scheduler_state = dict(lr_scheduler.state_dict()) copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler) for _ in range(10): lr_scheduler.step() assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state
def test_lr_scheduler(): def _test(torch_lr_scheduler_cls, **kwargs): tensor = torch.zeros([1], requires_grad=True) optimizer1 = torch.optim.SGD([tensor], lr=0.01) optimizer2 = torch.optim.SGD([tensor], lr=0.01) opt_state_dict1 = optimizer1.state_dict() opt_state_dict2 = optimizer2.state_dict() torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1, **kwargs) scheduler = LRScheduler(torch_lr_scheduler1) state_dict1 = scheduler.state_dict() torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs) state_dict2 = torch_lr_scheduler2.state_dict() def dummy_update(engine, batch): optimizer1.step() optimizer2.step() trainer = Engine(dummy_update) @trainer.on(Events.ITERATION_STARTED) def save_lr(engine): lrs.append(optimizer1.param_groups[0]["lr"]) @trainer.on(Events.ITERATION_STARTED) def save_true_lr(engine): lrs_true.append(optimizer2.param_groups[0]["lr"]) @trainer.on(Events.ITERATION_COMPLETED) def torch_lr_scheduler_step(engine): torch_lr_scheduler2.step() trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) for _ in range(2): lrs = [] lrs_true = [] data = [0] * 10 max_epochs = 2 trainer.run(data, max_epochs=max_epochs) assert lrs_true == pytest.approx( lrs), "{}: {} ({}) vs {} ({})".format(_, lrs_true, len(lrs_true), lrs, len(lrs)) optimizer1.load_state_dict(opt_state_dict1) scheduler.load_state_dict(state_dict1) optimizer2.load_state_dict(opt_state_dict2) torch_lr_scheduler2.load_state_dict(state_dict2) optimizer3 = torch.optim.SGD([tensor], lr=0.01) torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs) simulated_values = LRScheduler.simulate_values( num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler3) assert lrs == pytest.approx([v for i, v in simulated_values]) _test(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5) _test(torch.optim.lr_scheduler.ExponentialLR, gamma=0.78) # test _replicate_lr_scheduler tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0.01) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.78) init_lr_scheduler_state = dict(lr_scheduler.state_dict()) copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler) for _ in range(10): optimizer.step() lr_scheduler.step() assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state with pytest.raises(TypeError): LRScheduler._replicate_lr_scheduler(12)