def test_update_no_improvement_count_errors(epoch_error, expected): mod = TorchModelBase(tol=0.5) mod.no_improvement_count = 5 mod.best_error = 1 mod.errors = [] mod._update_no_improvement_count_errors(epoch_error) assert mod.no_improvement_count == expected
def test_build_validation_split(arg_count): n_features = 2 n_examples = 10 validation_fraction = 0.2 expected_dev_size = int(n_examples * validation_fraction) expected_train_size = n_examples - expected_dev_size args = [np.ones((n_examples, n_features)) for _ in range(arg_count)] train, dev = TorchModelBase._build_validation_split( *args, validation_fraction=validation_fraction) assert len(train) == arg_count assert len(dev) == arg_count assert all(x.shape == (expected_train_size, n_features) for x in train) assert all(x.shape == (expected_dev_size, n_features) for x in dev)
def test_no_setting_of_missing_param(): mod = TorchModelBase(amsgrad=0.5) with pytest.raises(ValueError): mod.set_params(**{'NON_EXISTENT_PARAM': False})
def test_parameter_setting(param, expected): mod = TorchModelBase(amsgrad=0.5) mod.set_params(**{param: expected}) result = getattr(mod, param) assert result == expected