Пример #1
0
    def test_passes_filtered_cgroups(
            self, filtered_optimizer, filter_requires_grad):
        pgroups = [{
            'params': [torch.zeros(1, requires_grad=True),
                       torch.zeros(1, requires_grad=False)],
            'lr': 0.1
        }, {
            'params': [torch.zeros(1, requires_grad=True)]
        }]

        opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
        filtered_opt = opt(pgroups, lr=0.2)

        assert isinstance(filtered_opt, torch.optim.SGD)
        assert len(list(filtered_opt.param_groups[0]['params'])) == 1
        assert len(list(filtered_opt.param_groups[1]['params'])) == 1

        assert filtered_opt.param_groups[0]['lr'] == 0.1
        assert filtered_opt.param_groups[1]['lr'] == 0.2
Пример #2
0
    def test_passes_kwargs_to_neuralnet_optimizer(
            self, filtered_optimizer, filter_requires_grad):
        from skorch import NeuralNetClassifier

        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.dense0 = torch.nn.Linear(1, 1)

            def forward(self, X):
                return self.dense0(X)

        opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
        net = NeuralNetClassifier(
            MyModule, optimizer=opt, optimizer__momentum=0.9)

        net.initialize()
        assert isinstance(net.optimizer_, torch.optim.SGD)
        assert len(net.optimizer_.param_groups) == 1
        assert net.optimizer_.param_groups[0]['momentum'] == 0.9
Пример #3
0
    def test_passes_kwargs_to_neuralnet_optimizer(
            self, filtered_optimizer, filter_requires_grad):
        from skorch import NeuralNetClassifier
        from skorch.toy import make_classifier

        module_cls = make_classifier(
            input_units=1,
            num_hidden=0,
            output_units=1,
        )

        with pytest.warns(DeprecationWarning):
            opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
            net = NeuralNetClassifier(
                module_cls, optimizer=opt, optimizer__momentum=0.9)
            net.initialize()

        assert isinstance(net.optimizer_, torch.optim.SGD)
        assert len(net.optimizer_.param_groups) == 1
        assert net.optimizer_.param_groups[0]['momentum'] == 0.9
Пример #4
0
 def test_pickle(self, filtered_optimizer, filter_requires_grad):
     with pytest.warns(DeprecationWarning):
         opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
     # Does not raise
     pickle.dumps(opt)
Пример #5
0
 def test_pickle(self, filtered_optimizer, filter_requires_grad):
     opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
     # Does not raise
     pickle.dumps(opt)