def test_batch_reparametrization_sampler_sample_raises_for_inconsistent_batch_size( ) -> None: sampler = BatchReparametrizationSampler(100, QuadraticMeanAndRBFKernel()) sampler.sample(tf.constant([[0.0], [1.0], [2.0]])) with pytest.raises(TF_DEBUGGING_ERROR_TYPES): sampler.sample(tf.constant([[0.0], [1.0]]))
def test_batch_reparametrization_sampler_samples_are_distinct_for_new_instances( ) -> None: model = _dim_two_gp() sampler1 = BatchReparametrizationSampler(100, model) sampler2 = BatchReparametrizationSampler(100, model) xs = tf.random.uniform([3, 5, 7, 2], dtype=tf.float64) npt.assert_array_less(1e-9, tf.abs(sampler2.sample(xs) - sampler1.sample(xs)))
def test_batch_reparametrization_sampler_samples_approximate_mean_and_covariance() -> None: model = _dim_two_gp() sample_size = 10_000 leading_dims = [3] batch_size = 4 xs = tf.random.uniform(leading_dims + [batch_size, 2], maxval=1.0, dtype=tf.float64) samples = BatchReparametrizationSampler(sample_size, model).sample(xs) assert samples.shape == leading_dims + [sample_size, batch_size, 2] samples_mean = tf.reduce_mean(samples, axis=-3) samples_covariance = tf.transpose( tfp.stats.covariance(samples, sample_axis=-3, event_axis=-2), [0, 3, 1, 2] ) model_mean, model_cov = model.predict_joint(xs) npt.assert_allclose(samples_mean, model_mean, rtol=0.02) npt.assert_allclose(samples_covariance, model_cov, rtol=0.04)
def test_batch_reparametrization_sampler_sample_raises_for_negative_jitter( ) -> None: sampler = BatchReparametrizationSampler(100, QuadraticMeanAndRBFKernel()) with pytest.raises(TF_DEBUGGING_ERROR_TYPES): sampler.sample(tf.constant([[0.0]]), jitter=-1e-6)
def test_batch_reparametrization_sampler_sample_raises_for_invalid_at_shape( at: tf.Tensor) -> None: sampler = BatchReparametrizationSampler(100, QuadraticMeanAndRBFKernel()) with pytest.raises(TF_DEBUGGING_ERROR_TYPES): sampler.sample(at)
def test_batch_reparametrization_sampler_samples_are_repeatable() -> None: sampler = BatchReparametrizationSampler(100, _dim_two_gp()) xs = tf.random.uniform([3, 5, 7, 2], dtype=tf.float64) npt.assert_allclose(sampler.sample(xs), sampler.sample(xs))
def test_batch_reparametrization_sampler_samples_are_continuous() -> None: sampler = BatchReparametrizationSampler(100, _dim_two_gp()) xs = tf.random.uniform([3, 5, 7, 2], dtype=tf.float64) npt.assert_array_less( tf.abs(sampler.sample(xs + 1e-20) - sampler.sample(xs)), 1e-20)
def test_batch_reparametrization_sampler_raises_for_invalid_sample_size( sample_size: int) -> None: with pytest.raises(TF_DEBUGGING_ERROR_TYPES): BatchReparametrizationSampler(sample_size, _dim_two_gp())