Exemple #1
0
    def test_swa_raises(self):
        # Tests that SWA raises errors for wrong parameter values

        x, y, loss_fun, opt = self._define_vars_loss_opt()

        with self.assertRaisesRegex(ValueError,
                                    "Invalid SWA learning rate: -0.0001"):
            opt = contriboptim.SWA(opt, swa_start=1, swa_freq=2, swa_lr=-1e-4)

        with self.assertRaisesRegex(ValueError, "Invalid swa_freq: 0"):
            opt = contriboptim.SWA(opt, swa_start=1, swa_freq=0, swa_lr=1e-4)

        with self.assertRaisesRegex(ValueError, "Invalid swa_start: -1"):
            opt = contriboptim.SWA(opt, swa_start=-1, swa_freq=0, swa_lr=1e-4)
Exemple #2
0
    def test_swa_manual_group(self):
        # Tests SWA in manual mode with only y param group updated: 
        # value of x should not change after opt.swap_swa_sgd() and y should
        # be equal to the manually computed average
        x, y, loss_fun, opt = self._define_vars_loss_opt()
        opt = contriboptim.SWA(opt)
        swa_start = 5
        swa_freq = 2

        y_sum = torch.zeros_like(y)
        n_avg = 0
        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            n_avg, _, y_sum = self._update_test_vars(
                i, swa_freq, swa_start, n_avg, 0, y_sum, x, y,
                upd_fun=lambda: opt.update_swa_group(opt.param_groups[1]))

        x_before_swap = x.data.clone()
        opt.swap_swa_sgd()
        y_avg = y_sum / n_avg
        self.assertEqual(y_avg, y)      
        self.assertEqual(x_before_swap, x)      
Exemple #3
0
    def test_swa_manual(self):
        # Tests SWA in manual mode: values of x and y after opt.swap_swa_sgd()
        # should be equal to the manually computed averages
        x, y, loss_fun, opt = self._define_vars_loss_opt()
        opt = contriboptim.SWA(opt)
        swa_start = 5
        swa_freq = 2

        x_sum = torch.zeros_like(x)
        y_sum = torch.zeros_like(y)
        n_avg = 0
        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            n_avg, x_sum, y_sum = self._update_test_vars(
                i,
                swa_freq,
                swa_start,
                n_avg,
                x_sum,
                y_sum,
                x,
                y,
                upd_fun=opt.update_swa)

        opt.swap_swa_sgd()
        x_avg = x_sum / n_avg
        y_avg = y_sum / n_avg
        self.assertEqual(x_avg, x)
        self.assertEqual(y_avg, y)
Exemple #4
0
    def test_swa_auto_mode_detection(self):
        # Tests that SWA mode (auto or manual) is chosen correctly based on
        # parameters provided

        # Auto mode
        x, y, loss_fun, base_opt = self._define_vars_loss_opt()
        swa_start = 5
        swa_freq = 2
        swa_lr = 0.001

        opt = contriboptim.SWA(
            base_opt, swa_start=swa_start, swa_freq=swa_freq, swa_lr=swa_lr)
        self.assertEqual(opt._auto_mode, True)

        opt = contriboptim.SWA(base_opt, swa_start=swa_start, swa_freq=swa_freq)
        self.assertEqual(opt._auto_mode, True)

        opt = contriboptim.SWA(base_opt, swa_start=swa_start, swa_lr=swa_lr)
        self.assertEqual(opt._auto_mode, False)

        opt = contriboptim.SWA(base_opt, swa_freq=swa_freq, swa_lr=swa_lr)
        self.assertEqual(opt._auto_mode, False)

        opt = contriboptim.SWA(base_opt, swa_start=swa_start)
        self.assertEqual(opt._auto_mode, False)

        opt = contriboptim.SWA(base_opt, swa_freq=swa_freq)
        self.assertEqual(opt._auto_mode, False)

        opt = contriboptim.SWA(base_opt, swa_lr=swa_lr)
        self.assertEqual(opt._auto_mode, False)
Exemple #5
0
    def test_swa_auto_group_added_during_run(self):
        # Tests SWA in Auto mode with the second param group added after several
        # optimizations steps. The expected behavior is that the averaging for
        # the second param group starts at swa_start steps after it is added.
        # For the first group averaging should start swa_start steps after the
        # first step of the optimizer.

        x, y, loss_fun, _ = self._define_vars_loss_opt()
        opt = optim.SGD([x], lr=1e-3, momentum=0.9)
        swa_start = 5
        swa_freq = 2
        opt = contriboptim.SWA(opt,
                               swa_start=swa_start,
                               swa_freq=swa_freq,
                               swa_lr=0.001)

        x_sum = torch.zeros_like(x)
        y_sum = torch.zeros_like(y)
        x_n_avg = 0
        y_n_avg = 0
        x_step = 0
        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            x_step += 1
            if i % swa_freq == 0 and i > swa_start:
                x_n_avg += 1
                x_sum += x.data

        x_avg = x_sum / x_n_avg

        opt.add_param_group({'params': y, 'lr': 1e-4})

        for y_step in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            x_step += 1
            if y_step % swa_freq == 0 and y_step > swa_start:
                y_n_avg += 1
                y_sum += y.data
            if x_step % swa_freq == 0 and x_step > swa_start:
                x_n_avg += 1
                x_sum += x.data
                x_avg = x_sum / x_n_avg

        opt.swap_swa_sgd()
        x_avg = x_sum / x_n_avg
        y_avg = y_sum / y_n_avg
        self.assertEqual(x_avg, x)
        self.assertEqual(y_avg, y)
Exemple #6
0
    def test_swa_lr(self):
        # Tests SWA learning rate: in auto mode after swa_start steps the
        # learning rate should be changed to swa_lr; in manual mode swa_lr
        # must be ignored

        # Auto mode
        x, y, loss_fun, opt = self._define_vars_loss_opt()
        swa_start = 5
        swa_freq = 2
        initial_lr = opt.param_groups[0]["lr"]
        swa_lr = initial_lr * 0.1
        opt = contriboptim.SWA(opt,
                               swa_start=swa_start,
                               swa_freq=swa_freq,
                               swa_lr=swa_lr)

        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            lr = opt.param_groups[0]["lr"]
            if i > swa_start:
                self.assertEqual(lr, swa_lr)
            else:
                self.assertEqual(lr, initial_lr)

        # Manual Mode
        x, y, loss, opt = self._define_vars_loss_opt()
        initial_lr = opt.param_groups[0]["lr"]
        swa_lr = initial_lr * 0.1
        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = contriboptim.SWA(opt, swa_lr=swa_lr)

        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            lr = opt.param_groups[0]["lr"]
            self.assertEqual(lr, initial_lr)
Exemple #7
0
    def _test_bn_update(self, data_tensor, dnn, cuda=False, label_tensor=None):

        class DatasetFromTensors(data.Dataset):
            def __init__(self, X, y=None):
                self.X = X
                self.y = y
                self.N = self.X.shape[0]
        
            def __getitem__(self, index):
                x = self.X[index]
                if self.y is None:
                    return x
                else:
                    y = self.y[index]
                    return x, y
        
            def __len__(self):
                return self.N 
        
        with_y = label_tensor is not None
        ds = DatasetFromTensors(data_tensor, y=label_tensor)
        dl = data.DataLoader(ds, batch_size=5, shuffle=True)

        preactivation_sum = torch.zeros(dnn.n_features)
        preactivation_squared_sum = torch.zeros(dnn.n_features)
        if cuda:
            preactivation_sum = preactivation_sum.cuda()
            preactivation_squared_sum = preactivation_squared_sum.cuda()
        total_num = 0
        for x in dl:
            if with_y:
                x, _ = x
            if cuda:
                x = x.cuda()

            dnn.forward(x)
            preactivations = dnn.compute_preactivation(x)
            if len(preactivations.shape) == 4:
                preactivations = preactivations.transpose(1, 3)
            preactivations = preactivations.contiguous().view(-1, dnn.n_features)
            total_num += preactivations.shape[0]
        
            preactivation_sum += torch.sum(preactivations, dim=0)
            preactivation_squared_sum += torch.sum(preactivations**2, dim=0)  
            
        preactivation_mean = preactivation_sum / total_num
        preactivation_var = preactivation_squared_sum / total_num 
        preactivation_var = preactivation_var - preactivation_mean**2

        swa = contriboptim.SWA(optim.SGD(dnn.parameters(), lr=1e-3))
        swa.bn_update(dl, dnn, cuda=cuda)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, prec=1e-1)
Exemple #8
0
 def lbfgs_constructor(params):
     lbfgs = optim.LBFGS(params, lr=5e-2, max_iter=5)
     return contriboptim.SWA(lbfgs,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-3)
Exemple #9
0
 def asgd_constructor(params):
     asgd = optim.ASGD(params, lr=1e-3)
     return contriboptim.SWA(asgd,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-3)
Exemple #10
0
 def rprop_constructor(params):
     rprop = optim.Rprop(params, lr=1e-2)
     return contriboptim.SWA(rprop,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-3)
Exemple #11
0
 def adamax_constructor(params):
     adamax = optim.Adamax(params, lr=1e-1)
     return contriboptim.SWA(adamax,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-2)
Exemple #12
0
 def adagrad_constructor(params):
     adagrad = optim.Adagrad(params, lr=1e-1)
     return contriboptim.SWA(adagrad,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-2)
Exemple #13
0
 def adadelta_constructor(params):
     adadelta = optim.Adadelta(params)
     return contriboptim.SWA(adadelta, swa_start=1000, swa_freq=1)
Exemple #14
0
 def adam_constructor(params):
     adam = optim.Adam(params, lr=1e-2)
     return contriboptim.SWA(adam,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-2)
Exemple #15
0
 def sgd_momentum_constructor(params):
     sgd = optim.SGD(params, lr=1e-3, momentum=0.9, weight_decay=1e-4)
     return contriboptim.SWA(sgd,
                             swa_start=1000,
                             swa_freq=1,
                             swa_lr=1e-3)
Exemple #16
0
 def sgd_manual_constructor(params):
     sgd = optim.SGD(params, lr=1e-3)
     return contriboptim.SWA(sgd)