예제 #1
0
    def test_dist_truncated_normal_batched(self):
        dist_sample_shape_correct = [2, 1]
        dist_means_non_truncated_correct = [[0], [2]]
        dist_stddevs_non_truncated_correct = [[1], [3]]
        dist_means_correct = [[0], [0.901189]]
        dist_stddevs_correct = [[0.53956], [1.95118]]
        dist_lows_correct = [[-1], [-4]]
        dist_highs_correct = [[1], [4]]
        dist_log_probs_correct = [[-0.537223], [-1.69563]]

        dist = TruncatedNormal(dist_means_non_truncated_correct, dist_stddevs_non_truncated_correct, dist_lows_correct, dist_highs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_non_truncated_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #2
0
    def test_dist_mixture(self):
        dist_sample_shape_correct = [1]
        dist_1 = Normal(0, 0.1)
        dist_2 = Normal(2, 0.1)
        dist_3 = Normal(3, 0.1)
        dist_means_correct = [0.7]
        dist_stddevs_correct = [1.10454]
        dist_log_probs_correct = [-23.473]

        dist = Mixture([dist_1, dist_2, dist_3], probs=[0.7, 0.2, 0.1])

        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        # print(dist.log_prob([2,2]))
        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #3
0
    def test_dist_mixture_batched(self):
        dist_sample_shape_correct = [2, 1]
        dist_1 = Normal([[0], [1]], [[0.1], [1]])
        dist_2 = Normal([[2], [5]], [[0.1], [1]])
        dist_3 = Normal([[3], [10]], [[0.1], [1]])
        dist_means_correct = [[0.7], [8.1]]
        dist_stddevs_correct = [[1.10454], [3.23883]]
        dist_log_probs_correct = [[-23.473], [-3.06649]]

        dist = Mixture([dist_1, dist_2, dist_3], probs=[[0.7, 0.2, 0.1], [0.1, 0.2, 0.7]])

        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #4
0
    def test_dist_kumaraswamy_batched(self):
        dist_sample_shape_correct = [4, 1]
        dist_shape1s_correct = [[0.5], [7.5], [7.5], [7.5]]
        dist_shape2s_correct = [[0.75], [2.5], [2.5], [2.5]]
        dist_means_correct = [[0.415584], [0.807999], [0.807999], [0.807999]]
        dist_stddevs_correct = [[0.327509], [0.111605], [0.111605], [0.111605]]
        dist_values = [[0.415584], [0.807999], [0.], [1.]]
        dist_log_probs_correct = [[-0.283125], [1.20676], [float('-inf')], [float('-inf')]]

        dist = Kumaraswamy(dist_shape1s_correct, dist_shape2s_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_values))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_values', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #5
0
    def test_dist_normal(self):
        dist_sample_shape_correct = [1]
        dist_means_correct = [0]
        dist_stddevs_correct = [1]
        dist_log_probs_correct = [-0.918939]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical(
            [dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct',
                   'dist_means', 'dist_means_empirical', 'dist_means_correct',
                   'dist_stddevs', 'dist_stddevs_empirical',
                   'dist_stddevs_correct', 'dist_log_probs',
                   'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs_empirical, dist_stddevs_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #6
0
    def test_inference_branching_random_walk_metropolis_hastings(self):
        samples = importance_sampling_samples
        posterior_correct = util.empirical_to_categorical(
            self.true_posterior(), max_val=40)

        start = time.time()
        posterior = util.empirical_to_categorical(
            self._model.posterior_distribution(
                samples,
                inference_engine=InferenceEngine.
                RANDOM_WALK_METROPOLIS_HASTINGS,
                observe={'obs': 6}),
            max_val=40)
        add_random_walk_metropolis_hastings_duration(time.time() - start)

        posterior_probs = util.to_numpy(posterior._probs)
        posterior_probs_correct = util.to_numpy(posterior_correct._probs)
        kl_divergence = float(
            pyprob.distributions.Distribution.kl_divergence(
                posterior, posterior_correct))

        util.eval_print('samples', 'posterior_probs',
                        'posterior_probs_correct', 'kl_divergence')
        add_random_walk_metropolis_hastings_kl_divergence(kl_divergence)

        self.assertLess(kl_divergence, 0.75)
예제 #7
0
    def test_dist_kumaraswamy(self):
        dist_sample_shape_correct = [1]
        dist_shape1s_correct = [2]
        dist_shape2s_correct = [5]
        dist_means_correct = [0.369408]
        dist_stddevs_correct = [0.173793]
        dist_log_probs_correct = [0.719861]

        dist = Kumaraswamy(dist_shape1s_correct, dist_shape2s_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #8
0
    def test_dist_normal_multivariate_batched(self):
        dist_sample_shape_correct = [2, 3]
        dist_means_correct = [[0, 2, 0], [2, 0, 2]]
        dist_stddevs_correct = [[1, 3, 1], [3, 1, 3]]
        dist_log_probs_correct = [[sum([-0.918939, -2.01755, -0.918939])],
                                  [sum([-2.01755, -0.918939, -2.01755])]]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical(
            [dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct',
                   'dist_means', 'dist_means_empirical', 'dist_means_correct',
                   'dist_stddevs', 'dist_stddevs_empirical',
                   'dist_stddevs_correct', 'dist_log_probs',
                   'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs_empirical, dist_stddevs_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #9
0
    def test_dist_empirical_resample(self):
        dist_means_correct = [2]
        dist_stddevs_correct = [5]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_empirical = dist_empirical.resample(int(empirical_samples/2))
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)

        util.debug('dist_means_empirical', 'dist_means_correct', 'dist_stddevs_empirical', 'dist_stddevs_correct')

        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.25))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.25))
예제 #10
0
    def test_inference_remote_branching_importance_sampling(self):
        samples = importance_sampling_samples
        posterior_correct = util.empirical_to_categorical(self.true_posterior(), max_val=40)

        start = time.time()
        posterior = util.empirical_to_categorical(self._model.posterior_results(samples, observe={'obs': 6}), max_val=40)
        add_importance_sampling_duration(time.time() - start)

        posterior_probs = util.to_numpy(posterior._probs)
        posterior_probs_correct = util.to_numpy(posterior_correct._probs)
        kl_divergence = float(pyprob.distributions.Distribution.kl_divergence(posterior, posterior_correct))

        util.eval_print('samples', 'posterior_probs', 'posterior_probs_correct', 'kl_divergence')
        add_importance_sampling_kl_divergence(kl_divergence)

        self.assertLess(kl_divergence, 0.75)
예제 #11
0
    def test_dist_uniform(self):
        dist_sample_shape_correct = [1]
        dist_means_correct = [0.5]
        dist_stddevs_correct = [0.288675]
        dist_lows_correct = [0]
        dist_highs_correct = [1]
        dist_log_probs_correct = [0]

        dist = Uniform(dist_lows_correct, dist_highs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_lows = util.to_numpy(dist.low)
        dist_highs = util.to_numpy(dist.high)
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)

        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_lows', 'dist_lows_correct', 'dist_highs', 'dist_highs_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_lows, dist_lows_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_highs, dist_highs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #12
0
    def test_dist_uniform_batched(self):
        dist_sample_shape_correct = [4, 1]
        dist_means_correct = [[0.5], [7.5], [0.5], [0.5]]
        dist_stddevs_correct = [[0.288675], [1.44338], [0.288675], [0.288675]]
        dist_lows_correct = [[0], [5], [0], [0]]
        dist_highs_correct = [[1], [10], [1], [1]]
        dist_values = [[0.5], [7.5], [0], [1]]
        dist_log_probs_correct = [[0], [-1.60944], [float('-inf')], [float('-inf')]]

        dist = Uniform(dist_lows_correct, dist_highs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_lows = util.to_numpy(dist.low)
        dist_highs = util.to_numpy(dist.high)
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_values))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_lows', 'dist_lows_correct', 'dist_highs', 'dist_highs_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_values', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_lows, dist_lows_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_highs, dist_highs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #13
0
    def test_model_remote_branching_importance_sampling(self):
        observation = 6
        posterior_correct = util.empirical_to_categorical(
            self.true_posterior(observation), max_val=40)

        posterior = util.empirical_to_categorical(
            self._model.posterior_distribution(samples,
                                               observation=observation),
            max_val=40)
        posterior_probs = util.to_numpy(posterior._probs[0])
        posterior_probs_correct = util.to_numpy(posterior_correct._probs[0])

        kl_divergence = float(
            util.kl_divergence_categorical(posterior_correct, posterior))

        util.debug('samples', 'posterior_probs', 'posterior_probs_correct',
                   'kl_divergence')

        self.assertLess(kl_divergence, 0.25)
예제 #14
0
    def test_dist_truncated_normal_clamped_batched(self):
        dist_sample_shape_correct = [2, 1]
        dist_means_non_truncated = [[0], [2]]
        dist_means_non_truncated_correct = [[0.5], [1]]
        dist_stddevs_non_truncated = [[1], [3]]
        dist_means_correct = [[0.744836], [-0.986679]]
        dist_stddevs_correct = [[0.143681], [1.32416]]
        dist_lows_correct = [[0.5], [-4]]
        dist_highs_correct = [[1], [1]]
        dist_log_prob_arguments = [[0.75], [-3]]
        dist_log_probs_correct = [[0.702875], [-2.11283]]

        dist = TruncatedNormal(dist_means_non_truncated, dist_stddevs_non_truncated, dist_lows_correct, dist_highs_correct, clamp_mean_between_low_high=True)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_means_non_truncated = util.to_numpy(dist._mean_non_truncated)
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_log_prob_arguments))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_means_non_truncated', 'dist_means_non_truncated_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means_non_truncated, dist_means_non_truncated_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #15
0
    def test_dist_poisson_batched(self):
        dist_sample_shape_correct = [2, 1]
        dist_means_correct = [[4], [100]]
        dist_stddevs_correct = [[math.sqrt(4)], [math.sqrt(100)]]
        dist_rates_correct = [[4], [100]]
        dist_log_probs_correct = [[-1.63288], [-3.22236]]

        dist = Poisson(dist_rates_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_rates = util.to_numpy(dist.rate)
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)

        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_rates', 'dist_rates_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_rates, dist_rates_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #16
0
    def test_dist_poisson_multivariate_batched(self):
        dist_sample_shape_correct = [2, 3]
        dist_means_correct = [[1, 2, 15], [100, 200, 300]]
        dist_stddevs_correct = [[math.sqrt(1), math.sqrt(2), math.sqrt(15)], [math.sqrt(100), math.sqrt(200), math.sqrt(300)]]
        dist_rates_correct = [[1, 2, 15], [100, 200, 300]]
        dist_log_probs_correct = [[sum([-1, -1.30685, -2.27852])], [sum([-3.22236, -3.56851, -3.77110])]]

        dist = Poisson(dist_rates_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_rates = util.to_numpy(dist.rate)
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)

        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_rates', 'dist_rates_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.25))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.25))
        self.assertTrue(np.allclose(dist_rates, dist_rates_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #17
0
    def test_dist_poisson_multivariate_from_flat_params(self):
        dist_sample_shape_correct = [1, 3]
        dist_means_correct = [[1, 2, 15]]
        dist_stddevs_correct = [[math.sqrt(1), math.sqrt(2), math.sqrt(15)]]
        dist_rates_correct = [[1, 2, 15]]
        dist_log_probs_correct = [sum([-1, -1.30685, -2.27852])]

        dist = Poisson(dist_rates_correct[0])
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical([dist.sample() for i in range(empirical_samples)])
        dist_rates = util.to_numpy(dist.rate)
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)

        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_rates', 'dist_rates_correct', 'dist_means', 'dist_means_empirical', 'dist_means_correct', 'dist_stddevs', 'dist_stddevs_empirical', 'dist_stddevs_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_stddevs_empirical, dist_stddevs_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_rates, dist_rates_correct, atol=0.1))
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #18
0
    def test_dist_categorical(self):
        dist_sample_shape_correct = [1]
        dist_log_probs_correct = [-2.30259]

        dist = Categorical([0.1, 0.2, 0.7])

        dist_sample_shape = list(dist.sample().size())
        dist_log_probs = util.to_numpy(dist.log_prob(0))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #19
0
    def test_model_remote_branching_random_walk_metropolis_hastings(self):
        observation = 6
        posterior_correct = util.empirical_to_categorical(
            self.true_posterior(observation), max_val=40)

        posterior = util.empirical_to_categorical(
            self._model.posterior_distribution(
                samples,
                observation=observation,
                inference_engine=pyprob.InferenceEngine.
                RANDOM_WALK_METROPOLIS_HASTINGS),
            max_val=40)
        posterior_probs = util.to_numpy(posterior._probs[0])
        posterior_probs_correct = util.to_numpy(posterior_correct._probs[0])

        kl_divergence = float(
            util.kl_divergence_categorical(posterior_correct, posterior))

        util.debug('samples', 'posterior_probs', 'posterior_probs_correct',
                   'kl_divergence')

        self.assertLess(kl_divergence, 0.25)
예제 #20
0
    def test_dist_categorical_batched(self):
        dist_sample_shape_correct = [2]
        dist_log_probs_correct = [[-2.30259], [-0.693147]]

        dist = Categorical([[0.1, 0.2, 0.7],
                            [0.2, 0.5, 0.3]])

        dist_sample_shape = list(dist.sample().size())
        dist_log_probs = util.to_numpy(dist.log_prob([[0, 1]]))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct', 'dist_log_probs', 'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
예제 #21
0
    def test_distributions_remote(self):
        num_samples = 4000
        prior_normal_mean_correct = Normal(1.75, 0.5).mean
        prior_uniform_mean_correct = Uniform(1.2, 2.5).mean
        prior_categorical_mean_correct = 1.  # Categorical([0.1, 0.5, 0.4])
        prior_poisson_mean_correct = Poisson(4.0).mean
        prior_bernoulli_mean_correct = Bernoulli(0.2).mean
        prior_beta_mean_correct = Beta(1.2, 2.5).mean
        prior_exponential_mean_correct = Exponential(2.2).mean
        prior_gamma_mean_correct = Gamma(0.5, 1.2).mean
        prior_log_normal_mean_correct = LogNormal(0.5, 0.2).mean
        prior_binomial_mean_correct = Binomial(10, 0.72).mean
        prior_weibull_mean_correct = Weibull(1.1, 0.6).mean

        prior = self._model.prior(num_samples)
        prior_normal = prior.map(
            lambda trace: trace.named_variables['normal'].value)
        prior_uniform = prior.map(
            lambda trace: trace.named_variables['uniform'].value)
        prior_categorical = prior.map(
            lambda trace: trace.named_variables['categorical'].value)
        prior_poisson = prior.map(
            lambda trace: trace.named_variables['poisson'].value)
        prior_bernoulli = prior.map(
            lambda trace: trace.named_variables['bernoulli'].value)
        prior_beta = prior.map(
            lambda trace: trace.named_variables['beta'].value)
        prior_exponential = prior.map(
            lambda trace: trace.named_variables['exponential'].value)
        prior_gamma = prior.map(
            lambda trace: trace.named_variables['gamma'].value)
        prior_log_normal = prior.map(
            lambda trace: trace.named_variables['log_normal'].value)
        prior_binomial = prior.map(
            lambda trace: trace.named_variables['binomial'].value)
        prior_weibull = prior.map(
            lambda trace: trace.named_variables['weibull'].value)
        prior_normal_mean = util.to_numpy(prior_normal.mean)
        prior_uniform_mean = util.to_numpy(prior_uniform.mean)
        prior_categorical_mean = util.to_numpy(int(prior_categorical.mean))
        prior_poisson_mean = util.to_numpy(prior_poisson.mean)
        prior_bernoulli_mean = util.to_numpy(prior_bernoulli.mean)
        prior_beta_mean = util.to_numpy(prior_beta.mean)
        prior_exponential_mean = util.to_numpy(prior_exponential.mean)
        prior_gamma_mean = util.to_numpy(prior_gamma.mean)
        prior_log_normal_mean = util.to_numpy(prior_log_normal.mean)
        prior_binomial_mean = util.to_numpy(prior_binomial.mean)
        prior_weibull_mean = util.to_numpy(prior_weibull.mean)
        util.eval_print('num_samples', 'prior_normal_mean',
                        'prior_normal_mean_correct', 'prior_uniform_mean',
                        'prior_uniform_mean_correct', 'prior_categorical_mean',
                        'prior_categorical_mean_correct', 'prior_poisson_mean',
                        'prior_poisson_mean_correct', 'prior_bernoulli_mean',
                        'prior_bernoulli_mean_correct', 'prior_beta_mean',
                        'prior_beta_mean_correct', 'prior_exponential_mean',
                        'prior_exponential_mean_correct', 'prior_gamma_mean',
                        'prior_gamma_mean_correct', 'prior_log_normal_mean',
                        'prior_log_normal_mean_correct', 'prior_binomial_mean',
                        'prior_binomial_mean_correct', 'prior_weibull_mean',
                        'prior_weibull_mean_correct')

        self.assertTrue(
            np.allclose(prior_normal_mean, prior_normal_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_uniform_mean,
                        prior_uniform_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_categorical_mean,
                        prior_categorical_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_poisson_mean,
                        prior_poisson_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_bernoulli_mean,
                        prior_bernoulli_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_beta_mean, prior_beta_mean_correct, atol=0.1))
        self.assertTrue(
            np.allclose(prior_exponential_mean,
                        prior_exponential_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_gamma_mean, prior_gamma_mean_correct, atol=0.1))
        self.assertTrue(
            np.allclose(prior_log_normal_mean,
                        prior_log_normal_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_binomial_mean,
                        prior_binomial_mean_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(prior_weibull_mean,
                        prior_weibull_mean_correct,
                        atol=0.1))