Exemple #1
0
 def assert_lr_correct(self, optimizer, targets, epochs, min_lr, max_lr,
                       base_period, period_mult):
     targets = [targets] if len(optimizer.param_groups) == 1 else targets
     scheduler = WarmRestartLR(optimizer, min_lr, max_lr, base_period,
                               period_mult)
     for epoch in range(epochs):
         scheduler.step(epoch)
         for param_group, target in zip(optimizer.param_groups, targets):
             assert param_group['lr'] == pytest.approx(target[epoch])
Exemple #2
0
 def assert_lr_correct(self, optimizer, targets, epochs, min_lr, max_lr,
                       base_period, period_mult):
     """Test that learning rate was set correctly."""
     targets = [targets] if len(optimizer.param_groups) == 1 else targets
     scheduler = WarmRestartLR(optimizer, min_lr, max_lr, base_period,
                               period_mult)
     for epoch in range(epochs):
         optimizer.step()  # suppress warning about .step call order
         scheduler.step(epoch)
         for param_group, target in zip(optimizer.param_groups, targets):
             assert param_group['lr'] == pytest.approx(target[epoch])
 def assert_lr_correct(
         self, optimizer, targets, epochs, min_lr, max_lr, base_period,
         period_mult):
     targets = [targets] if len(optimizer.param_groups) == 1 else targets
     scheduler = WarmRestartLR(
         optimizer, min_lr, max_lr, base_period, period_mult
     )
     for epoch in range(epochs):
         scheduler.step(epoch)
         for param_group, target in zip(optimizer.param_groups, targets):
             assert param_group['lr'] == pytest.approx(target[epoch])
Exemple #4
0
 def test_raise_incompatible_len_on_max_lr_err(self, init_optimizer):
     with pytest.raises(ValueError) as excinfo:
         WarmRestartLR(init_optimizer, max_lr=[1e-1, 1e-2])
     assert 'max_lr' in str(excinfo.value)