def test_discrete_validate_args(jax_dist, valid_args, invalid_args, invalid_sample):
    with validation_enabled():
        with pytest.raises(ValueError, match='Invalid parameters'):
            jax_dist(*invalid_args)

        frozen_dist = jax_dist(*valid_args)
        with pytest.raises(ValueError, match='Invalid values'):
            frozen_dist.logpmf(invalid_sample)
def test_continuous_validate_args(jax_dist, dist_args, sample):
    valid_args = [i + 1 for i in range(jax_dist.numargs)]
    with validation_enabled():
        if dist_args:
            with pytest.raises(ValueError, match='Invalid parameters'):
                jax_dist(*dist_args)

        with pytest.raises(ValueError, match='Invalid scale parameter'):
            jax_dist(*valid_args, scale=-1)

        frozen_dist = jax_dist(*valid_args)
        with pytest.raises(ValueError, match='Invalid values'):
            frozen_dist.logpdf(sample)