Ejemplo n.º 1
0
    def test_swa_raises(self):
        # Tests that SWA raises errors for wrong parameter values

        x, y, loss_fun, opt = self._define_vars_loss_opt()

        with self.assertRaisesRegex(ValueError, "Invalid SWA learning rate: -0.0001"):
            opt = StochasticWeightAveraging(opt, swa_start=1, swa_freq=2, swa_lr=-1e-4)

        with self.assertRaisesRegex(ValueError, "Invalid swa_freq: 0"):
            opt = StochasticWeightAveraging(opt, swa_start=1, swa_freq=0, swa_lr=1e-4)

        with self.assertRaisesRegex(ValueError, "Invalid swa_start: -1"):
            opt = StochasticWeightAveraging(opt, swa_start=-1, swa_freq=0, swa_lr=1e-4)
Ejemplo n.º 2
0
    def test_swa_manual(self):
        # Tests SWA in manual mode: values of x and y after opt.finalize()
        # should be equal to the manually computed averages
        x, y, loss_fun, opt = self._define_vars_loss_opt()
        opt = StochasticWeightAveraging(opt)
        swa_start = 5
        swa_freq = 2

        x_sum = torch.zeros_like(x)
        y_sum = torch.zeros_like(y)
        n_avg = 0
        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            n_avg, x_sum, y_sum = self._update_test_vars(
                i,
                swa_freq,
                swa_start,
                n_avg,
                x_sum,
                y_sum,
                x,
                y,
                upd_fun=opt.update_swa,
            )

        opt.finalize()
        x_avg = x_sum / n_avg
        y_avg = y_sum / n_avg
        assert equal(x_avg, x)
        assert equal(y_avg, y)
Ejemplo n.º 3
0
    def _test_bn_update(self, data_tensor, dnn, device, label_tensor=None):
        class DatasetFromTensors(data.Dataset):
            def __init__(self, X, y=None):
                self.X = X
                self.y = y
                self.N = self.X.shape[0]

            def __getitem__(self, index):
                x = self.X[index]
                if self.y is None:
                    return x
                else:
                    y = self.y[index]
                    return x, y

            def __len__(self):
                return self.N

        with_y = label_tensor is not None
        ds = DatasetFromTensors(data_tensor, y=label_tensor)
        dl = data.DataLoader(ds, batch_size=5, shuffle=True)

        preactivation_sum = torch.zeros(dnn.n_features, device=device)
        preactivation_squared_sum = torch.zeros(dnn.n_features, device=device)
        total_num = 0
        for x in dl:
            if with_y:
                x, _ = x
            x = x.to(device)

            dnn(x)
            preactivations = dnn.compute_preactivation(x)
            if len(preactivations.shape) == 4:
                preactivations = preactivations.transpose(1, 3)
            preactivations = preactivations.reshape(-1, dnn.n_features)
            total_num += preactivations.shape[0]

            preactivation_sum += torch.sum(preactivations, dim=0)
            preactivation_squared_sum += torch.sum(preactivations**2, dim=0)

        preactivation_mean = preactivation_sum / total_num
        preactivation_var = preactivation_squared_sum / total_num
        preactivation_var = preactivation_var - preactivation_mean**2

        swa = StochasticWeightAveraging(optim.SGD(dnn.parameters(), lr=1e-3))
        swa.bn_update(dl, dnn, device=device)
        assert equal(preactivation_mean, dnn.bn.running_mean)
        assert equal(preactivation_var, dnn.bn.running_var, prec=1e-1)
Ejemplo n.º 4
0
    def test_swa_manual_group(self):
        # Tests SWA in manual mode with only y param group updated:
        # value of x should not change after opt.finalize() and y should
        # be equal to the manually computed average
        x, y, loss_fun, opt = self._define_vars_loss_opt()
        opt = StochasticWeightAveraging(opt)
        swa_start = 5
        swa_freq = 2

        y_sum = torch.zeros_like(y)
        n_avg = 0
        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            n_avg, _, y_sum = self._update_test_vars(
                i,
                swa_freq,
                swa_start,
                n_avg,
                0,
                y_sum,
                x,
                y,
                upd_fun=lambda: opt.update_swa_group(opt.param_groups[1]),
            )

        x_before_swap = x.data.clone()

        with self.assertWarnsRegex(
            re.escape(r"SWA wasn't applied to param {}".format(x))
        ):
            opt.finalize()

        y_avg = y_sum / n_avg
        assert equal(y_avg, y)
        assert equal(x_before_swap, x)
Ejemplo n.º 5
0
    def test_swa_auto_mode_detection(self):
        # Tests that SWA mode (auto or manual) is chosen correctly based on
        # parameters provided

        # Auto mode
        x, y, loss_fun, base_opt = self._define_vars_loss_opt()
        swa_start = 5
        swa_freq = 2
        swa_lr = 0.001

        opt = StochasticWeightAveraging(base_opt,
                                        swa_start=swa_start,
                                        swa_freq=swa_freq,
                                        swa_lr=swa_lr)
        assert equal(opt._auto_mode, True)

        opt = StochasticWeightAveraging(base_opt,
                                        swa_start=swa_start,
                                        swa_freq=swa_freq)
        assert equal(opt._auto_mode, True)

        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = StochasticWeightAveraging(base_opt,
                                            swa_start=swa_start,
                                            swa_lr=swa_lr)
            assert equal(opt._auto_mode, False)

        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = StochasticWeightAveraging(base_opt,
                                            swa_freq=swa_freq,
                                            swa_lr=swa_lr)
            assert equal(opt._auto_mode, False)

        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = StochasticWeightAveraging(base_opt, swa_start=swa_start)
            assert equal(opt._auto_mode, False)

        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = StochasticWeightAveraging(base_opt, swa_freq=swa_freq)
            assert equal(opt._auto_mode, False)

        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = StochasticWeightAveraging(base_opt, swa_lr=swa_lr)
            assert equal(opt._auto_mode, False)
Ejemplo n.º 6
0
 def lamb_constructor(params):
     return StochasticWeightAveraging(Lamb(params, weight_decay=0.01),
                                      swa_start=1000,
                                      swa_freq=1,
                                      swa_lr=1e-2)
Ejemplo n.º 7
0
    def test_swa_lr(self):
        # Tests SWA learning rate: in auto mode after swa_start steps the
        # learning rate should be changed to swa_lr; in manual mode swa_lr
        # must be ignored

        # Auto mode
        x, y, loss_fun, opt = self._define_vars_loss_opt()
        swa_start = 5
        swa_freq = 2
        initial_lr = opt.param_groups[0]["lr"]
        swa_lr = initial_lr * 0.1
        opt = StochasticWeightAveraging(opt,
                                        swa_start=swa_start,
                                        swa_freq=swa_freq,
                                        swa_lr=swa_lr)

        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            lr = opt.param_groups[0]["lr"]
            if i > swa_start:
                assert equal(lr, swa_lr)
            else:
                assert equal(lr, initial_lr)

        # Manual Mode
        x, y, loss, opt = self._define_vars_loss_opt()
        initial_lr = opt.param_groups[0]["lr"]
        swa_lr = initial_lr * 0.1
        with self.assertWarnsRegex("Some of swa_start, swa_freq is None"):
            opt = StochasticWeightAveraging(opt, swa_lr=swa_lr)

        for _ in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            lr = opt.param_groups[0]["lr"]
            assert equal(lr, initial_lr)
Ejemplo n.º 8
0
    def test_swa_auto_group_added_during_run(self):
        # Tests SWA in Auto mode with the second param group added after several
        # optimizations steps. The expected behavior is that the averaging for
        # the second param group starts at swa_start steps after it is added.
        # For the first group averaging should start swa_start steps after the
        # first step of the optimizer.

        x, y, loss_fun, _ = self._define_vars_loss_opt()
        opt = optim.SGD([x], lr=1e-3, momentum=0.9)
        swa_start = 5
        swa_freq = 2
        opt = StochasticWeightAveraging(opt,
                                        swa_start=swa_start,
                                        swa_freq=swa_freq,
                                        swa_lr=0.001)

        x_sum = torch.zeros_like(x)
        y_sum = torch.zeros_like(y)
        x_n_avg = 0
        y_n_avg = 0
        x_step = 0
        for i in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            x_step += 1
            if i % swa_freq == 0 and i > swa_start:
                x_n_avg += 1
                x_sum += x.data

        x_avg = x_sum / x_n_avg

        opt.add_param_group({"params": y, "lr": 1e-4})

        for y_step in range(1, 11):
            opt.zero_grad()
            loss = loss_fun(x, y)
            loss.backward()
            opt.step()
            x_step += 1
            if y_step % swa_freq == 0 and y_step > swa_start:
                y_n_avg += 1
                y_sum += y.data
            if x_step % swa_freq == 0 and x_step > swa_start:
                x_n_avg += 1
                x_sum += x.data
                x_avg = x_sum / x_n_avg

        opt.finalize()
        x_avg = x_sum / x_n_avg
        y_avg = y_sum / y_n_avg
        assert equal(x_avg, x)
        assert equal(y_avg, y)
Ejemplo n.º 9
0
 def lbfgs_constructor(params):
     lbfgs = optim.LBFGS(params, lr=5e-2, max_iter=5)
     return StochasticWeightAveraging(lbfgs,
                                      swa_start=1000,
                                      swa_freq=1,
                                      swa_lr=1e-3)
Ejemplo n.º 10
0
 def asgd_constructor(params):
     asgd = optim.ASGD(params, lr=1e-3)
     return StochasticWeightAveraging(asgd,
                                      swa_start=1000,
                                      swa_freq=1,
                                      swa_lr=1e-3)
Ejemplo n.º 11
0
 def rprop_constructor(params):
     rprop = optim.Rprop(params, lr=1e-2)
     return StochasticWeightAveraging(rprop,
                                      swa_start=1000,
                                      swa_freq=1,
                                      swa_lr=1e-3)
Ejemplo n.º 12
0
 def adamax_constructor(params):
     adamax = optim.Adamax(params, lr=1e-1)
     return StochasticWeightAveraging(adamax,
                                      swa_start=1000,
                                      swa_freq=1,
                                      swa_lr=1e-2)
Ejemplo n.º 13
0
 def adadelta_constructor(params):
     adadelta = optim.Adadelta(params)
     return StochasticWeightAveraging(adadelta,
                                      swa_start=1000,
                                      swa_freq=1)
Ejemplo n.º 14
0
 def sgd_momentum_constructor(params):
     sgd = optim.SGD(params, lr=1e-3, momentum=0.9, weight_decay=1e-4)
     return StochasticWeightAveraging(sgd,
                                      swa_start=1000,
                                      swa_freq=1,
                                      swa_lr=1e-3)
Ejemplo n.º 15
0
 def sgd_manual_constructor(params):
     sgd = optim.SGD(params, lr=1e-3)
     return StochasticWeightAveraging(sgd)