def test_discrete_validate_args(jax_dist, valid_args, invalid_args, invalid_sample): with validation_enabled(): with pytest.raises(ValueError, match='Invalid parameters'): jax_dist(*invalid_args) frozen_dist = jax_dist(*valid_args) with pytest.raises(ValueError, match='Invalid values'): frozen_dist.logpmf(invalid_sample)
def test_continuous_validate_args(jax_dist, dist_args, sample): valid_args = [i + 1 for i in range(jax_dist.numargs)] with validation_enabled(): if dist_args: with pytest.raises(ValueError, match='Invalid parameters'): jax_dist(*dist_args) with pytest.raises(ValueError, match='Invalid scale parameter'): jax_dist(*valid_args, scale=-1) frozen_dist = jax_dist(*valid_args) with pytest.raises(ValueError, match='Invalid values'): frozen_dist.logpdf(sample)