Ejemplo n.º 1
0
    def test_scheduler(self):
        config = self._get_valid_config()

        # Check as warmup
        scheduler = LinearParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [config["start_value"]
                             ] + self._get_valid_intermediate()
        self.assertEqual(schedule, expected_schedule)

        # Check as decay
        tmp = config["start_value"]
        config["start_value"] = config["end_value"]
        config["end_value"] = tmp
        scheduler = LinearParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [config["start_value"]] + list(
            reversed(self._get_valid_intermediate()))
        self.assertEqual(schedule, expected_schedule)
Ejemplo n.º 2
0
    def test_invalid_config(self):
        config = self._get_valid_config()

        bad_config = copy.deepcopy(config)
        # No start lr
        del bad_config["start_value"]
        with self.assertRaises((AssertionError, TypeError)):
            LinearParamScheduler.from_config(bad_config)

        # No end lr
        bad_config["start_value"] = config["start_value"]
        del bad_config["end_value"]
        with self.assertRaises((AssertionError, TypeError)):
            LinearParamScheduler.from_config(bad_config)
    def test_batchnorm_weight_decay(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(2, 3)
                self.relu = nn.ReLU()
                self.bn = nn.BatchNorm1d(3)

            def forward(self, x):
                return self.bn(self.relu(self.lin(x)))

        torch.manual_seed(1)
        model = MyModel()

        opt = build_optimizer(self._get_config())
        bn_params, lin_params = split_batchnorm_params(model)

        lin_param_before = model.lin.weight.detach().clone()
        bn_param_before = model.bn.weight.detach().clone()

        with torch.enable_grad():
            x = torch.tensor([[1.0, 1.0], [1.0, 2.0]])
            out = model(x).pow(2).sum()
            out.backward()

        opt.set_param_groups([
            {
                "params": lin_params,
                "lr": LinearParamScheduler(1, 2),
                "weight_decay": 0.5,
            },
            {
                "params": bn_params,
                "lr": 0,
                "weight_decay": 0
            },
        ])

        opt.step(where=0.5)

        # Make sure the linear parameters are trained but not the batch norm
        self.assertFalse(torch.allclose(model.lin.weight, lin_param_before))
        self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))

        opt.step(where=0.5)

        # Same, but after another step and triggering the lr scheduler
        self.assertFalse(torch.allclose(model.lin.weight, lin_param_before))
        self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))
Ejemplo n.º 4
0
    def test_get_lr(self):
        opt = build_optimizer(self._get_config())
        param = torch.tensor([1.0], requires_grad=True)
        opt.set_param_groups([{"params": [param], "lr": 1}])

        self.assertEqual(opt.options_view.lr, 1)

        # Case two: verify LR changes
        opt = build_optimizer(self._get_config())
        param = torch.tensor([1.0], requires_grad=True)
        opt.set_param_groups([{"params": [param], "lr": LinearParamScheduler(1, 2)}])

        self.assertAlmostEqual(opt.options_view.lr, 1)
        opt.step(where=0.5)
        self.assertAlmostEqual(opt.options_view.lr, 1.5)
Ejemplo n.º 5
0
    def test_one(self):
        train_dataset = MyDataset(
            batchsize_per_replica=32,
            shuffle=False,
            transform=GenericImageTransform(transform=transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])),
            num_samples=100,
            crop_size=224,
            class_ratio=0.5,
            seed=0,
        )

        test_dataset = MyDataset(
            batchsize_per_replica=32,
            shuffle=False,
            transform=GenericImageTransform(transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])),
            num_samples=100,
            crop_size=224,
            class_ratio=0.5,
            seed=0,
        )

        model = MyModel()
        loss = MyLoss()

        optimizer = SGD(momentum=0.9, weight_decay=1e-4, nesterov=True)

        task = (ClassificationTask().set_model(model).set_dataset(
            train_dataset, "train").set_dataset(test_dataset, "test").set_loss(
                loss).set_optimizer(optimizer).set_optimizer_schedulers({
                    "lr":
                    LinearParamScheduler(start_value=0.01, end_value=0.009)
                }).set_num_epochs(1))

        trainer = LocalTrainer()
        trainer.train(task)
Ejemplo n.º 6
0
    def test_lr_step(self):
        opt = SGD()

        param = torch.tensor([0.0], requires_grad=True)
        opt.set_param_groups([param], lr=LinearParamScheduler(1, 2))

        param.grad = torch.tensor([1.0])

        self.assertAlmostEqual(opt.options_view.lr, 1.0)

        # lr=1, param should go from 0 to -1
        opt.step(where=0)
        self.assertAlmostEqual(opt.options_view.lr, 1.0)

        self.assertAlmostEqual(param.item(), -1.0, delta=1e-5)

        # lr=1.5, param should go from -1 to -1-1.5 = -2.5
        opt.step(where=0.5)
        self.assertAlmostEqual(param.item(), -2.5, delta=1e-5)

        # lr=1.9, param should go from -2.5 to -1.9-2.5 = -4.4
        opt.step(where=0.9)
        self.assertAlmostEqual(param.item(), -4.4, delta=1e-5)