def test_passes_filtered_cgroups( self, filtered_optimizer, filter_requires_grad): pgroups = [{ 'params': [torch.zeros(1, requires_grad=True), torch.zeros(1, requires_grad=False)], 'lr': 0.1 }, { 'params': [torch.zeros(1, requires_grad=True)] }] opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad) filtered_opt = opt(pgroups, lr=0.2) assert isinstance(filtered_opt, torch.optim.SGD) assert len(list(filtered_opt.param_groups[0]['params'])) == 1 assert len(list(filtered_opt.param_groups[1]['params'])) == 1 assert filtered_opt.param_groups[0]['lr'] == 0.1 assert filtered_opt.param_groups[1]['lr'] == 0.2
def test_passes_kwargs_to_neuralnet_optimizer( self, filtered_optimizer, filter_requires_grad): from skorch import NeuralNetClassifier class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.dense0 = torch.nn.Linear(1, 1) def forward(self, X): return self.dense0(X) opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad) net = NeuralNetClassifier( MyModule, optimizer=opt, optimizer__momentum=0.9) net.initialize() assert isinstance(net.optimizer_, torch.optim.SGD) assert len(net.optimizer_.param_groups) == 1 assert net.optimizer_.param_groups[0]['momentum'] == 0.9
def test_passes_kwargs_to_neuralnet_optimizer( self, filtered_optimizer, filter_requires_grad): from skorch import NeuralNetClassifier from skorch.toy import make_classifier module_cls = make_classifier( input_units=1, num_hidden=0, output_units=1, ) with pytest.warns(DeprecationWarning): opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad) net = NeuralNetClassifier( module_cls, optimizer=opt, optimizer__momentum=0.9) net.initialize() assert isinstance(net.optimizer_, torch.optim.SGD) assert len(net.optimizer_.param_groups) == 1 assert net.optimizer_.param_groups[0]['momentum'] == 0.9
def test_pickle(self, filtered_optimizer, filter_requires_grad): with pytest.warns(DeprecationWarning): opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad) # Does not raise pickle.dumps(opt)
def test_pickle(self, filtered_optimizer, filter_requires_grad): opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad) # Does not raise pickle.dumps(opt)