Exemplo n.º 1
0
    def test_dist_normal(self):
        dist_sample_shape_correct = [1]
        dist_means_correct = [0]
        dist_stddevs_correct = [1]
        dist_log_probs_correct = [-0.918939]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical(
            [dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct',
                   'dist_means', 'dist_means_empirical', 'dist_means_correct',
                   'dist_stddevs', 'dist_stddevs_empirical',
                   'dist_stddevs_correct', 'dist_log_probs',
                   'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs_empirical, dist_stddevs_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
Exemplo n.º 2
0
    def test_dist_normal_multivariate_batched(self):
        dist_sample_shape_correct = [2, 3]
        dist_means_correct = [[0, 2, 0], [2, 0, 2]]
        dist_stddevs_correct = [[1, 3, 1], [3, 1, 3]]
        dist_log_probs_correct = [[sum([-0.918939, -2.01755, -0.918939])],
                                  [sum([-2.01755, -0.918939, -2.01755])]]

        dist = Normal(dist_means_correct, dist_stddevs_correct)
        dist_sample_shape = list(dist.sample().size())
        dist_empirical = Empirical(
            [dist.sample() for i in range(empirical_samples)])
        dist_means = util.to_numpy(dist.mean)
        dist_means_empirical = util.to_numpy(dist_empirical.mean)
        dist_stddevs = util.to_numpy(dist.stddev)
        dist_stddevs_empirical = util.to_numpy(dist_empirical.stddev)
        dist_log_probs = util.to_numpy(dist.log_prob(dist_means_correct))

        util.debug('dist_sample_shape', 'dist_sample_shape_correct',
                   'dist_means', 'dist_means_empirical', 'dist_means_correct',
                   'dist_stddevs', 'dist_stddevs_empirical',
                   'dist_stddevs_correct', 'dist_log_probs',
                   'dist_log_probs_correct')

        self.assertEqual(dist_sample_shape, dist_sample_shape_correct)
        self.assertTrue(np.allclose(dist_means, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_means_empirical, dist_means_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs, dist_stddevs_correct, atol=0.1))
        self.assertTrue(
            np.allclose(dist_stddevs_empirical, dist_stddevs_correct,
                        atol=0.1))
        self.assertTrue(
            np.allclose(dist_log_probs, dist_log_probs_correct, atol=0.1))
Exemplo n.º 3
0
 def forward(self):
     mu = pyprob.sample(Normal(self.prior_mean, self.prior_stddev))
     likelihood = Normal(mu, self.likelihood_stddev)
     likelihood_func = lambda x: likelihood.log_prob(x)
     pyprob.factor(log_prob=likelihood_func(8))
     pyprob.factor(log_prob=likelihood_func(9))
     return mu