def test_transformation_by_constant(self): base_prior = [Uniform(0, 5, guess=3), 2] transform = TransformedPrior(np.maximum, base_prior) samples = transform.sample(100) self.assertEqual(transform.guess, 3) self.assertTrue(all(samples < 5)) self.assertTrue(all(samples >= 2))
def test_sample_single_prior_with_size(self): transformed = TransformedPrior(np.sqrt, Uniform(81, 100)) sample = transformed.sample(100) self.assertEqual(len(sample), 100) self.assertTrue(isinstance(sample, np.ndarray)) self.assertTrue((9 < sample).all() and (sample < 10).all()) self.assertNotEqual(sample[0], sample[1])
def test_sample_multiple_priors_with_size(self): base_prior = [Uniform(0, 1), Uniform(2, 3)] transformed = TransformedPrior(np.maximum, base_prior) sample = transformed.sample(100) self.assertEqual(len(sample), 100) self.assertTrue(isinstance(sample, np.ndarray)) self.assertTrue((2 < sample).all() and (sample < 3).all()) self.assertNotEqual(sample[0], sample[1])
def test_hierarchical_transformed_prior(self): base_prior = Uniform(10, 20, guess=16) transform = TransformedPrior(np.sqrt, base_prior) double = TransformedPrior(np.sqrt, transform) samples = double.sample(100) self.assertEqual(double.guess, 2) self.assertTrue( all(samples < np.sqrt(np.sqrt(20))) and all(samples > np.sqrt(np.sqrt(10))))
def test_two_arg_numpy_ufunc_with_both_priors(self): prior_1 = Uniform(2, 4) prior_2 = Uniform(1, 3) transformed = TransformedPrior(np.maximum, [prior_1, prior_2]) self.assertEqual(np.maximum(prior_1, prior_2), transformed)
def test_two_arg_numpy_ufunc_with_const(self): base_prior = Uniform(2, 4) transformed = TransformedPrior(np.maximum, [base_prior, 3]) self.assertEqual(np.maximum(base_prior, 3), transformed)
def test_single_arg_numpy_ufunc(self): base_prior = Uniform(2, 3) transformed = TransformedPrior(np.sqrt, base_prior) self.assertEqual(np.sqrt(base_prior), transformed)
def test_guess_with_multiple_priors(self): base_priors = [Uniform(0, 10, guess=4), Uniform(0, 2, guess=1)] transformed = TransformedPrior(np.maximum, base_priors) self.assertEqual(transformed.guess, 4)
def test_guess_with_single_prior(self): base_prior = Uniform(0, 10, guess=4) transformed = TransformedPrior(np.sqrt, base_prior) self.assertEqual(transformed.guess, 2)
def test_sample_multiple_priors_once(self): base_prior = [Uniform(0, 1), Uniform(2, 3)] transformed = TransformedPrior(np.maximum, base_prior) sample = transformed.sample() self.assertTrue(2 < sample < 3)
def test_sample_single_prior_once(self): transformed = TransformedPrior(np.sqrt, Uniform(81, 100)) sample = transformed.sample() self.assertTrue(9 < sample < 10)
def test_single_base_prior_becomes_tuple(self): base_prior = Uniform(0, 2) transformed = TransformedPrior(np.sqrt, base_prior) self.assertEqual(transformed.base_prior, (base_prior, ))
def transformed_prior(transformation, base_priors): if any([isinstance(bp, Prior) for bp in base_priors]): return TransformedPrior(transformation, base_priors) else: return transformation(*base_priors)