Exemplo n.º 1
0
    def test_parse_args_help(self, parse_args, estimator):
        with patch('skorch.cli.sys.exit') as exit_:
            with patch('skorch.cli.print_help') as help_:
                parsed = parse_args({'help': True, 'foo': 'bar'})
                parsed(estimator)

        assert estimator.set_params.call_count == 0  # kwargs and defaults
        assert help_.call_count == 1
        assert exit_.call_count == 1
Exemplo n.º 2
0
    def test_parse_args_net_custom_defaults(self, parse_args, net):
        defaults = {'batch_size': 256, 'module__hidden_units': 55}
        kwargs = {'batch_size': 123, 'module__nonlin': nn.Hardtanh(1, 2)}
        parsed = parse_args(kwargs, defaults)
        net = parsed(net)

        # cmd line args have precedence over defaults
        assert net.batch_size == 123
        assert net.module_.hidden_units == 55
        assert isinstance(net.module_.nonlin, nn.Hardtanh)
        assert net.module_.nonlin.min_val == 1
        assert net.module_.nonlin.max_val == 2
Exemplo n.º 3
0
    def test_parse_args_sklearn_pipe_custom_defaults(self, parse_args, pipe_sklearn):
        defaults = {'features__scale__copy': 123, 'clf__fit_intercept': 456}
        kwargs = {'features__scale__copy': 555, 'clf__normalize': 789}
        parsed = parse_args(kwargs, defaults)
        pipe = parsed(pipe_sklearn)
        scaler = pipe.steps[0][1].transformer_list[0][1]
        clf = pipe.steps[-1][1]

        # cmd line args have precedence over defaults
        assert scaler.copy == 555
        assert clf.fit_intercept == 456
        assert clf.normalize == 789
Exemplo n.º 4
0
    def test_parse_args_pipe_custom_defaults(self, parse_args, pipe):
        defaults = {'net__batch_size': 256, 'net__module__hidden_units': 55}
        kwargs = {'net__batch_size': 123, 'net__module__nonlin': nn.Hardtanh(1, 2)}
        parsed = parse_args(kwargs, defaults)
        pipe = parsed(pipe)
        net = pipe.steps[-1][1]

        # cmd line args have precedence over defaults
        assert net.batch_size == 123
        assert net.module__hidden_units == 55
        assert isinstance(net.module__nonlin, nn.Hardtanh)
        assert net.module__nonlin.min_val == 1
        assert net.module__nonlin.max_val == 2
Exemplo n.º 5
0
    def test_parse_args_run(self, parse_args, estimator):
        kwargs = {'foo': 'bar', 'baz': 'math.cos'}
        with patch('skorch.cli.sys.exit') as exit_:
            with patch('skorch.cli.print_help') as help_:
                parsed = parse_args(kwargs)
                parsed(estimator)

        assert estimator.set_params.call_count == 2  # defaults and kwargs

        defaults_set_params = estimator.set_params.call_args_list[0][1]
        assert not defaults_set_params  # no defaults specified

        kwargs_set_params = estimator.set_params.call_args_list[1][1]
        assert kwargs_set_params['foo'] == 'bar'
        assert kwargs_set_params['baz'] == cos

        assert help_.call_count == 0
        assert exit_.call_count == 0