def test_swa_raises(self): # Tests that SWA raises errors for wrong parameter values x, y, loss_fun, opt = self._define_vars_loss_opt() with self.assertRaisesRegex(ValueError, "Invalid SWA learning rate: -0.0001"): opt = StochasticWeightAveraging(opt, swa_start=1, swa_freq=2, swa_lr=-1e-4) with self.assertRaisesRegex(ValueError, "Invalid swa_freq: 0"): opt = StochasticWeightAveraging(opt, swa_start=1, swa_freq=0, swa_lr=1e-4) with self.assertRaisesRegex(ValueError, "Invalid swa_start: -1"): opt = StochasticWeightAveraging(opt, swa_start=-1, swa_freq=0, swa_lr=1e-4)
def test_swa_manual(self): # Tests SWA in manual mode: values of x and y after opt.finalize() # should be equal to the manually computed averages x, y, loss_fun, opt = self._define_vars_loss_opt() opt = StochasticWeightAveraging(opt) swa_start = 5 swa_freq = 2 x_sum = torch.zeros_like(x) y_sum = torch.zeros_like(y) n_avg = 0 for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() n_avg, x_sum, y_sum = self._update_test_vars( i, swa_freq, swa_start, n_avg, x_sum, y_sum, x, y, upd_fun=opt.update_swa, ) opt.finalize() x_avg = x_sum / n_avg y_avg = y_sum / n_avg assert equal(x_avg, x) assert equal(y_avg, y)
def _test_bn_update(self, data_tensor, dnn, device, label_tensor=None): class DatasetFromTensors(data.Dataset): def __init__(self, X, y=None): self.X = X self.y = y self.N = self.X.shape[0] def __getitem__(self, index): x = self.X[index] if self.y is None: return x else: y = self.y[index] return x, y def __len__(self): return self.N with_y = label_tensor is not None ds = DatasetFromTensors(data_tensor, y=label_tensor) dl = data.DataLoader(ds, batch_size=5, shuffle=True) preactivation_sum = torch.zeros(dnn.n_features, device=device) preactivation_squared_sum = torch.zeros(dnn.n_features, device=device) total_num = 0 for x in dl: if with_y: x, _ = x x = x.to(device) dnn(x) preactivations = dnn.compute_preactivation(x) if len(preactivations.shape) == 4: preactivations = preactivations.transpose(1, 3) preactivations = preactivations.reshape(-1, dnn.n_features) total_num += preactivations.shape[0] preactivation_sum += torch.sum(preactivations, dim=0) preactivation_squared_sum += torch.sum(preactivations**2, dim=0) preactivation_mean = preactivation_sum / total_num preactivation_var = preactivation_squared_sum / total_num preactivation_var = preactivation_var - preactivation_mean**2 swa = StochasticWeightAveraging(optim.SGD(dnn.parameters(), lr=1e-3)) swa.bn_update(dl, dnn, device=device) assert equal(preactivation_mean, dnn.bn.running_mean) assert equal(preactivation_var, dnn.bn.running_var, prec=1e-1)
def test_swa_manual_group(self): # Tests SWA in manual mode with only y param group updated: # value of x should not change after opt.finalize() and y should # be equal to the manually computed average x, y, loss_fun, opt = self._define_vars_loss_opt() opt = StochasticWeightAveraging(opt) swa_start = 5 swa_freq = 2 y_sum = torch.zeros_like(y) n_avg = 0 for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() n_avg, _, y_sum = self._update_test_vars( i, swa_freq, swa_start, n_avg, 0, y_sum, x, y, upd_fun=lambda: opt.update_swa_group(opt.param_groups[1]), ) x_before_swap = x.data.clone() with self.assertWarnsRegex( re.escape(r"SWA wasn't applied to param {}".format(x)) ): opt.finalize() y_avg = y_sum / n_avg assert equal(y_avg, y) assert equal(x_before_swap, x)
def test_swa_auto_mode_detection(self): # Tests that SWA mode (auto or manual) is chosen correctly based on # parameters provided # Auto mode x, y, loss_fun, base_opt = self._define_vars_loss_opt() swa_start = 5 swa_freq = 2 swa_lr = 0.001 opt = StochasticWeightAveraging(base_opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr) assert equal(opt._auto_mode, True) opt = StochasticWeightAveraging(base_opt, swa_start=swa_start, swa_freq=swa_freq) assert equal(opt._auto_mode, True) with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = StochasticWeightAveraging(base_opt, swa_start=swa_start, swa_lr=swa_lr) assert equal(opt._auto_mode, False) with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = StochasticWeightAveraging(base_opt, swa_freq=swa_freq, swa_lr=swa_lr) assert equal(opt._auto_mode, False) with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = StochasticWeightAveraging(base_opt, swa_start=swa_start) assert equal(opt._auto_mode, False) with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = StochasticWeightAveraging(base_opt, swa_freq=swa_freq) assert equal(opt._auto_mode, False) with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = StochasticWeightAveraging(base_opt, swa_lr=swa_lr) assert equal(opt._auto_mode, False)
def lamb_constructor(params): return StochasticWeightAveraging(Lamb(params, weight_decay=0.01), swa_start=1000, swa_freq=1, swa_lr=1e-2)
def test_swa_lr(self): # Tests SWA learning rate: in auto mode after swa_start steps the # learning rate should be changed to swa_lr; in manual mode swa_lr # must be ignored # Auto mode x, y, loss_fun, opt = self._define_vars_loss_opt() swa_start = 5 swa_freq = 2 initial_lr = opt.param_groups[0]["lr"] swa_lr = initial_lr * 0.1 opt = StochasticWeightAveraging(opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr) for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() lr = opt.param_groups[0]["lr"] if i > swa_start: assert equal(lr, swa_lr) else: assert equal(lr, initial_lr) # Manual Mode x, y, loss, opt = self._define_vars_loss_opt() initial_lr = opt.param_groups[0]["lr"] swa_lr = initial_lr * 0.1 with self.assertWarnsRegex("Some of swa_start, swa_freq is None"): opt = StochasticWeightAveraging(opt, swa_lr=swa_lr) for _ in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() lr = opt.param_groups[0]["lr"] assert equal(lr, initial_lr)
def test_swa_auto_group_added_during_run(self): # Tests SWA in Auto mode with the second param group added after several # optimizations steps. The expected behavior is that the averaging for # the second param group starts at swa_start steps after it is added. # For the first group averaging should start swa_start steps after the # first step of the optimizer. x, y, loss_fun, _ = self._define_vars_loss_opt() opt = optim.SGD([x], lr=1e-3, momentum=0.9) swa_start = 5 swa_freq = 2 opt = StochasticWeightAveraging(opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=0.001) x_sum = torch.zeros_like(x) y_sum = torch.zeros_like(y) x_n_avg = 0 y_n_avg = 0 x_step = 0 for i in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() x_step += 1 if i % swa_freq == 0 and i > swa_start: x_n_avg += 1 x_sum += x.data x_avg = x_sum / x_n_avg opt.add_param_group({"params": y, "lr": 1e-4}) for y_step in range(1, 11): opt.zero_grad() loss = loss_fun(x, y) loss.backward() opt.step() x_step += 1 if y_step % swa_freq == 0 and y_step > swa_start: y_n_avg += 1 y_sum += y.data if x_step % swa_freq == 0 and x_step > swa_start: x_n_avg += 1 x_sum += x.data x_avg = x_sum / x_n_avg opt.finalize() x_avg = x_sum / x_n_avg y_avg = y_sum / y_n_avg assert equal(x_avg, x) assert equal(y_avg, y)
def lbfgs_constructor(params): lbfgs = optim.LBFGS(params, lr=5e-2, max_iter=5) return StochasticWeightAveraging(lbfgs, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def asgd_constructor(params): asgd = optim.ASGD(params, lr=1e-3) return StochasticWeightAveraging(asgd, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def rprop_constructor(params): rprop = optim.Rprop(params, lr=1e-2) return StochasticWeightAveraging(rprop, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def adamax_constructor(params): adamax = optim.Adamax(params, lr=1e-1) return StochasticWeightAveraging(adamax, swa_start=1000, swa_freq=1, swa_lr=1e-2)
def adadelta_constructor(params): adadelta = optim.Adadelta(params) return StochasticWeightAveraging(adadelta, swa_start=1000, swa_freq=1)
def sgd_momentum_constructor(params): sgd = optim.SGD(params, lr=1e-3, momentum=0.9, weight_decay=1e-4) return StochasticWeightAveraging(sgd, swa_start=1000, swa_freq=1, swa_lr=1e-3)
def sgd_manual_constructor(params): sgd = optim.SGD(params, lr=1e-3) return StochasticWeightAveraging(sgd)