def test_swa_raises(self): # Tests that SWA raises errors for wrong parameter values x, y, loss_fun, opt = self._define_vars_loss_opt() with self.assertRaisesRegex(ValueError, "Invalid SWA learning rate: -0.0001"): opt = contriboptim.SWA(opt, swa_start=1, swa_freq=2, swa_lr=-1e-4) with self.assertRaisesRegex(ValueError, "Invalid swa_freq: 0"): opt = contriboptim.SWA(opt, swa_start=1, swa_freq=0, swa_lr=1e-4) with self.assertRaisesRegex(ValueError, "Invalid swa_start: -1"): opt = contriboptim.SWA(opt, swa_start=-1, swa_freq=0, swa_lr=1e-4)
def test_swa_manual_group(self): # Tests SWA in manual mode with only y param group updated: # value of x should not change after opt.swap_swa_sgd() and y should # be equal to the manually computed average x, y, loss_fun, opt = self._define_vars_loss_opt() opt = contriboptim.SWA(opt) swa_start = 5 swa_freq = 2 y_sum = torch.zeros_like(y) n_avg = 0 for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() n_avg, _, y_sum = self._update_test_vars( i, swa_freq, swa_start, n_avg, 0, y_sum, x, y, upd_fun=lambda: opt.update_swa_group(opt.param_groups[1])) x_before_swap = x.data.clone() opt.swap_swa_sgd() y_avg = y_sum / n_avg self.assertEqual(y_avg, y) self.assertEqual(x_before_swap, x)
def test_swa_manual(self): # Tests SWA in manual mode: values of x and y after opt.swap_swa_sgd() # should be equal to the manually computed averages x, y, loss_fun, opt = self._define_vars_loss_opt() opt = contriboptim.SWA(opt) swa_start = 5 swa_freq = 2 x_sum = torch.zeros_like(x) y_sum = torch.zeros_like(y) n_avg = 0 for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() n_avg, x_sum, y_sum = self._update_test_vars( i, swa_freq, swa_start, n_avg, x_sum, y_sum, x, y, upd_fun=opt.update_swa) opt.swap_swa_sgd() x_avg = x_sum / n_avg y_avg = y_sum / n_avg self.assertEqual(x_avg, x) self.assertEqual(y_avg, y)
def test_swa_auto_mode_detection(self): # Tests that SWA mode (auto or manual) is chosen correctly based on # parameters provided # Auto mode x, y, loss_fun, base_opt = self._define_vars_loss_opt() swa_start = 5 swa_freq = 2 swa_lr = 0.001 opt = contriboptim.SWA( base_opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr) self.assertEqual(opt._auto_mode, True) opt = contriboptim.SWA(base_opt, swa_start=swa_start, swa_freq=swa_freq) self.assertEqual(opt._auto_mode, True) opt = contriboptim.SWA(base_opt, swa_start=swa_start, swa_lr=swa_lr) self.assertEqual(opt._auto_mode, False) opt = contriboptim.SWA(base_opt, swa_freq=swa_freq, swa_lr=swa_lr) self.assertEqual(opt._auto_mode, False) opt = contriboptim.SWA(base_opt, swa_start=swa_start) self.assertEqual(opt._auto_mode, False) opt = contriboptim.SWA(base_opt, swa_freq=swa_freq) self.assertEqual(opt._auto_mode, False) opt = contriboptim.SWA(base_opt, swa_lr=swa_lr) self.assertEqual(opt._auto_mode, False)
def test_swa_auto_group_added_during_run(self): # Tests SWA in Auto mode with the second param group added after several # optimizations steps. The expected behavior is that the averaging for # the second param group starts at swa_start steps after it is added. # For the first group averaging should start swa_start steps after the # first step of the optimizer. x, y, loss_fun, _ = self._define_vars_loss_opt() opt = optim.SGD([x], lr=1e-3, momentum=0.9) swa_start = 5 swa_freq = 2 opt = contriboptim.SWA(opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=0.001) x_sum = torch.zeros_like(x) y_sum = torch.zeros_like(y) x_n_avg = 0 y_n_avg = 0 x_step = 0 for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() x_step += 1 if i % swa_freq == 0 and i > swa_start: x_n_avg += 1 x_sum += x.data x_avg = x_sum / x_n_avg opt.add_param_group({'params': y, 'lr': 1e-4}) for y_step in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() x_step += 1 if y_step % swa_freq == 0 and y_step > swa_start: y_n_avg += 1 y_sum += y.data if x_step % swa_freq == 0 and x_step > swa_start: x_n_avg += 1 x_sum += x.data x_avg = x_sum / x_n_avg opt.swap_swa_sgd() x_avg = x_sum / x_n_avg y_avg = y_sum / y_n_avg self.assertEqual(x_avg, x) self.assertEqual(y_avg, y)
def test_swa_lr(self): # Tests SWA learning rate: in auto mode after swa_start steps the # learning rate should be changed to swa_lr; in manual mode swa_lr # must be ignored # Auto mode x, y, loss_fun, opt = self._define_vars_loss_opt() swa_start = 5 swa_freq = 2 initial_lr = opt.param_groups[0]["lr"] swa_lr = initial_lr * 0.1 opt = contriboptim.SWA(opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr) for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() lr = opt.param_groups[0]["lr"] if i > swa_start: self.assertEqual(lr, swa_lr) else: self.assertEqual(lr, initial_lr) # Manual Mode x, y, loss, opt = self._define_vars_loss_opt() initial_lr = opt.param_groups[0]["lr"] swa_lr = initial_lr * 0.1 with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = contriboptim.SWA(opt, swa_lr=swa_lr) for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() lr = opt.param_groups[0]["lr"] self.assertEqual(lr, initial_lr)
def _test_bn_update(self, data_tensor, dnn, cuda=False, label_tensor=None): class DatasetFromTensors(data.Dataset): def __init__(self, X, y=None): self.X = X self.y = y self.N = self.X.shape[0] def __getitem__(self, index): x = self.X[index] if self.y is None: return x else: y = self.y[index] return x, y def __len__(self): return self.N with_y = label_tensor is not None ds = DatasetFromTensors(data_tensor, y=label_tensor) dl = data.DataLoader(ds, batch_size=5, shuffle=True) preactivation_sum = torch.zeros(dnn.n_features) preactivation_squared_sum = torch.zeros(dnn.n_features) if cuda: preactivation_sum = preactivation_sum.cuda() preactivation_squared_sum = preactivation_squared_sum.cuda() total_num = 0 for x in dl: if with_y: x, _ = x if cuda: x = x.cuda() dnn.forward(x) preactivations = dnn.compute_preactivation(x) if len(preactivations.shape) == 4: preactivations = preactivations.transpose(1, 3) preactivations = preactivations.contiguous().view(-1, dnn.n_features) total_num += preactivations.shape[0] preactivation_sum += torch.sum(preactivations, dim=0) preactivation_squared_sum += torch.sum(preactivations**2, dim=0) preactivation_mean = preactivation_sum / total_num preactivation_var = preactivation_squared_sum / total_num preactivation_var = preactivation_var - preactivation_mean**2 swa = contriboptim.SWA(optim.SGD(dnn.parameters(), lr=1e-3)) swa.bn_update(dl, dnn, cuda=cuda) self.assertEqual(preactivation_mean, dnn.bn.running_mean) self.assertEqual(preactivation_var, dnn.bn.running_var, prec=1e-1)
def lbfgs_constructor(params): lbfgs = optim.LBFGS(params, lr=5e-2, max_iter=5) return contriboptim.SWA(lbfgs, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def asgd_constructor(params): asgd = optim.ASGD(params, lr=1e-3) return contriboptim.SWA(asgd, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def rprop_constructor(params): rprop = optim.Rprop(params, lr=1e-2) return contriboptim.SWA(rprop, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def adamax_constructor(params): adamax = optim.Adamax(params, lr=1e-1) return contriboptim.SWA(adamax, swa_start=1000, swa_freq=1, swa_lr=1e-2)
def adagrad_constructor(params): adagrad = optim.Adagrad(params, lr=1e-1) return contriboptim.SWA(adagrad, swa_start=1000, swa_freq=1, swa_lr=1e-2)
def adadelta_constructor(params): adadelta = optim.Adadelta(params) return contriboptim.SWA(adadelta, swa_start=1000, swa_freq=1)
def adam_constructor(params): adam = optim.Adam(params, lr=1e-2) return contriboptim.SWA(adam, swa_start=1000, swa_freq=1, swa_lr=1e-2)
def sgd_momentum_constructor(params): sgd = optim.SGD(params, lr=1e-3, momentum=0.9, weight_decay=1e-4) return contriboptim.SWA(sgd, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def sgd_manual_constructor(params): sgd = optim.SGD(params, lr=1e-3) return contriboptim.SWA(sgd)