Esempio n. 1
0
def test_update_no_improvement_count_errors(epoch_error, expected):
    mod = TorchModelBase(tol=0.5)
    mod.no_improvement_count = 5
    mod.best_error = 1
    mod.errors = []
    mod._update_no_improvement_count_errors(epoch_error)
    assert mod.no_improvement_count == expected
Esempio n. 2
0
def test_build_validation_split(arg_count):
    n_features = 2
    n_examples = 10
    validation_fraction = 0.2
    expected_dev_size = int(n_examples * validation_fraction)
    expected_train_size = n_examples - expected_dev_size
    args = [np.ones((n_examples, n_features)) for _ in range(arg_count)]
    train, dev = TorchModelBase._build_validation_split(
        *args, validation_fraction=validation_fraction)
    assert len(train) == arg_count
    assert len(dev) == arg_count
    assert all(x.shape == (expected_train_size, n_features) for x in train)
    assert all(x.shape == (expected_dev_size, n_features) for x in dev)
Esempio n. 3
0
def test_no_setting_of_missing_param():
    mod = TorchModelBase(amsgrad=0.5)
    with pytest.raises(ValueError):
        mod.set_params(**{'NON_EXISTENT_PARAM': False})
Esempio n. 4
0
def test_parameter_setting(param, expected):
    mod = TorchModelBase(amsgrad=0.5)
    mod.set_params(**{param: expected})
    result = getattr(mod, param)
    assert result == expected